diff --git a/libraries/microsoft-agents-activity/microsoft_agents/activity/activity.py b/libraries/microsoft-agents-activity/microsoft_agents/activity/activity.py index b146a459..a21ec8c0 100644 --- a/libraries/microsoft-agents-activity/microsoft_agents/activity/activity.py +++ b/libraries/microsoft-agents-activity/microsoft_agents/activity/activity.py @@ -618,10 +618,15 @@ def create_typing_activity() -> "Activity": """ return Activity(type=ActivityTypes.typing) - def get_conversation_reference(self) -> ConversationReference: + def get_conversation_reference( + self, force_base_channel: bool | None = None + ) -> ConversationReference: """ Creates a ConversationReference based on this activity. + :param force_base_channel: Optional, when True use only the base channel value + from the channel id (for example ``msteams`` from ``msteams:copilot-web``). + Composite values are split only on the first ``:``. :returns: A conversation reference for the conversation that contains this activity. """ return pick_model( @@ -635,7 +640,11 @@ def get_conversation_reference(self) -> ConversationReference: user=copy(self.from_property), agent=copy(self.recipient), conversation=copy(self.conversation), - channel_id=self.channel_id, + channel_id=( + self.channel_id.split(":", 1)[0] + if force_base_channel and self.channel_id is not None + else self.channel_id + ), locale=self.locale, service_url=self.service_url, ) diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/_oauth/_oauth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/_oauth/_oauth_flow.py index f396c17f..0c1369b4 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/_oauth/_oauth_flow.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/_oauth/_oauth_flow.py @@ -187,7 +187,7 @@ async def begin_flow(self, activity: Activity) -> _FlowResponse: token_exchange_state = TokenExchangeState( connection_name=self._abs_oauth_connection_name, - conversation=activity.get_conversation_reference(), + conversation=activity.get_conversation_reference(force_base_channel=True), relates_to=activity.relates_to, ms_app_id=self._ms_app_id, ) diff --git a/tests/activity/test_activity.py b/tests/activity/test_activity.py index 35a04893..f44667dc 100644 --- a/tests/activity/test_activity.py +++ b/tests/activity/test_activity.py @@ -66,6 +66,43 @@ def test_get_conversation_reference(self, activity): assert activity.locale == conversation_reference.locale assert activity.service_url == conversation_reference.service_url + def test_get_conversation_reference_force_base_channel(self, activity): + activity.channel_id = "msteams:copilot-web" + + conversation_reference = activity.get_conversation_reference( + force_base_channel=True + ) + + assert conversation_reference.channel_id == "msteams" + + @pytest.mark.parametrize( + "channel_id, expected_base_channel", + [ + ("msteams", "msteams"), + ("msteams:copilot-web", "msteams"), + ("msteams:copilot:web", "msteams"), + ], + ) + def test_get_conversation_reference_force_base_channel_variants( + self, activity, channel_id, expected_base_channel + ): + activity.channel_id = channel_id + + conversation_reference = activity.get_conversation_reference( + force_base_channel=True + ) + + assert conversation_reference.channel_id == expected_base_channel + + def test_get_conversation_reference_does_not_force_base_channel(self, activity): + activity.channel_id = "msteams:copilot-web" + + conversation_reference = activity.get_conversation_reference( + force_base_channel=False + ) + + assert conversation_reference.channel_id == "msteams:copilot-web" + def test_get_reply_conversation_reference(self, activity): reply = ResourceResponse(id="1234") conversation_reference = activity.get_reply_conversation_reference(reply) diff --git a/tests/hosting_core/_oauth/test_oauth_flow.py b/tests/hosting_core/_oauth/test_oauth_flow.py index 9fe4e3dc..aae663d6 100644 --- a/tests/hosting_core/_oauth/test_oauth_flow.py +++ b/tests/hosting_core/_oauth/test_oauth_flow.py @@ -32,25 +32,22 @@ def create_testing_Activity( value=None, text="a", ): - # mock_conversation_ref = mocker.MagicMock(ConversationReference) conversation_reference = ConversationReference( conversation={"id": "conv1"}, ) - mocker.patch.object( - Activity, - "get_conversation_reference", - return_value=conversation_reference, - ) - return Activity( + activity = Activity( type=type, name=name, from_property=ChannelAccount(id=DEFAULTS.user_id), + recipient=ChannelAccount(id="agent-id"), + conversation={"id": "conv1"}, channel_id=DEFAULTS.channel_id, - # get_conversation_reference=mocker.Mock(return_value=conv_ref), + service_url=DEFAULTS.service_url, relates_to=conversation_reference, value=value, text=text, ) + return activity class TestUtils(FlowStateFixtures): @@ -174,6 +171,9 @@ async def test_begin_flow_easy_case(self, mocker, flow_state, activity): mocker.patch.object( TokenExchangeState, "get_encoded_state", return_value="encoded_state" ) + get_conversation_reference_spy = mocker.spy( + Activity, "get_conversation_reference" + ) flow = _OAuthFlow(flow_state, user_token_client) expected_flow_state = flow_state expected_flow_state.tag = _FlowStateTag.COMPLETE @@ -197,6 +197,9 @@ async def test_begin_flow_easy_case(self, mocker, flow_state, activity): activity.channel_id, "encoded_state", ) + get_conversation_reference_spy.assert_any_call( + activity, force_base_channel=True + ) @pytest.mark.asyncio async def test_begin_flow_long_case(self, mocker, flow_state, activity):