Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -618,10 +618,15 @@ def create_typing_activity() -> "Activity":
"""
return Activity(type=ActivityTypes.typing)

def get_conversation_reference(self) -> ConversationReference:
def get_conversation_reference(
self, force_base_channel: bool | None = None
) -> ConversationReference:
Comment thread
MattB-msft marked this conversation as resolved.
"""
Creates a ConversationReference based on this activity.

:param force_base_channel: Optional, when True use only the base channel value
from the channel id (for example ``msteams`` from ``msteams:copilot-web``).
Composite values are split only on the first ``:``.
:returns: A conversation reference for the conversation that contains this activity.
"""
return pick_model(
Expand All @@ -635,7 +640,11 @@ def get_conversation_reference(self) -> ConversationReference:
user=copy(self.from_property),
agent=copy(self.recipient),
conversation=copy(self.conversation),
channel_id=self.channel_id,
channel_id=(
self.channel_id.split(":", 1)[0]
if force_base_channel and self.channel_id is not None
else self.channel_id
),
locale=self.locale,
service_url=self.service_url,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ async def begin_flow(self, activity: Activity) -> _FlowResponse:

token_exchange_state = TokenExchangeState(
connection_name=self._abs_oauth_connection_name,
conversation=activity.get_conversation_reference(),
conversation=activity.get_conversation_reference(force_base_channel=True),
Comment thread
axelsrz marked this conversation as resolved.
relates_to=activity.relates_to,
ms_app_id=self._ms_app_id,
)
Expand Down
37 changes: 37 additions & 0 deletions tests/activity/test_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,43 @@ def test_get_conversation_reference(self, activity):
assert activity.locale == conversation_reference.locale
assert activity.service_url == conversation_reference.service_url

def test_get_conversation_reference_force_base_channel(self, activity):
activity.channel_id = "msteams:copilot-web"

conversation_reference = activity.get_conversation_reference(
force_base_channel=True
)

assert conversation_reference.channel_id == "msteams"

@pytest.mark.parametrize(
"channel_id, expected_base_channel",
[
("msteams", "msteams"),
("msteams:copilot-web", "msteams"),
("msteams:copilot:web", "msteams"),
],
)
def test_get_conversation_reference_force_base_channel_variants(
self, activity, channel_id, expected_base_channel
):
activity.channel_id = channel_id

conversation_reference = activity.get_conversation_reference(
force_base_channel=True
)

assert conversation_reference.channel_id == expected_base_channel

def test_get_conversation_reference_does_not_force_base_channel(self, activity):
activity.channel_id = "msteams:copilot-web"

conversation_reference = activity.get_conversation_reference(
force_base_channel=False
Comment thread
axelsrz marked this conversation as resolved.
)

assert conversation_reference.channel_id == "msteams:copilot-web"

def test_get_reply_conversation_reference(self, activity):
reply = ResourceResponse(id="1234")
conversation_reference = activity.get_reply_conversation_reference(reply)
Expand Down
19 changes: 11 additions & 8 deletions tests/hosting_core/_oauth/test_oauth_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,22 @@ def create_testing_Activity(
value=None,
text="a",
):
# mock_conversation_ref = mocker.MagicMock(ConversationReference)
conversation_reference = ConversationReference(
conversation={"id": "conv1"},
)
mocker.patch.object(
Activity,
"get_conversation_reference",
return_value=conversation_reference,
)
return Activity(
activity = Activity(
type=type,
name=name,
from_property=ChannelAccount(id=DEFAULTS.user_id),
recipient=ChannelAccount(id="agent-id"),
conversation={"id": "conv1"},
channel_id=DEFAULTS.channel_id,
# get_conversation_reference=mocker.Mock(return_value=conv_ref),
service_url=DEFAULTS.service_url,
relates_to=conversation_reference,
value=value,
text=text,
)
return activity


class TestUtils(FlowStateFixtures):
Expand Down Expand Up @@ -174,6 +171,9 @@ async def test_begin_flow_easy_case(self, mocker, flow_state, activity):
mocker.patch.object(
TokenExchangeState, "get_encoded_state", return_value="encoded_state"
)
get_conversation_reference_spy = mocker.spy(
Activity, "get_conversation_reference"
)
flow = _OAuthFlow(flow_state, user_token_client)
expected_flow_state = flow_state
expected_flow_state.tag = _FlowStateTag.COMPLETE
Expand All @@ -197,6 +197,9 @@ async def test_begin_flow_easy_case(self, mocker, flow_state, activity):
activity.channel_id,
"encoded_state",
)
get_conversation_reference_spy.assert_any_call(
activity, force_base_channel=True
)

@pytest.mark.asyncio
Comment thread
axelsrz marked this conversation as resolved.
async def test_begin_flow_long_case(self, mocker, flow_state, activity):
Expand Down
Loading