Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using Microsoft.Extensions.AI;

namespace Microsoft.Agents.AI.Workflows;

internal sealed class AgentAIContextProviderMsgEvent(IReadOnlyList<ChatMessage> messages) : WorkflowEvent(messages)
Comment thread
XiongHaoTrigger marked this conversation as resolved.
{
public IReadOnlyList<ChatMessage> Messages { get; } = messages;
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Agents.AI;
using Microsoft.Extensions.AI;

namespace Microsoft.Agents.AI.Workflows.Specialized;
Expand Down Expand Up @@ -219,13 +220,19 @@ private async ValueTask<AgentResponse> InvokeAgentAsync(IEnumerable<ChatMessage>
{
AgentResponse response;
AIAgentUnservicedRequestsCollector collector = new(this._userInputHandler, this._functionCallHandler);
AgentSession session = await this.EnsureSessionAsync(context, cancellationToken).ConfigureAwait(false);
bool shouldCaptureEnrichedRequestMessages = this.HasAIContextProviders();
List<ChatMessage>? historyBefore = shouldCaptureEnrichedRequestMessages
? await this.GetStoredChatHistorySnapshotAsync(session, cancellationToken).ConfigureAwait(false)
: null;
List<ChatMessage> requestMessages = messages as List<ChatMessage> ?? messages.ToList();
Comment thread
XiongHaoTrigger marked this conversation as resolved.

if (emitUpdateEvents)
{
// Run the agent in streaming mode only when agent run update events are to be emitted.
IAsyncEnumerable<AgentResponseUpdate> agentStream = this._agent.RunStreamingAsync(
messages,
await this.EnsureSessionAsync(context, cancellationToken).ConfigureAwait(false),
requestMessages,
session,
cancellationToken: cancellationToken);

List<AgentResponseUpdate> updates = [];
Expand All @@ -241,8 +248,8 @@ await this.EnsureSessionAsync(context, cancellationToken).ConfigureAwait(false),
else
{
// Otherwise, run the agent in non-streaming mode.
response = await this._agent.RunAsync(messages,
await this.EnsureSessionAsync(context, cancellationToken).ConfigureAwait(false),
response = await this._agent.RunAsync(requestMessages,
session,
cancellationToken: cancellationToken)
.ConfigureAwait(false);

Expand All @@ -254,11 +261,147 @@ await this.EnsureSessionAsync(context, cancellationToken).ConfigureAwait(false),
await context.YieldOutputAsync(response, cancellationToken).ConfigureAwait(false);
}

if (shouldCaptureEnrichedRequestMessages)
{
await this.EmitEnrichedRequestMessagesAsync(historyBefore, session, context, cancellationToken).ConfigureAwait(false);
}

await collector.SubmitAsync(context, cancellationToken).ConfigureAwait(false);

return response;
}

private bool HasAIContextProviders()
=> this._agent.GetService<ChatClientAgent>()?.AIContextProviders is { Count: > 0 }
|| this._agent.GetService<ChatClientAgentOptions>()?.AIContextProviders?.Any() == true;

/// <summary>
/// Get a snapshot of the chat history for the given session.
/// </summary>
/// <param name="session"> The session to get the chat history for. </param>
/// <param name="cancellationToken"> Cancellation token. </param>
/// <returns></returns>
private async ValueTask<List<ChatMessage>?> GetStoredChatHistorySnapshotAsync(AgentSession session, CancellationToken cancellationToken)
Comment thread
XiongHaoTrigger marked this conversation as resolved.
{
ChatHistoryProvider? provider = this._agent.GetService<ChatHistoryProvider>();
if (provider is null)
{
return null;
}
// if the provider is InMemoryChatHistoryProvider, get the messages directly
if (provider is InMemoryChatHistoryProvider inMemoryProvider)
{
return [.. inMemoryProvider.GetMessages(session)];
}

// otherwise, invoke the provider to get the messages
#pragma warning disable MAAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
ChatHistoryProvider.InvokingContext invokingContext = new(this._agent, session, []);
#pragma warning restore MAAI001
IEnumerable<ChatMessage> messages = await provider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false);
return [.. messages];
}

/// <summary>
/// Detects request messages that were injected by <see cref="AIContextProvider"/> during the
/// latest agent invocation and raises them as <see cref="AgentAIContextProviderMsgEvent"/> so that
/// the workflow layer can persist them into its chat history.
/// </summary>
/// <remarks>
/// <para>
/// This method compares the agent's stored chat history before and after the agent run.
/// Any newly-added messages whose <see cref="AgentRequestMessageSourceType"/> equals
/// <see cref="AgentRequestMessageSourceType.AIContextProvider"/> are considered enriched
/// request messages and are forwarded to the workflow event stream.
/// </para>
/// <para>
/// If <paramref name="historyBefore"/> is <see langword="null"/> (e.g. the agent does not
/// expose a <see cref="ChatHistoryProvider"/>), this method performs no work.
/// </para>
/// </remarks>
/// <param name="historyBefore">
/// Snapshot of the agent's stored chat history taken before the agent was invoked.
/// </param>
/// <param name="session">The agent session used for the invocation.</param>
/// <param name="context">The current workflow context used to emit events.</param>
/// <param name="cancellationToken">
/// The <see cref="CancellationToken"/> to monitor for cancellation requests.
/// </param>
private async ValueTask EmitEnrichedRequestMessagesAsync(
List<ChatMessage>? historyBefore,
AgentSession session,
IWorkflowContext context,
CancellationToken cancellationToken)
{
if (historyBefore is null)
{
return;
}

List<ChatMessage>? historyAfter = await this.GetStoredChatHistorySnapshotAsync(session, cancellationToken).ConfigureAwait(false);
if (historyAfter is null)
{
return;
}

int firstNewMessageIndex = FindFirstDivergenceIndex(historyBefore, historyAfter);
if (firstNewMessageIndex >= historyAfter.Count)
{
return;
}

List<ChatMessage> enrichedRequestMessages =
[
.. historyAfter
.Skip(firstNewMessageIndex)
.Where(message => message.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.AIContextProvider)
];
Comment thread
XiongHaoTrigger marked this conversation as resolved.

if (enrichedRequestMessages.Count > 0)
{
await context.AddEventAsync(new AgentAIContextProviderMsgEvent(enrichedRequestMessages), cancellationToken).ConfigureAwait(false);
}
}

/// <summary>
/// Finds the first index where <paramref name="historyAfter"/> diverges from <paramref name="historyBefore"/>.
/// Ensure that the messages can be properly filled in when they are being truncated.
/// </summary>
/// <returns>The index of the first new or changed message.</returns>
private static int FindFirstDivergenceIndex(List<ChatMessage> historyBefore, List<ChatMessage> historyAfter)
{
int commonLength = Math.Min(historyBefore.Count, historyAfter.Count);
for (int i = 0; i < commonLength; i++)
{
if (!MessagesCompare(historyBefore[i], historyAfter[i]))
{
return i;
}
}

return commonLength;
}

/// <summary>
/// Compare two messages
/// </summary>
/// <param name="before">Previous messages</param>
/// <param name="after">Cuurrent messages</param>
/// <returns></returns>
private static bool MessagesCompare(ChatMessage before, ChatMessage after)
{
if (before.MessageId is not null && after.MessageId is not null)
{
return string.Equals(before.MessageId, after.MessageId, StringComparison.Ordinal);
}

return before.Role == after.Role
&& string.Equals(before.AuthorName, after.AuthorName, StringComparison.Ordinal)
&& string.Equals(before.Text, after.Text, StringComparison.Ordinal)
&& before.GetAgentRequestMessageSourceType() == after.GetAgentRequestMessageSourceType()
&& string.Equals(before.GetAgentRequestMessageSourceId(), after.GetAgentRequestMessageSourceId(), StringComparison.Ordinal);
}

/// <summary>
/// Content types that represent meaningful conversational content portable across agents.
/// Messages containing only content types not in this set (e.g. reasoning tokens, web search
Expand Down
5 changes: 5 additions & 0 deletions dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,11 @@ IAsyncEnumerable<AgentResponseUpdate> InvokeStageAsync(
yield return update;
break;

case AgentAIContextProviderMsgEvent requestMessages:
// Add the message in the AIContentProvider to the ChatHistoryProvider of the Workflow.
this.ChatHistoryProvider.AddMessages(this, requestMessages.Messages);
break;
Comment thread
XiongHaoTrigger marked this conversation as resolved.

case WorkflowErrorEvent workflowError:
Exception? exception = workflowError.Exception;
if (exception is TargetInvocationException tie && tie.InnerException != null)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using FluentAssertions;
using Microsoft.Extensions.AI;

namespace Microsoft.Agents.AI.Workflows.UnitTests;

/// <summary>
/// Validates that messages injected by <see cref="AIContextProvider"/> into an inner agent
/// are correctly persisted into the workflow's chat history, without leaking to downstream agents.
/// </summary>
public class AIContextProviderWorkflowTests
{
private const string UserText = "Where is Taggia?";
private const string ContextText = "Taggia is a city in Liguria.";
private const string FirstAgentResponseText = "Taggia is in Liguria.";

/// <summary>
/// Ensures that AIContextProvider-injected messages appear in the workflow session's
/// chat history and survive serialization (regression test for the bug where such
/// messages were lost because WorkflowHostAgent only persisted model outputs).
/// </summary>
[Fact]
public async Task Test_WorkflowAsAgent_SerializesAIContextProviderRequestMessagesAsync()
{
// Arrange
ChatClientAgent innerAgent = CreateContextAwareAgent();
AIAgent workflowAgent = AgentWorkflowBuilder.BuildSequential(innerAgent).AsAIAgent();
AgentSession session = await workflowAgent.CreateSessionAsync();

// Act
await workflowAgent.RunAsync(new ChatMessage(ChatRole.User, UserText), session);
JsonElement serializedSession = await workflowAgent.SerializeSessionAsync(session);

// Assert
WorkflowSession workflowSession = session.Should().BeOfType<WorkflowSession>().Subject;
string[] historyTexts =
[
.. workflowSession.ChatHistoryProvider
.GetAllMessages(workflowSession)
.Select(message => message.Text)
];

historyTexts.Should().Contain(UserText);
historyTexts.Should().Contain(ContextText);
historyTexts.Should().Contain(FirstAgentResponseText);
serializedSession.GetRawText().Should().Contain(ContextText);
}

/// <summary>
/// Ensures that AIContextProvider-injected messages are still persisted when inner chat history is pruned.
/// </summary>
[Fact]
public async Task Test_WorkflowAsAgent_SerializesAIContextProviderRequestMessagesWhenInnerHistoryIsPrunedAsync()
{
// Arrange
RetainingChatHistoryProvider chatHistoryProvider = new(maxStoredMessages: 2);
chatHistoryProvider.Add(new ChatMessage(ChatRole.User, "Previous question") { MessageId = "previous-user" });
chatHistoryProvider.Add(new ChatMessage(ChatRole.Assistant, "Previous answer") { MessageId = "previous-assistant" });
ChatClientAgent innerAgent = CreateContextAwareAgent(chatHistoryProvider);
AIAgent workflowAgent = AgentWorkflowBuilder.BuildSequential(innerAgent).AsAIAgent();
AgentSession session = await workflowAgent.CreateSessionAsync();

// Act
await workflowAgent.RunAsync(new ChatMessage(ChatRole.User, UserText), session);

// Assert
WorkflowSession workflowSession = session.Should().BeOfType<WorkflowSession>().Subject;
workflowSession.ChatHistoryProvider
.GetAllMessages(workflowSession)
.Select(message => message.Text)
.Should()
.Contain(ContextText);
}

/// <summary>
/// Ensures that AIContextProvider-injected messages are saved to workflow history
/// but are NOT forwarded as part of the input to subsequent agents in the workflow.
/// </summary>
[Fact]
public async Task Test_WorkflowAsAgent_DoesNotForwardAIContextProviderRequestMessagesToDownstreamAgentAsync()
{
// Arrange
ChatClientAgent innerAgent = CreateContextAwareAgent();
RecordingEchoAgent downstreamAgent = new(id: "downstream", name: "downstream", prefix: "downstream:");
AIAgent workflowAgent = AgentWorkflowBuilder.BuildSequential(innerAgent, downstreamAgent).AsAIAgent();

// Act
await workflowAgent.RunAsync(new ChatMessage(ChatRole.User, UserText), await workflowAgent.CreateSessionAsync());

// Assert
downstreamAgent.RecordedInputs.Should().ContainSingle();
string[] downstreamTexts = [.. downstreamAgent.RecordedInputs[0].Select(message => message.Text)];
downstreamTexts.Should().Contain(FirstAgentResponseText);
downstreamTexts.Should().NotContain(ContextText);
}

/// <summary>Builds an agent whose IChatClient always replies with <see cref="FirstAgentResponseText"/>, prepopulated with a <see cref="StaticAIContextProvider"/>.</summary>
private static ChatClientAgent CreateContextAwareAgent(ChatHistoryProvider? chatHistoryProvider = null)
{
return new ChatClientAgent(
new StubChatClient(_ => new ChatResponse([new ChatMessage(ChatRole.Assistant, FirstAgentResponseText)])),
new ChatClientAgentOptions
{
Name = "inner",
ChatHistoryProvider = chatHistoryProvider,
AIContextProviders = [new StaticAIContextProvider(ContextText)]
});
}

/// <summary>Always injects a single System message containing the configured text.</summary>
private sealed class StaticAIContextProvider(string text) : AIContextProvider
{
protected override ValueTask<AIContext> ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
return new(new AIContext
{
Messages = [new ChatMessage(ChatRole.System, text)]
});
}
}

private sealed class RetainingChatHistoryProvider(int maxStoredMessages) : ChatHistoryProvider
{
private readonly List<ChatMessage> _messages = [];

public void Add(ChatMessage message)
{
this._messages.Add(message);
}

protected override ValueTask<IEnumerable<ChatMessage>> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
return new(this._messages.Concat(context.RequestMessages));
}

protected override ValueTask StoreChatHistoryAsync(InvokedContext context, CancellationToken cancellationToken = default)
{
this._messages.AddRange(context.RequestMessages);
if (context.ResponseMessages is not null)
{
this._messages.AddRange(context.ResponseMessages);
}

if (this._messages.Count > maxStoredMessages)
{
this._messages.RemoveRange(0, this._messages.Count - maxStoredMessages);
}

return default;
}
}

/// <summary>Test double for <see cref="IChatClient"/> that returns deterministic responses via the supplied factory.</summary>
private sealed class StubChatClient(Func<IEnumerable<ChatMessage>, ChatResponse> responseFactory) : IChatClient
{
public Task<ChatResponse> GetResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
=> Task.FromResult(responseFactory(messages));

public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
IEnumerable<ChatMessage> messages,
ChatOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
ChatResponse response = await this.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
foreach (ChatResponseUpdate update in response.ToChatResponseUpdates())
{
yield return update;
}
}

public object? GetService(Type serviceType, object? serviceKey = null) => null;

public void Dispose()
{
}
}
}