From 7fd715cd917636c2aa2d0363683e096295907379 Mon Sep 17 00:00:00 2001 From: Gian <47775302+gpunto@users.noreply.github.com> Date: Tue, 9 Jun 2026 14:42:13 +0200 Subject: [PATCH] Prevent tracking channel states for malformed CIDs --- .../android/client/api/state/StateRegistry.kt | 141 ++++++++++-------- .../internal/EventHandlerSequential.kt | 12 +- .../plugin/logic/internal/LogicRegistry.kt | 114 ++++++++------ .../internal/QueryChannelsStateLogic.kt | 51 +++---- .../state/sync/internal/SyncManager.kt | 33 ++-- .../client/utils/internal/ChannelId.kt | 38 +++++ .../state/internal/SyncManagerTest.kt | 62 +++++++- .../logic/internal/LogicRegistryTest.kt | 47 +++++- .../internal/QueryChannelsStateLogicTest.kt | 61 ++++++-- .../state/plugin/state/StateRegistryTest.kt | 108 ++++++++++---- .../client/utils/internal/ChannelIdTest.kt | 81 ++++++++++ 11 files changed, 535 insertions(+), 213 deletions(-) create mode 100644 stream-chat-android-client/src/main/java/io/getstream/chat/android/client/utils/internal/ChannelId.kt create mode 100644 stream-chat-android-client/src/test/java/io/getstream/chat/android/client/utils/internal/ChannelIdTest.kt diff --git a/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/api/state/StateRegistry.kt b/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/api/state/StateRegistry.kt index 3080525055d..2810729cd0b 100644 --- a/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/api/state/StateRegistry.kt +++ b/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/api/state/StateRegistry.kt @@ -27,6 +27,7 @@ import io.getstream.chat.android.client.internal.state.plugin.state.channel.inte import io.getstream.chat.android.client.internal.state.plugin.state.channel.thread.internal.ThreadMutableState import io.getstream.chat.android.client.internal.state.plugin.state.querychannels.internal.QueryChannelsMutableState import io.getstream.chat.android.client.internal.state.plugin.state.querythreads.internal.QueryThreadsMutableState +import io.getstream.chat.android.client.utils.internal.ChannelId import io.getstream.chat.android.core.internal.InternalStreamChatApi import io.getstream.chat.android.models.Channel import io.getstream.chat.android.models.FilterObject @@ -57,7 +58,7 @@ import java.util.concurrent.ConcurrentHashMap * @param mutedUsers The current list of muted users. * @param useLegacyChannelState Whether to use the legacy channel state implementation. */ -@Suppress("LongParameterList") +@Suppress("LongParameterList", "TooManyFunctions") public class StateRegistry @JvmOverloads constructor( private val userStateFlow: StateFlow, private var latestUsers: StateFlow>, @@ -74,8 +75,8 @@ public class StateRegistry @JvmOverloads constructor( private val queryChannels: ConcurrentHashMap = ConcurrentHashMap() - private val legacyChannels: ConcurrentHashMap, ChannelStateLegacyImpl> = ConcurrentHashMap() - private val channels: ConcurrentHashMap, ChannelStateImpl> = ConcurrentHashMap() + private val legacyChannels: ConcurrentHashMap = ConcurrentHashMap() + private val channels: ConcurrentHashMap = ConcurrentHashMap() private val queryThreads: ConcurrentHashMap>, QueryThreadsMutableState> = ConcurrentHashMap() private val threads: ConcurrentHashMap = ConcurrentHashMap() @@ -120,78 +121,100 @@ public class StateRegistry @JvmOverloads constructor( } /** - * Returns [ChannelState] that represents a state of particular channel. + * Returns the [ChannelState] for the given channel. * - * @param channelType The channel type. ie messaging. - * @param channelId The channel id. ie 123. - * - * @return [ChannelState] object. - */ - public fun channel(channelType: String, channelId: String): ChannelState = if (useLegacyChannelState) { - legacyChannelState(channelType, channelId) - } else { - channelState(channelType, channelId) - } - - /** - * Returns [ChannelStateLegacyImpl] that represents a state of particular channel. + * A malformed cid yields a fresh, non-cached state so callers still get a non-null object, but + * the registry won't track it and the state will never receive updates. * * @param channelType The channel type. ie messaging. * @param channelId The channel id. ie 123. * * @return [ChannelState] object. */ - internal fun legacyChannelState(channelType: String, channelId: String): ChannelStateLegacyImpl { - return legacyChannels.getOrPut(channelType to channelId) { - val baseMessageLimit = messageLimitConfig.channelMessageLimits - .find { it.channelType == channelType } - ?.baseLimit - ChannelStateLegacyImpl( - channelType = channelType, - channelId = channelId, - userFlow = userStateFlow, - latestUsers = latestUsers, - activeLiveLocations = activeLiveLocations, - baseMessageLimit = baseMessageLimit, - now = now, - ) + public fun channel(channelType: String, channelId: String): ChannelState { + val id = ChannelId.fromTypeAndId(channelType, channelId) + if (id == null) { + logger.w { "[channel] rejected malformed cid: $channelType:$channelId" } + return newChannelState(channelType, channelId) } + return channel(id) } - internal fun channelState(channelType: String, channelId: String): ChannelStateImpl { - val baseMessageLimit = messageLimitConfig.channelMessageLimits - .find { it.channelType == channelType } - ?.baseLimit - return channels.getOrPut(channelType to channelId) { - ChannelStateImpl( - channelType = channelType, - channelId = channelId, - currentUser = userStateFlow, - latestUsers = latestUsers, - mutedUsers = mutedUsers, - liveLocations = activeLiveLocations, - messageLimit = baseMessageLimit, - ) + /** Returns the cached [ChannelState] for an already-validated [ChannelId]. */ + internal fun channel(channelId: ChannelId): ChannelState = + if (useLegacyChannelState) legacyChannelState(channelId) else channelState(channelId) + + internal fun legacyChannelState(channelId: ChannelId): ChannelStateLegacyImpl = + legacyChannels.getOrPut(channelId) { + buildLegacyChannelState(channelId.type, channelId.id) + } + + internal fun legacyChannelState(channelType: String, channelId: String): ChannelStateLegacyImpl = + ChannelId.fromTypeAndId(channelType, channelId) + ?.let(::legacyChannelState) + ?: buildLegacyChannelState(channelType, channelId) + + internal fun channelState(channelId: ChannelId): ChannelStateImpl = + channels.getOrPut(channelId) { + buildChannelState(channelId.type, channelId.id) } - } + + internal fun channelState(channelType: String, channelId: String): ChannelStateImpl = + ChannelId.fromTypeAndId(channelType, channelId) + ?.let(::channelState) + ?: buildChannelState(channelType, channelId) /** * Checks if the channel is already present in the state. * Should be used to prevent creating [ChannelState] objects without populated data. * - * @param channelType The channel type. ie messaging. - * @param channelId The channel id. ie 123. - * * @return true if the channel is active. */ - internal fun isActiveChannel(channelType: String, channelId: String): Boolean { + internal fun isActiveChannel(channelId: ChannelId): Boolean { return if (useLegacyChannelState) { - legacyChannels.containsKey(channelType to channelId) + legacyChannels.containsKey(channelId) } else { - channels.containsKey(channelType to channelId) + channels.containsKey(channelId) } } + private fun newChannelState(channelType: String, channelId: String): ChannelState = + if (useLegacyChannelState) { + buildLegacyChannelState(channelType, channelId) + } else { + buildChannelState(channelType, channelId) + } + + private fun buildLegacyChannelState(channelType: String, channelId: String): ChannelStateLegacyImpl { + val baseMessageLimit = messageLimitConfig.channelMessageLimits + .find { it.channelType == channelType } + ?.baseLimit + return ChannelStateLegacyImpl( + channelType = channelType, + channelId = channelId, + userFlow = userStateFlow, + latestUsers = latestUsers, + activeLiveLocations = activeLiveLocations, + baseMessageLimit = baseMessageLimit, + now = now, + ) + } + + private fun buildChannelState(channelType: String, channelId: String): ChannelStateImpl { + val baseMessageLimit = messageLimitConfig.channelMessageLimits + .find { it.channelType == channelType } + ?.baseLimit + return ChannelStateImpl( + channelType = channelType, + channelId = channelId, + currentUser = userStateFlow, + latestUsers = latestUsers, + mutedUsers = mutedUsers, + liveLocations = activeLiveLocations, + messageLimit = baseMessageLimit, + ) + } + /** * Returns a [QueryThreadsState] holding the current state of the threads data. */ @@ -227,13 +250,8 @@ public class StateRegistry @JvmOverloads constructor( ThreadMutableState(messageId, scope) } - internal fun getActiveChannelStates(): List { - return if (useLegacyChannelState) { - legacyChannels.values.toList() - } else { - channels.values.toList() - } - } + internal fun getActiveChannelStates(): Map = + if (useLegacyChannelState) legacyChannels.toMap() else channels.toMap() /** * Clear state of all state objects. @@ -269,10 +287,11 @@ public class StateRegistry @JvmOverloads constructor( } private fun removeChanel(channelType: String, channelId: String) { + val id = ChannelId.fromTypeAndId(channelType, channelId) ?: return val removed = if (useLegacyChannelState) { - legacyChannels.remove(channelType to channelId)?.destroy() + legacyChannels.remove(id)?.destroy() } else { - channels.remove(channelType to channelId)?.destroy() + channels.remove(id)?.destroy() } logger.i { "[removeChanel] removed channel($channelType, $channelId): $removed" } } diff --git a/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/event/handler/internal/EventHandlerSequential.kt b/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/event/handler/internal/EventHandlerSequential.kt index 5429e63c619..32f1f534384 100644 --- a/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/event/handler/internal/EventHandlerSequential.kt +++ b/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/event/handler/internal/EventHandlerSequential.kt @@ -496,15 +496,11 @@ internal class EventHandlerSequential( sortedEvents.find { it is UserPresenceChangedEvent }?.let { userPresenceChanged -> val event = userPresenceChanged as UserPresenceChangedEvent - stateRegistry.getActiveChannelStates() - .filter { channelState -> channelState.members.containsWithUserId(event.user.id) } - .forEach { channelState -> - val channelLogic: ChannelLogic = logicRegistry.channel( - channelType = channelState.channelType, - channelId = channelState.channelId, - ) - channelLogic.handleEvent(userPresenceChanged) + stateRegistry.getActiveChannelStates().forEach { (id, state) -> + if (state.members.containsWithUserId(event.user.id)) { + logicRegistry.channel(id).handleEvent(userPresenceChanged) } + } } // Handle `user.messages.deleted` event diff --git a/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/plugin/logic/internal/LogicRegistry.kt b/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/plugin/logic/internal/LogicRegistry.kt index 0e81c0ce7e0..0b45ccd500a 100644 --- a/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/plugin/logic/internal/LogicRegistry.kt +++ b/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/plugin/logic/internal/LogicRegistry.kt @@ -22,7 +22,6 @@ import io.getstream.chat.android.client.api.models.QueryThreadsRequest import io.getstream.chat.android.client.api.state.StateRegistry import io.getstream.chat.android.client.channel.ChannelMessagesUpdateLogic import io.getstream.chat.android.client.channel.state.ChannelStateLogicProvider -import io.getstream.chat.android.client.extensions.cidToTypeAndId import io.getstream.chat.android.client.internal.state.plugin.QueryChannelsIdentifier import io.getstream.chat.android.client.internal.state.plugin.identifier import io.getstream.chat.android.client.internal.state.plugin.logic.channel.internal.ChannelLogic @@ -43,10 +42,12 @@ import io.getstream.chat.android.client.internal.state.plugin.state.global.inter import io.getstream.chat.android.client.internal.state.plugin.state.querychannels.internal.toMutableState import io.getstream.chat.android.client.persistance.repository.RepositoryFacade import io.getstream.chat.android.client.setup.state.ClientState +import io.getstream.chat.android.client.utils.internal.ChannelId import io.getstream.chat.android.models.FilterObject import io.getstream.chat.android.models.Message import io.getstream.chat.android.models.Thread import io.getstream.chat.android.models.querysort.QuerySorter +import io.getstream.log.taggedLogger import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.launch import java.util.concurrent.ConcurrentHashMap @@ -71,9 +72,11 @@ internal class LogicRegistry internal constructor( private val useLegacyChannelLogic: Boolean, ) : ChannelStateLogicProvider { + private val logger by taggedLogger("Chat:LogicRegistry") + private val queryChannels: ConcurrentHashMap = ConcurrentHashMap() - private val channels: ConcurrentHashMap, ChannelLogic> = ConcurrentHashMap() + private val channels: ConcurrentHashMap = ConcurrentHashMap() private val queryThreads: ConcurrentHashMap?>, QueryThreadsLogic> = ConcurrentHashMap() private val threads: ConcurrentHashMap = ConcurrentHashMap() @@ -108,19 +111,38 @@ internal class LogicRegistry internal constructor( internal fun queryChannels(queryChannelsRequest: QueryChannelsRequest): QueryChannelsLogic = queryChannels(queryChannelsRequest.identifier) - /** Returns [ChannelLogic] by channelType and channelId combination. */ + /** + * Returns [ChannelLogic] by channelType and channelId combination. + * + * A malformed cid yields a fresh, non-cached logic so callers still get a non-null object, but + * the registry won't track it and updates fed to it will be discarded. + */ fun channel(channelType: String, channelId: String): ChannelLogic { - return if (useLegacyChannelLogic) { - legacyChannelLogic(channelType, channelId) - } else { - channelLogic(channelType, channelId) + val id = ChannelId.fromTypeAndId(channelType, channelId) + if (id == null) { + logger.w { "[channel] rejected malformed cid: $channelType:$channelId" } + return newChannelLogic(channelType, channelId) } + return channel(id) + } + + /** Returns [ChannelLogic] for a validated [ChannelId], creating and caching it on first access. */ + fun channel(channelId: ChannelId): ChannelLogic = channels.getOrPut(channelId) { + newChannelLogic(channelId.type, channelId.id) } internal fun removeChannel(channelType: String, channelId: String) { - channels.remove(channelType to channelId) + val id = ChannelId.fromTypeAndId(channelType, channelId) ?: return + channels.remove(id) } + private fun newChannelLogic(type: String, id: String): ChannelLogic = + if (useLegacyChannelLogic) { + buildLegacyChannelLogic(type, id) + } else { + buildChannelLogic(type, id) + } + fun channelFromMessageId(messageId: String): ChannelLogic? { return channels.values.find { channelLogic -> channelLogic.getMessage(messageId) != null @@ -159,8 +181,7 @@ internal class LogicRegistry internal constructor( */ fun channelFromMessage(message: Message): ChannelLogic? { return if (message.parentId == null || message.showInChannel) { - val (channelType, channelId) = message.cid.cidToTypeAndId() - channel(channelType, channelId) + ChannelId.fromCid(message.cid)?.let(::channel) } else { null } @@ -252,8 +273,10 @@ internal class LogicRegistry internal constructor( * * @return True if the channel is active. */ - fun isActiveChannel(channelType: String, channelId: String): Boolean = - channels.containsKey(channelType to channelId) + fun isActiveChannel(channelType: String, channelId: String): Boolean { + val id = ChannelId.fromTypeAndId(channelType, channelId) ?: return false + return channels.containsKey(id) + } /** * Returns a list of [ChannelLogic] for all, active channel requests. @@ -276,44 +299,39 @@ internal class LogicRegistry internal constructor( mutableGlobalState.destroy() } - private fun legacyChannelLogic(type: String, id: String): ChannelLogic { - return channels.getOrPut(type to id) { - val mutableState = stateRegistry.legacyChannelState(type, id) - val stateLogic = ChannelStateLogic( - clientState = clientState, - mutableState = mutableState, - globalMutableState = mutableGlobalState, - searchLogic = SearchLogic(mutableState), - now = now, - coroutineScope = coroutineScope, - ) - - ChannelLogicLegacyImpl( - repos = repos, - userPresence = userPresence, - stateLogic = stateLogic, - coroutineScope = coroutineScope, - getCurrentUserId = { clientState.user.value?.id }, - ) - } + private fun buildLegacyChannelLogic(type: String, id: String): ChannelLogic { + val mutableState = stateRegistry.legacyChannelState(type, id) + val stateLogic = ChannelStateLogic( + clientState = clientState, + mutableState = mutableState, + globalMutableState = mutableGlobalState, + searchLogic = SearchLogic(mutableState), + now = now, + coroutineScope = coroutineScope, + ) + return ChannelLogicLegacyImpl( + repos = repos, + userPresence = userPresence, + stateLogic = stateLogic, + coroutineScope = coroutineScope, + getCurrentUserId = { clientState.user.value?.id }, + ) } - private fun channelLogic(type: String, id: String): ChannelLogic { - return channels.getOrPut(type to id) { - val state = stateRegistry.channelState(type, id) - val messagesUpdateLogic = ChannelMessagesUpdateLogicImpl(state) - ChannelLogicImpl( - cid = "$type:$id", - messagesUpdateLogic = messagesUpdateLogic, - repository = repos, - state = state, - mutableGlobalState = mutableGlobalState, - userPresence = userPresence, - coroutineScope = coroutineScope, - getCurrentUserId = { clientState.user.value?.id }, - now = now, - ) - } + private fun buildChannelLogic(type: String, id: String): ChannelLogic { + val state = stateRegistry.channelState(type, id) + val messagesUpdateLogic = ChannelMessagesUpdateLogicImpl(state) + return ChannelLogicImpl( + cid = "$type:$id", + messagesUpdateLogic = messagesUpdateLogic, + repository = repos, + state = state, + mutableGlobalState = mutableGlobalState, + userPresence = userPresence, + coroutineScope = coroutineScope, + getCurrentUserId = { clientState.user.value?.id }, + now = now, + ) } companion object { diff --git a/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/plugin/logic/querychannels/internal/QueryChannelsStateLogic.kt b/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/plugin/logic/querychannels/internal/QueryChannelsStateLogic.kt index 5b482ad9168..a072fe30926 100644 --- a/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/plugin/logic/querychannels/internal/QueryChannelsStateLogic.kt +++ b/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/plugin/logic/querychannels/internal/QueryChannelsStateLogic.kt @@ -22,12 +22,11 @@ import io.getstream.chat.android.client.api.state.QueryChannelsState import io.getstream.chat.android.client.api.state.StateRegistry import io.getstream.chat.android.client.channel.state.ChannelState import io.getstream.chat.android.client.events.ChatEvent -import io.getstream.chat.android.client.extensions.cidToTypeAndId -import io.getstream.chat.android.client.extensions.internal.toCid import io.getstream.chat.android.client.extensions.internal.users import io.getstream.chat.android.client.internal.state.plugin.logic.internal.LogicRegistry import io.getstream.chat.android.client.internal.state.plugin.state.querychannels.internal.QueryChannelsMutableState import io.getstream.chat.android.client.query.QueryChannelsSpec +import io.getstream.chat.android.client.utils.internal.ChannelId import io.getstream.chat.android.models.Channel import io.getstream.chat.android.models.FilterObject import io.getstream.chat.android.models.User @@ -35,6 +34,7 @@ import io.getstream.chat.android.models.querysort.QuerySorter import io.getstream.log.taggedLogger import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll @Suppress("TooManyFunctions") internal class QueryChannelsStateLogic( @@ -153,26 +153,26 @@ internal class QueryChannelsStateLogic( * @param channels List. */ internal suspend fun addChannelsState(channels: List) { - mutableState.setCids(mutableState.queryChannelsSpec.cids + channels.map { it.cid }) - val existingChannels = mutableState.rawChannels ?: emptyMap() + val validated = channels.mapNotNull { channel -> + ChannelId.fromCid(channel.cid)?.let { it to channel } + } + mutableState.setCids(mutableState.queryChannelsSpec.cids + validated.map { (id, _) -> id.cid }) + val existingChannels = mutableState.rawChannels.orEmpty() mutableState.setChannels( - existingChannels + - channels.map { - it.cid to it.joinMessages(existingChannels[it.cid]) - .joinMembers(existingChannels[it.cid]) - }, + existingChannels + validated.associate { (id, channel) -> + id.cid to channel.joinMessages(existingChannels[id.cid]) + .joinMembers(existingChannels[id.cid]) + }, ) - channels.map { channel -> + validated.map { (id, channel) -> coroutineScope.async { - logicRegistry.channel(channel.type, channel.id).updateDataForChannel( + logicRegistry.channel(id).updateDataForChannel( channel = channel, messageLimit = channel.messages.size, isChannelsStateUpdate = true, ) } - }.map { - it.await() - } + }.awaitAll() } private fun Channel.joinMessages(existingChannel: Channel?): Channel = @@ -237,19 +237,10 @@ internal class QueryChannelsStateLogic( val newChannels = existingChannels + mutableState.queryChannelsSpec.cids .intersect(cidList.toSet()) - .map { cid -> cid.cidToTypeAndId() } - .filter { (channelType, channelId) -> - stateRegistry.isActiveChannel( - channelType = channelType, - channelId = channelId, - ) - } - .associate { (channelType, channelId) -> - val cid = (channelType to channelId).toCid() - cid to stateRegistry.channel( - channelType = channelType, - channelId = channelId, - ).toChannel() + .mapNotNull(ChannelId::fromCid) + .filter(stateRegistry::isActiveChannel) + .associate { id -> + id.cid to stateRegistry.channel(id).toChannel() } mutableState.setChannels(newChannels) @@ -260,9 +251,9 @@ internal class QueryChannelsStateLogic( * channel is active, or `null` otherwise. */ internal fun getActiveChannelState(cid: String): Channel? { - val (type, id) = cid.cidToTypeAndId() - if (!stateRegistry.isActiveChannel(type, id)) return null - return stateRegistry.channel(type, id).toChannel() + val id = ChannelId.fromCid(cid) ?: return null + if (!stateRegistry.isActiveChannel(id)) return null + return stateRegistry.channel(id).toChannel() } /** diff --git a/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/sync/internal/SyncManager.kt b/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/sync/internal/SyncManager.kt index b24142da822..913718930a2 100644 --- a/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/sync/internal/SyncManager.kt +++ b/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/sync/internal/SyncManager.kt @@ -39,12 +39,14 @@ import io.getstream.chat.android.client.query.CreateChannelParams import io.getstream.chat.android.client.setup.state.ClientState import io.getstream.chat.android.client.sync.SyncState import io.getstream.chat.android.client.sync.stringify +import io.getstream.chat.android.client.utils.internal.ChannelId import io.getstream.chat.android.client.utils.internal.ServerClockOffset import io.getstream.chat.android.client.utils.message.isDeleted import io.getstream.chat.android.client.utils.observable.Disposable import io.getstream.chat.android.core.internal.coroutines.Tube import io.getstream.chat.android.core.utils.date.diff import io.getstream.chat.android.models.Attachment +import io.getstream.chat.android.models.Channel import io.getstream.chat.android.models.Filters import io.getstream.chat.android.models.MemberData import io.getstream.chat.android.models.Message @@ -469,17 +471,16 @@ internal class SyncManager( logger.d { "[updateActiveChannels] recoverAll: $recoverAll, online: $online, cidsToExclude.size: ${cidsToExclude.size}" } - val missingCids: List = stateRegistry.getActiveChannelStates() - .asSequence() - .filter { (it.recoveryNeeded || recoverAll) && !cidsToExclude.contains(it.cid) } + val missingChannelIds: List = stateRegistry.getActiveChannelStates() + .filter { (id, state) -> (state.recoveryNeeded || recoverAll) && id.cid !in cidsToExclude } + .keys .take(n = 30) - .map { it.cid } - .toList() - logger.v { "[updateActiveChannels] missingCids.size: ${missingCids.size}" } - if (missingCids.isEmpty() || !online) { + logger.v { "[updateActiveChannels] missingCids.size: ${missingChannelIds.size}" } + if (missingChannelIds.isEmpty() || !online) { return } + val missingCids = missingChannelIds.map(ChannelId::cid) val filter = Filters.`in`("cid", missingCids) val request = QueryChannelsRequest(filter, offset = 0, limit = 30) logger.v { "[updateActiveChannels] request: $request" } @@ -493,19 +494,18 @@ internal class SyncManager( logger.v { "[updateActiveChannels] request completed; foundChannels.size: ${foundChannels.size}" } foundChannels.forEach { channel -> - val channelLogic = logicRegistry.channel(channel.type, channel.id) - channelLogic.updateDataForChannel(channel, channel.messages.size) + ChannelId.fromTypeAndId(channel.type, channel.id) + ?.let(logicRegistry::channel) + ?.updateDataForChannel(channel, channel.messages.size) } repos.storeStateForChannels(foundChannels) - val foundCids = foundChannels.map { it.cid } - val stillMissingCids = missingCids - foundCids.toSet() - logger.v { "[updateActiveChannels] stillMissingCids.size: ${stillMissingCids.size}" } + val foundCids = foundChannels.mapTo(mutableSetOf(), Channel::cid) + val stillMissingChannelIds = missingChannelIds.filterNot { it.cid in foundCids } + logger.v { "[updateActiveChannels] stillMissingCids.size: ${stillMissingChannelIds.size}" } // create channels that are not present on the API - stillMissingCids.forEach { cid -> - val (type, id) = cid.cidToTypeAndId() - val channelLogic = logicRegistry.channel(type, id) - channelLogic.watch(userPresence = userPresence) + stillMissingChannelIds.forEach { id -> + logicRegistry.channel(id).watch(userPresence = userPresence) } } } @@ -572,6 +572,7 @@ internal class SyncManager( message.isDeleted() && !message.deletedForMe -> { retryDeletionOfMessageWithSyncedAttachments(id, message, channelClient) } + message.deletedForMe -> retryDeleteMessageForMe(id) message.updatedLocallyAt != null && message.createdAt != null -> { retryUpdateOfMessageWithSyncedAttachments(id, message, channelClient) diff --git a/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/utils/internal/ChannelId.kt b/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/utils/internal/ChannelId.kt new file mode 100644 index 00000000000..578fa761105 --- /dev/null +++ b/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/utils/internal/ChannelId.kt @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2014-2026 Stream.io Inc. All rights reserved. + * + * Licensed under the Stream License; + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://github.com/GetStream/stream-chat-android/blob/main/LICENSE + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.getstream.chat.android.client.utils.internal + +/** + * A channel identifier whose underlying `type:id` cid has been validated. + */ +@JvmInline +internal value class ChannelId private constructor(val cid: String) { + + val type: String get() = cid.substringBefore(':') + val id: String get() = cid.substringAfter(':') + + companion object { + fun fromCid(cid: String): ChannelId? = try { + ChannelId(validateCid(cid)) + } catch (_: IllegalArgumentException) { + null + } + + fun fromTypeAndId(channelType: String, channelId: String): ChannelId? = + fromCid("$channelType:$channelId") + } +} diff --git a/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/internal/SyncManagerTest.kt b/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/internal/SyncManagerTest.kt index 769051a55a5..3a6bceb0aa5 100644 --- a/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/internal/SyncManagerTest.kt +++ b/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/internal/SyncManagerTest.kt @@ -18,7 +18,10 @@ package io.getstream.chat.android.client.internal.state.internal import app.cash.turbine.test import io.getstream.chat.android.client.ChatClient +import io.getstream.chat.android.client.api.models.QueryChannelsRequest +import io.getstream.chat.android.client.api.models.QueryChannelsResult import io.getstream.chat.android.client.api.state.StateRegistry +import io.getstream.chat.android.client.channel.state.ChannelState import io.getstream.chat.android.client.errors.ChatErrorCode import io.getstream.chat.android.client.events.ChatEvent import io.getstream.chat.android.client.events.ConnectedEvent @@ -32,6 +35,7 @@ import io.getstream.chat.android.client.persistance.repository.RepositoryFacade import io.getstream.chat.android.client.setup.state.ClientState import io.getstream.chat.android.client.sync.SyncState import io.getstream.chat.android.client.test.randomConnectedEvent +import io.getstream.chat.android.client.utils.internal.ChannelId import io.getstream.chat.android.client.utils.internal.ServerClockOffset import io.getstream.chat.android.client.utils.observable.Disposable import io.getstream.chat.android.core.internal.coroutines.Tube @@ -67,7 +71,6 @@ import org.amshove.kluent.shouldNotBeNull import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test -import org.junit.jupiter.api.TestInstance import org.mockito.kotlin.any import org.mockito.kotlin.doReturn import org.mockito.kotlin.eq @@ -79,7 +82,6 @@ import org.mockito.kotlin.whenever import java.util.Date @ExperimentalCoroutinesApi -@TestInstance(TestInstance.Lifecycle.PER_CLASS) internal class SyncManagerTest { private val testDispatcher = UnconfinedTestDispatcher() @@ -191,7 +193,7 @@ internal class SyncManagerTest { _syncEvents.test { syncManager.onEvent(connectingEvent) - assertEquals(listOf(mockedChatEvent), awaitItem()) + expectNoEvents() } } @@ -504,6 +506,60 @@ internal class SyncManagerTest { verify(chatClient).getSyncHistory(any(), any()) } + @Test + fun `on reconnect, recoverable active channels found by the server get updated and missing ones get watched`() = + runTest(testDispatcher) { + val createdAt = localDate() + val rawCreatedAt = streamDateFormatter.format(createdAt) + val syncState = SyncState( + userId = user.id, + activeChannelIds = emptyList(), + lastSyncedAt = createdAt, + rawLastSyncedAt = rawCreatedAt, + markedAllReadAt = createdAt, + ) + whenever(repositoryFacade.selectSyncState(user.id)) doReturn syncState + whenever(chatClient.getSyncHistory(any(), any())) doReturn TestCall(Result.Success(emptyList())) + whenever(clientState.isOnline) doReturn true + + val foundChannelId = ChannelId.fromTypeAndId("messaging", "found")!! + val missingChannelId = ChannelId.fromTypeAndId("messaging", "missing")!! + val foundChannelState: ChannelState = mock { on(it.recoveryNeeded) doReturn true } + val missingChannelState: ChannelState = mock { on(it.recoveryNeeded) doReturn true } + whenever(stateRegistry.getActiveChannelStates()) doReturn mapOf( + foundChannelId to foundChannelState, + missingChannelId to missingChannelState, + ) + + val foundChannel = randomChannel(type = "messaging", id = "found") + whenever(chatClient.queryChannelsInternal(any())) doReturn TestCall( + Result.Success(QueryChannelsResult(channels = listOf(foundChannel), predefinedFilter = null)), + ) + + val foundLogic: ChannelLogic = mock() + val missingLogic: ChannelLogic = mock { + onBlocking { it.watch(any(), any()) } doReturn Result.Success(foundChannel) + } + whenever(logicRegistry.channel(foundChannelId)) doReturn foundLogic + whenever(logicRegistry.channel(missingChannelId)) doReturn missingLogic + + val syncManager = buildSyncManager(isAutomaticSyncOnReconnectEnabled = true) + + syncManager.onEvent( + ConnectedEvent( + type = "type", + createdAt = createdAt, + rawCreatedAt = rawCreatedAt, + connectionId = randomString(), + me = user, + ), + ) + + verify(foundLogic).updateDataForChannel(foundChannel, foundChannel.messages.size) + verify(repositoryFacade).storeStateForChannels(listOf(foundChannel)) + verify(missingLogic).watch(userPresence = true) + } + @Test fun `when isAutomaticSyncOnReconnectEnabled is false, getSyncHistory should not be called on connected event`() = runTest(testDispatcher) { diff --git a/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/plugin/logic/internal/LogicRegistryTest.kt b/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/plugin/logic/internal/LogicRegistryTest.kt index 4935ad6c658..e281c77919d 100644 --- a/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/plugin/logic/internal/LogicRegistryTest.kt +++ b/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/plugin/logic/internal/LogicRegistryTest.kt @@ -140,7 +140,7 @@ internal class LogicRegistryTest { // Stub new channel state (ChannelStateImpl mocks) channelStateMocks.clear() - whenever(stateRegistry.channelState(any(), any())).thenAnswer { + whenever(stateRegistry.channelState(any(), any())).thenAnswer { val type = it.getArgument(0) val id = it.getArgument(1) channelStateMocks.getOrPut(type to id) { mock() } @@ -148,7 +148,7 @@ internal class LogicRegistryTest { // Stub legacy channel state (real ChannelStateLegacyImpl instances) legacyChannelStateMocks.clear() - whenever(stateRegistry.legacyChannelState(any(), any())).thenAnswer { + whenever(stateRegistry.legacyChannelState(any(), any())).thenAnswer { val type = it.getArgument(0) val id = it.getArgument(1) legacyChannelStateMocks.getOrPut(type to id) { @@ -1203,4 +1203,47 @@ internal class LogicRegistryTest { } // endregion + + // region Malformed cid handling + + @Test + fun `channel with malformed cid returns a non-cached ChannelLogic`() { + val first = logicRegistry.channel("messaging", "") + val second = logicRegistry.channel("messaging", "") + + Assertions.assertNotSame(first, second) + Assertions.assertTrue(logicRegistry.getActiveChannelsLogic().isEmpty()) + Assertions.assertFalse(logicRegistry.isActiveChannel("messaging", "")) + } + + @Test + fun `legacy channel with malformed cid returns a non-cached ChannelLogic`() { + val first = legacyLogicRegistry.channel("", "id") + val second = legacyLogicRegistry.channel("", "id") + + Assertions.assertNotSame(first, second) + Assertions.assertTrue(legacyLogicRegistry.getActiveChannelsLogic().isEmpty()) + Assertions.assertFalse(legacyLogicRegistry.isActiveChannel("", "id")) + } + + @Test + fun `removeChannel with malformed cid is a no-op`() { + logicRegistry.channel("messaging", "123") + Assertions.assertEquals(1, logicRegistry.getActiveChannelsLogic().size) + + logicRegistry.removeChannel("", "123") + + Assertions.assertEquals(1, logicRegistry.getActiveChannelsLogic().size) + } + + @Test + fun `channelFromMessage returns null when cid is malformed`() { + val message = Message(id = "msg1", cid = "not-a-cid", parentId = null) + + val result = logicRegistry.channelFromMessage(message) + + Assertions.assertNull(result) + } + + // endregion } diff --git a/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/plugin/logic/querychannels/internal/QueryChannelsStateLogicTest.kt b/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/plugin/logic/querychannels/internal/QueryChannelsStateLogicTest.kt index 45ad60ed1fd..34bab6dafae 100644 --- a/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/plugin/logic/querychannels/internal/QueryChannelsStateLogicTest.kt +++ b/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/plugin/logic/querychannels/internal/QueryChannelsStateLogicTest.kt @@ -18,11 +18,11 @@ package io.getstream.chat.android.client.internal.state.plugin.logic.querychanne import io.getstream.chat.android.client.api.state.StateRegistry import io.getstream.chat.android.client.channel.state.ChannelState -import io.getstream.chat.android.client.extensions.cidToTypeAndId import io.getstream.chat.android.client.extensions.internal.toCid import io.getstream.chat.android.client.internal.state.plugin.logic.internal.LogicRegistry import io.getstream.chat.android.client.internal.state.plugin.state.querychannels.internal.QueryChannelsMutableState import io.getstream.chat.android.client.query.QueryChannelsSpec +import io.getstream.chat.android.client.utils.internal.ChannelId import io.getstream.chat.android.models.Filters import io.getstream.chat.android.models.querysort.QuerySortByField import io.getstream.chat.android.randomCID @@ -49,6 +49,7 @@ internal class QueryChannelsStateLogicTest { private val type = randomString() private val id = randomString() private val testCid = (type to id).toCid() + private val testChannelId = ChannelId.fromCid(testCid)!! private val queryChannelsSpec = QueryChannelsSpec( filter = Filters.neutral(), @@ -63,6 +64,7 @@ internal class QueryChannelsStateLogicTest { private val stateRegistry: StateRegistry = mock() private val logicRegistry: LogicRegistry = mock { on(it.channel(any(), any())) doReturn mock() + on(it.channel(any())) doReturn mock() } private val queryChannelsStateLogic = @@ -75,10 +77,8 @@ internal class QueryChannelsStateLogicTest { on(it.toChannel()) doReturn channel } - val (channelType, channelId) = testCid.cidToTypeAndId() - - whenever(stateRegistry.isActiveChannel(channelType, channelId)) doReturn true - whenever(stateRegistry.channel(channelType, channelId)) doReturn channelState + whenever(stateRegistry.isActiveChannel(testChannelId)) doReturn true + whenever(stateRegistry.channel(testChannelId)) doReturn channelState queryChannelsStateLogic.refreshChannels(listOf(testCid)) @@ -92,11 +92,10 @@ internal class QueryChannelsStateLogicTest { on(it.toChannel()) doReturn channel } - val (channelType, channelId) = testCid.cidToTypeAndId() val cidOutsideSpecs = randomCID() - whenever(stateRegistry.isActiveChannel(channel.type, channel.id)) doReturn true - whenever(stateRegistry.channel(channelType, channelId)) doReturn channelState + whenever(stateRegistry.isActiveChannel(testChannelId)) doReturn true + whenever(stateRegistry.channel(testChannelId)) doReturn channelState queryChannelsStateLogic.refreshChannels(listOf(cidOutsideSpecs)) @@ -111,10 +110,9 @@ internal class QueryChannelsStateLogicTest { } val cidOutsideSpecs = randomCID() - val (channelType, channelId) = testCid.cidToTypeAndId() - whenever(stateRegistry.isActiveChannel(channel.type, channel.id)) doReturn false - whenever(stateRegistry.channel(channelType, channelId)) doReturn channelState + whenever(stateRegistry.isActiveChannel(testChannelId)) doReturn false + whenever(stateRegistry.channel(testChannelId)) doReturn channelState queryChannelsStateLogic.refreshChannels(listOf(cidOutsideSpecs)) @@ -142,8 +140,8 @@ internal class QueryChannelsStateLogicTest { on(it.toChannel()) doReturn channel } - whenever(stateRegistry.isActiveChannel(type, id)) doReturn true - whenever(stateRegistry.channel(type, id)) doReturn channelState + whenever(stateRegistry.isActiveChannel(testChannelId)) doReturn true + whenever(stateRegistry.channel(testChannelId)) doReturn channelState val result = queryChannelsStateLogic.getActiveChannelState(testCid) @@ -152,10 +150,45 @@ internal class QueryChannelsStateLogicTest { @Test fun `getActiveChannelState should return null when channel is not active in state registry`() { - whenever(stateRegistry.isActiveChannel(type, id)) doReturn false + whenever(stateRegistry.isActiveChannel(testChannelId)) doReturn false val result = queryChannelsStateLogic.getActiveChannelState(testCid) assertNull(result) } + + @Test + fun `getActiveChannelState should return null when cid is malformed`() { + assertNull(queryChannelsStateLogic.getActiveChannelState("not-a-cid")) + } + + @Test + fun `addChannelsState drops channels with malformed cid`() = runTest { + val valid = randomChannel(type = "messaging", id = "valid-${randomString()}") + val malformed = randomChannel(type = "", id = "x") + whenever(mutableState.queryChannelsSpec) doReturn queryChannelsSpec + + queryChannelsStateLogic.addChannelsState(listOf(valid, malformed)) + + verify(mutableState).setCids(setOf(testCid, valid.cid)) + verify(mutableState).setChannels(mapOf(valid.cid to valid)) + verify(logicRegistry).channel(ChannelId.fromCid(valid.cid)!!) + } + + @Test + fun `refreshChannels skips malformed cids in the spec intersection`() { + val malformedCid = "no-colon-cid" + val specWithMalformed = QueryChannelsSpec( + filter = Filters.neutral(), + querySort = QuerySortByField.descByName(""), + cids = setOf(testCid, malformedCid), + ) + whenever(mutableState.queryChannelsSpec) doReturn specWithMalformed + whenever(mutableState.rawChannels) doReturn emptyMap() + whenever(stateRegistry.isActiveChannel(testChannelId)) doReturn false + + queryChannelsStateLogic.refreshChannels(listOf(testCid, malformedCid)) + + verify(mutableState).setChannels(emptyMap()) + } } diff --git a/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/plugin/state/StateRegistryTest.kt b/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/plugin/state/StateRegistryTest.kt index 2c0a297d718..49b2d6a1f44 100644 --- a/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/plugin/state/StateRegistryTest.kt +++ b/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/plugin/state/StateRegistryTest.kt @@ -23,6 +23,7 @@ import io.getstream.chat.android.client.events.NotificationChannelDeletedEvent import io.getstream.chat.android.client.internal.state.event.handler.internal.batch.BatchEvent import io.getstream.chat.android.client.internal.state.plugin.state.channel.internal.ChannelStateImpl import io.getstream.chat.android.client.internal.state.plugin.state.channel.internal.ChannelStateLegacyImpl +import io.getstream.chat.android.client.utils.internal.ChannelId import io.getstream.chat.android.models.Channel import io.getstream.chat.android.models.FilterObject import io.getstream.chat.android.models.Filters @@ -55,6 +56,10 @@ internal class StateRegistryTest { val testCoroutines = TestCoroutineExtension() } + private val messaging123 = ChannelId.fromTypeAndId("messaging", "123")!! + private val messaging456 = ChannelId.fromTypeAndId("messaging", "456")!! + private val livestream789 = ChannelId.fromTypeAndId("livestream", "789")!! + private lateinit var legacyStateRegistry: StateRegistry private lateinit var stateRegistry: StateRegistry private lateinit var userStateFlow: StateFlow @@ -424,17 +429,17 @@ internal class StateRegistryTest { legacyStateRegistry.channel("messaging", "123") // Then - assertTrue(legacyStateRegistry.isActiveChannel("messaging", "123")) + assertTrue(legacyStateRegistry.isActiveChannel(messaging123)) } @Test fun `legacy isActiveChannel should return false for non-existent channel`() { // Then - assertFalse(legacyStateRegistry.isActiveChannel("messaging", "123")) + assertFalse(legacyStateRegistry.isActiveChannel(messaging123)) } @Test - fun `legacy getActiveChannelStates should return empty list initially`() { + fun `legacy getActiveChannelStates should return empty map initially`() { // When val activeStates = legacyStateRegistry.getActiveChannelStates() @@ -454,9 +459,9 @@ internal class StateRegistryTest { // Then assertEquals(3, activeStates.size) - assertTrue(activeStates.contains(state1)) - assertTrue(activeStates.contains(state2)) - assertTrue(activeStates.contains(state3)) + assertTrue(activeStates.containsValue(state1)) + assertTrue(activeStates.containsValue(state2)) + assertTrue(activeStates.containsValue(state3)) } @Test @@ -495,8 +500,8 @@ internal class StateRegistryTest { // Then assertTrue(legacyStateRegistry.getActiveChannelStates().isEmpty()) - assertFalse(legacyStateRegistry.isActiveChannel("messaging", "123")) - assertFalse(legacyStateRegistry.isActiveChannel("messaging", "456")) + assertFalse(legacyStateRegistry.isActiveChannel(messaging123)) + assertFalse(legacyStateRegistry.isActiveChannel(messaging456)) } @Test @@ -509,7 +514,7 @@ internal class StateRegistryTest { val newState = legacyStateRegistry.channel("messaging", "123") // Then - assertTrue(legacyStateRegistry.isActiveChannel("messaging", "123")) + assertTrue(legacyStateRegistry.isActiveChannel(messaging123)) assertEquals(1, legacyStateRegistry.getActiveChannelStates().size) assertEquals("messaging:123", newState.cid) } @@ -518,7 +523,7 @@ internal class StateRegistryTest { fun `legacy handleBatchEvent with ChannelDeletedEvent should remove channel`() { // Given legacyStateRegistry.channel("messaging", "123") - assertTrue(legacyStateRegistry.isActiveChannel("messaging", "123")) + assertTrue(legacyStateRegistry.isActiveChannel(messaging123)) val event = ChannelDeletedEvent( type = "channel.deleted", @@ -536,7 +541,7 @@ internal class StateRegistryTest { legacyStateRegistry.handleBatchEvent(batchEvent) // Then - assertFalse(legacyStateRegistry.isActiveChannel("messaging", "123")) + assertFalse(legacyStateRegistry.isActiveChannel(messaging123)) assertTrue(legacyStateRegistry.getActiveChannelStates().isEmpty()) } @@ -544,7 +549,7 @@ internal class StateRegistryTest { fun `legacy handleBatchEvent with NotificationChannelDeletedEvent should remove channel`() { // Given legacyStateRegistry.channel("messaging", "123") - assertTrue(legacyStateRegistry.isActiveChannel("messaging", "123")) + assertTrue(legacyStateRegistry.isActiveChannel(messaging123)) val event = NotificationChannelDeletedEvent( type = "notification.channel_deleted", @@ -561,7 +566,7 @@ internal class StateRegistryTest { legacyStateRegistry.handleBatchEvent(batchEvent) // Then - assertFalse(legacyStateRegistry.isActiveChannel("messaging", "123")) + assertFalse(legacyStateRegistry.isActiveChannel(messaging123)) assertTrue(legacyStateRegistry.getActiveChannelStates().isEmpty()) } @@ -598,9 +603,9 @@ internal class StateRegistryTest { legacyStateRegistry.handleBatchEvent(batchEvent) // Then - assertFalse(legacyStateRegistry.isActiveChannel("messaging", "123")) - assertTrue(legacyStateRegistry.isActiveChannel("messaging", "456")) - assertFalse(legacyStateRegistry.isActiveChannel("livestream", "789")) + assertFalse(legacyStateRegistry.isActiveChannel(messaging123)) + assertTrue(legacyStateRegistry.isActiveChannel(messaging456)) + assertFalse(legacyStateRegistry.isActiveChannel(livestream789)) assertEquals(1, legacyStateRegistry.getActiveChannelStates().size) } @@ -662,17 +667,17 @@ internal class StateRegistryTest { stateRegistry.channel("messaging", "123") // Then - assertTrue(stateRegistry.isActiveChannel("messaging", "123")) + assertTrue(stateRegistry.isActiveChannel(messaging123)) } @Test fun `isActiveChannel should return false for non-existent channel`() { // Then - assertFalse(stateRegistry.isActiveChannel("messaging", "123")) + assertFalse(stateRegistry.isActiveChannel(messaging123)) } @Test - fun `getActiveChannelStates should return empty list initially`() { + fun `getActiveChannelStates should return empty map initially`() { // When val activeStates = stateRegistry.getActiveChannelStates() @@ -692,9 +697,9 @@ internal class StateRegistryTest { // Then assertEquals(3, activeStates.size) - assertTrue(activeStates.contains(state1)) - assertTrue(activeStates.contains(state2)) - assertTrue(activeStates.contains(state3)) + assertTrue(activeStates.containsValue(state1)) + assertTrue(activeStates.containsValue(state2)) + assertTrue(activeStates.containsValue(state3)) } @Test @@ -733,15 +738,15 @@ internal class StateRegistryTest { // Then assertTrue(stateRegistry.getActiveChannelStates().isEmpty()) - assertFalse(stateRegistry.isActiveChannel("messaging", "123")) - assertFalse(stateRegistry.isActiveChannel("messaging", "456")) + assertFalse(stateRegistry.isActiveChannel(messaging123)) + assertFalse(stateRegistry.isActiveChannel(messaging456)) } @Test fun `handleBatchEvent with ChannelDeletedEvent should remove channel`() { // Given stateRegistry.channel("messaging", "123") - assertTrue(stateRegistry.isActiveChannel("messaging", "123")) + assertTrue(stateRegistry.isActiveChannel(messaging123)) val event = ChannelDeletedEvent( type = "channel.deleted", @@ -759,7 +764,7 @@ internal class StateRegistryTest { stateRegistry.handleBatchEvent(batchEvent) // Then - assertFalse(stateRegistry.isActiveChannel("messaging", "123")) + assertFalse(stateRegistry.isActiveChannel(messaging123)) assertTrue(stateRegistry.getActiveChannelStates().isEmpty()) } @@ -767,7 +772,7 @@ internal class StateRegistryTest { fun `handleBatchEvent with NotificationChannelDeletedEvent should remove channel`() { // Given stateRegistry.channel("messaging", "123") - assertTrue(stateRegistry.isActiveChannel("messaging", "123")) + assertTrue(stateRegistry.isActiveChannel(messaging123)) val event = NotificationChannelDeletedEvent( type = "notification.channel_deleted", @@ -784,7 +789,7 @@ internal class StateRegistryTest { stateRegistry.handleBatchEvent(batchEvent) // Then - assertFalse(stateRegistry.isActiveChannel("messaging", "123")) + assertFalse(stateRegistry.isActiveChannel(messaging123)) assertTrue(stateRegistry.getActiveChannelStates().isEmpty()) } @@ -821,10 +826,51 @@ internal class StateRegistryTest { stateRegistry.handleBatchEvent(batchEvent) // Then - assertFalse(stateRegistry.isActiveChannel("messaging", "123")) - assertTrue(stateRegistry.isActiveChannel("messaging", "456")) - assertFalse(stateRegistry.isActiveChannel("livestream", "789")) + assertFalse(stateRegistry.isActiveChannel(messaging123)) + assertTrue(stateRegistry.isActiveChannel(messaging456)) + assertFalse(stateRegistry.isActiveChannel(livestream789)) + assertEquals(1, stateRegistry.getActiveChannelStates().size) + } + + // endregion + + // region Malformed cid handling + + @Test + fun `channel with malformed cid returns a non-cached ChannelState`() { + val state = stateRegistry.channel("messaging", "") + assertEquals("messaging:", state.cid) + assertTrue(stateRegistry.getActiveChannelStates().isEmpty()) + } + + @Test + fun `legacy channel with malformed cid returns a non-cached ChannelState`() { + val state = legacyStateRegistry.channel("", "id") + assertEquals(":id", state.cid) + assertTrue(legacyStateRegistry.getActiveChannelStates().isEmpty()) + } + + @Test + fun `handleBatchEvent with malformed cid ChannelDeletedEvent is a no-op`() { + stateRegistry.channel("messaging", "123") + assertEquals(1, stateRegistry.getActiveChannelStates().size) + + val event = ChannelDeletedEvent( + type = "channel.deleted", + createdAt = Date(), + rawCreatedAt = "", + cid = ":id", + channelType = "", + channelId = "id", + channel = Channel(id = "id", type = ""), + user = null, + ) + val batchEvent = BatchEvent(sortedEvents = listOf(event), isFromHistorySync = false) + + stateRegistry.handleBatchEvent(batchEvent) + assertEquals(1, stateRegistry.getActiveChannelStates().size) + assertTrue(stateRegistry.isActiveChannel(messaging123)) } // endregion diff --git a/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/utils/internal/ChannelIdTest.kt b/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/utils/internal/ChannelIdTest.kt new file mode 100644 index 00000000000..7b1581294e6 --- /dev/null +++ b/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/utils/internal/ChannelIdTest.kt @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2014-2026 Stream.io Inc. All rights reserved. + * + * Licensed under the Stream License; + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://github.com/GetStream/stream-chat-android/blob/main/LICENSE + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.getstream.chat.android.client.utils.internal + +import org.amshove.kluent.`should be equal to` +import org.amshove.kluent.`should be null` +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource + +internal class ChannelIdTest { + + @ParameterizedTest + @MethodSource("validCids") + fun `fromCid returns a ChannelId for a valid cid`(cid: String) { + val channelId = ChannelId.fromCid(cid) + channelId?.cid `should be equal to` cid + } + + @ParameterizedTest + @MethodSource("invalidCids") + fun `fromCid returns null for a malformed cid`(cid: String) { + ChannelId.fromCid(cid).`should be null`() + } + + @Test + fun `fromTypeAndId joins parts and validates the result`() { + val channelId = ChannelId.fromTypeAndId("messaging", "123") + channelId?.cid `should be equal to` "messaging:123" + channelId?.type `should be equal to` "messaging" + channelId?.id `should be equal to` "123" + } + + @Test + fun `fromTypeAndId returns null when a part is blank`() { + ChannelId.fromTypeAndId("", "123").`should be null`() + ChannelId.fromTypeAndId("messaging", "").`should be null`() + } + + @Test + fun `fromTypeAndId returns null when a part contains a colon`() { + ChannelId.fromTypeAndId("messaging:foo", "123").`should be null`() + } + + companion object { + + @JvmStatic + fun validCids() = listOf( + "messaging:123", + "a:e", + "messaging:!members-oNJ1lQqt2b9SKG6raDWRTn4wWLakkFkwvqlUn-EsatU", + "!members-hash:!members-hash", + ) + + @JvmStatic + fun invalidCids() = listOf( + "", + " ", + "messaging 123", + "messaging123", + "messaging::123", + "messaging:", + ":123", + ":", + ) + } +}