Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ protected AIContextProvider(
/// </summary>
protected Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> StoreInputResponseMessageFilter { get; }

/// <summary>
/// Merges input messages with messages provided by this context provider.
/// </summary>
/// <param name="inputMessages">The original input messages for the invocation.</param>
/// <param name="providedMessages">The messages provided by this context provider, after source attribution has been applied.</param>
/// <returns>The merged messages to use for the invocation.</returns>
/// <remarks>
/// The default implementation appends provided messages after input messages.
/// Override this method to customize placement, for example to prepend provided context before user messages.
/// </remarks>
protected virtual IEnumerable<ChatMessage> MergeMessages(IEnumerable<ChatMessage> inputMessages, IEnumerable<ChatMessage> providedMessages)
=> inputMessages.Concat(providedMessages);

/// <summary>
/// Gets the set of keys used to store the provider state in the <see cref="AgentSession.StateBag"/>.
/// </summary>
Expand Down Expand Up @@ -180,7 +193,7 @@ protected virtual async ValueTask<AIContext> InvokingCoreAsync(InvokingContext c
(null, null) => null,
(var a, null) => a,
(null, var b) => b,
(var a, var b) => a.Concat(b)
(var a, var b) => this.MergeMessages(a, b)
};

var mergedTools = (inputContext.Tools, provided.Tools) switch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ protected virtual async ValueTask<IEnumerable<ChatMessage>> InvokingCoreAsync(In

// Stamp and merge provided messages.
providedMessages = providedMessages.Select(m => m.WithAgentRequestMessageSource(AgentRequestMessageSourceType.AIContextProvider, this.GetType().FullName!));
return inputMessages.Concat(providedMessages);
return this.MergeMessages(inputMessages, providedMessages);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -694,10 +694,34 @@ public async Task InvokedCoreAsync_RequestAndResponseFiltersOperateIndependently

#endregion

#region MergeMessages Tests

[Fact]
public async Task InvokingAsync_OverriddenMergeMessages_CanPrependProvidedMessagesAsync()
{
// Arrange
var provideContext = new AIContext { Messages = [new ChatMessage(ChatRole.System, "Context")] };
var provider = new TestAIContextProvider(provideContext: provideContext, prependProvidedMessages: true);
var inputContext = new AIContext { Messages = [new ChatMessage(ChatRole.User, "User input")] };
var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, inputContext);

// Act
var result = await provider.InvokingAsync(context);
var messages = result.Messages!.ToList();

// Assert
Assert.Equal(2, messages.Count);
Assert.Equal("Context", messages[0].Text);
Assert.Equal("User input", messages[1].Text);
}

#endregion

private sealed class TestAIContextProvider : AIContextProvider
{
private readonly AIContext? _provideContext;
private readonly bool _captureFilteredContext;
private readonly bool _prependProvidedMessages;

public InvokedContext? LastStoredContext { get; private set; }

Expand All @@ -708,13 +732,18 @@ public TestAIContextProvider(
bool captureFilteredContext = false,
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? provideInputMessageFilter = null,
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? storeInputRequestMessageFilter = null,
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? storeInputResponseMessageFilter = null)
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? storeInputResponseMessageFilter = null,
bool prependProvidedMessages = false)
: base(provideInputMessageFilter, storeInputRequestMessageFilter, storeInputResponseMessageFilter)
{
this._provideContext = provideContext;
this._captureFilteredContext = captureFilteredContext;
this._prependProvidedMessages = prependProvidedMessages;
}

protected override IEnumerable<ChatMessage> MergeMessages(IEnumerable<ChatMessage> inputMessages, IEnumerable<ChatMessage> providedMessages)
=> this._prependProvidedMessages ? providedMessages.Concat(inputMessages) : base.MergeMessages(inputMessages, providedMessages);

protected override ValueTask<AIContext> ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
if (this._captureFilteredContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,27 @@ public void InvokingContext_RequestMessages_SetterAcceptsValidValue()

#endregion

#region MergeMessages Tests

[Fact]
public async Task InvokingAsync_OverriddenMergeMessages_CanPrependProvidedMessagesAsync()
{
// Arrange
var providedMessages = new[] { new ChatMessage(ChatRole.System, "Context message") };
var provider = new TestMessageProvider(provideMessages: providedMessages, prependProvidedMessages: true);
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "User input")]);

// Act
var result = (await provider.InvokingAsync(context)).ToList();

// Assert
Assert.Equal(2, result.Count);
Assert.Equal("Context message", result[0].Text);
Assert.Equal("User input", result[1].Text);
}

#endregion

#region GetService Tests

[Fact]
Expand All @@ -289,20 +310,26 @@ private sealed class TestMessageProvider : MessageAIContextProvider
{
private readonly IEnumerable<ChatMessage>? _provideMessages;
private readonly bool _captureFilteredContext;
private readonly bool _prependProvidedMessages;

public InvokingContext? LastFilteredContext { get; private set; }

public TestMessageProvider(
IEnumerable<ChatMessage>? provideMessages = null,
bool captureFilteredContext = false,
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? provideInputMessageFilter = null,
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? storeInputMessageFilter = null)
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? storeInputMessageFilter = null,
bool prependProvidedMessages = false)
: base(provideInputMessageFilter, storeInputMessageFilter)
{
this._provideMessages = provideMessages;
this._captureFilteredContext = captureFilteredContext;
this._prependProvidedMessages = prependProvidedMessages;
}

protected override IEnumerable<ChatMessage> MergeMessages(IEnumerable<ChatMessage> inputMessages, IEnumerable<ChatMessage> providedMessages)
=> this._prependProvidedMessages ? providedMessages.Concat(inputMessages) : base.MergeMessages(inputMessages, providedMessages);

protected override ValueTask<IEnumerable<ChatMessage>> ProvideMessagesAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
if (this._captureFilteredContext)
Expand Down
Loading