diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs index 641825d1dc..290f353cf9 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs @@ -79,6 +79,19 @@ protected AIContextProvider( /// protected Func, IEnumerable> StoreInputResponseMessageFilter { get; } + /// + /// Merges input messages with messages provided by this context provider. + /// + /// The original input messages for the invocation. + /// The messages provided by this context provider, after source attribution has been applied. + /// The merged messages to use for the invocation. + /// + /// The default implementation appends provided messages after input messages. + /// Override this method to customize placement, for example to prepend provided context before user messages. + /// + protected virtual IEnumerable MergeMessages(IEnumerable inputMessages, IEnumerable providedMessages) + => inputMessages.Concat(providedMessages); + /// /// Gets the set of keys used to store the provider state in the . /// @@ -180,7 +193,7 @@ protected virtual async ValueTask InvokingCoreAsync(InvokingContext c (null, null) => null, (var a, null) => a, (null, var b) => b, - (var a, var b) => a.Concat(b) + (var a, var b) => this.MergeMessages(a, b) }; var mergedTools = (inputContext.Tools, provided.Tools) switch diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/MessageAIContextProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/MessageAIContextProvider.cs index 041e621341..546a9c3043 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/MessageAIContextProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/MessageAIContextProvider.cs @@ -124,7 +124,7 @@ protected virtual async ValueTask> InvokingCoreAsync(In // Stamp and merge provided messages. providedMessages = providedMessages.Select(m => m.WithAgentRequestMessageSource(AgentRequestMessageSourceType.AIContextProvider, this.GetType().FullName!)); - return inputMessages.Concat(providedMessages); + return this.MergeMessages(inputMessages, providedMessages); } /// diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs index 0e664d1ac9..bcdfe94f3c 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs @@ -694,10 +694,34 @@ public async Task InvokedCoreAsync_RequestAndResponseFiltersOperateIndependently #endregion + #region MergeMessages Tests + + [Fact] + public async Task InvokingAsync_OverriddenMergeMessages_CanPrependProvidedMessagesAsync() + { + // Arrange + var provideContext = new AIContext { Messages = [new ChatMessage(ChatRole.System, "Context")] }; + var provider = new TestAIContextProvider(provideContext: provideContext, prependProvidedMessages: true); + var inputContext = new AIContext { Messages = [new ChatMessage(ChatRole.User, "User input")] }; + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, inputContext); + + // Act + var result = await provider.InvokingAsync(context); + var messages = result.Messages!.ToList(); + + // Assert + Assert.Equal(2, messages.Count); + Assert.Equal("Context", messages[0].Text); + Assert.Equal("User input", messages[1].Text); + } + + #endregion + private sealed class TestAIContextProvider : AIContextProvider { private readonly AIContext? _provideContext; private readonly bool _captureFilteredContext; + private readonly bool _prependProvidedMessages; public InvokedContext? LastStoredContext { get; private set; } @@ -708,13 +732,18 @@ public TestAIContextProvider( bool captureFilteredContext = false, Func, IEnumerable>? provideInputMessageFilter = null, Func, IEnumerable>? storeInputRequestMessageFilter = null, - Func, IEnumerable>? storeInputResponseMessageFilter = null) + Func, IEnumerable>? storeInputResponseMessageFilter = null, + bool prependProvidedMessages = false) : base(provideInputMessageFilter, storeInputRequestMessageFilter, storeInputResponseMessageFilter) { this._provideContext = provideContext; this._captureFilteredContext = captureFilteredContext; + this._prependProvidedMessages = prependProvidedMessages; } + protected override IEnumerable MergeMessages(IEnumerable inputMessages, IEnumerable providedMessages) + => this._prependProvidedMessages ? providedMessages.Concat(inputMessages) : base.MergeMessages(inputMessages, providedMessages); + protected override ValueTask ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default) { if (this._captureFilteredContext) diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/MessageAIContextProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/MessageAIContextProviderTests.cs index 8c11de6b62..18ef0a062c 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/MessageAIContextProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/MessageAIContextProviderTests.cs @@ -267,6 +267,27 @@ public void InvokingContext_RequestMessages_SetterAcceptsValidValue() #endregion + #region MergeMessages Tests + + [Fact] + public async Task InvokingAsync_OverriddenMergeMessages_CanPrependProvidedMessagesAsync() + { + // Arrange + var providedMessages = new[] { new ChatMessage(ChatRole.System, "Context message") }; + var provider = new TestMessageProvider(provideMessages: providedMessages, prependProvidedMessages: true); + var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "User input")]); + + // Act + var result = (await provider.InvokingAsync(context)).ToList(); + + // Assert + Assert.Equal(2, result.Count); + Assert.Equal("Context message", result[0].Text); + Assert.Equal("User input", result[1].Text); + } + + #endregion + #region GetService Tests [Fact] @@ -289,6 +310,7 @@ private sealed class TestMessageProvider : MessageAIContextProvider { private readonly IEnumerable? _provideMessages; private readonly bool _captureFilteredContext; + private readonly bool _prependProvidedMessages; public InvokingContext? LastFilteredContext { get; private set; } @@ -296,13 +318,18 @@ public TestMessageProvider( IEnumerable? provideMessages = null, bool captureFilteredContext = false, Func, IEnumerable>? provideInputMessageFilter = null, - Func, IEnumerable>? storeInputMessageFilter = null) + Func, IEnumerable>? storeInputMessageFilter = null, + bool prependProvidedMessages = false) : base(provideInputMessageFilter, storeInputMessageFilter) { this._provideMessages = provideMessages; this._captureFilteredContext = captureFilteredContext; + this._prependProvidedMessages = prependProvidedMessages; } + protected override IEnumerable MergeMessages(IEnumerable inputMessages, IEnumerable providedMessages) + => this._prependProvidedMessages ? providedMessages.Concat(inputMessages) : base.MergeMessages(inputMessages, providedMessages); + protected override ValueTask> ProvideMessagesAsync(InvokingContext context, CancellationToken cancellationToken = default) { if (this._captureFilteredContext)