diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/HandoffAgentExecutorTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/HandoffAgentExecutorTests.cs index 1a5b2ea4d1..29617ac0ed 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/HandoffAgentExecutorTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/HandoffAgentExecutorTests.cs @@ -1,8 +1,14 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; using System.Threading.Tasks; +using FluentAssertions; using Microsoft.Agents.AI.Workflows.Specialized; +using Microsoft.Extensions.AI; namespace Microsoft.Agents.AI.Workflows.UnitTests; @@ -68,4 +74,96 @@ public async Task Test_HandoffAgentExecutor_EmitsResponseIFFConfiguredAsync(bool AgentResponseEvent[] updates = testContext.Events.OfType().ToArray(); CheckResponseEventsAgainstTestMessages(updates, expectingResponse: executorSetting, agent.GetDescriptiveId()); } + + [Fact] + public async Task Test_HandoffAgentExecutor_DoesNotBashOverExistingInstructionsAndToolsAsync() + { + // Arrange + const string BaseInstructions = "BaseInstructions"; + const string HandoffInstructions = "HandoffInstructions"; + + AITool someTool = AIFunctionFactory.CreateDeclaration("BaseTool", null, AIFunctionFactory.Create(() => { }).JsonSchema); + + OptionValidatingChatClient chatClient = new(BaseInstructions, HandoffInstructions, someTool); + AIAgent handoffAgent = chatClient.AsAIAgent(BaseInstructions, tools: [someTool]); + AIAgent targetAgent = new TestEchoAgent(); + + HandoffAgentExecutorOptions options = new("HandoffInstructions", false, null, HandoffToolCallFilteringBehavior.None); + HandoffTarget handoff = new(new TestEchoAgent()); + HandoffAgentExecutor executor = new(handoffAgent, [handoff], options); + + TestWorkflowContext testContext = new(executor.Id); + HandoffState state = new(new(false), null, [], null); + + // Act / Assert + Func runStreamingAsync = async () => await executor.HandleAsync(state, testContext); + await runStreamingAsync.Should().NotThrowAsync(); + } + + private sealed class OptionValidatingChatClient(string baseInstructions, string handoffInstructions, AITool baseTool) : IChatClient + { + public void Dispose() + { + } + + private void CheckOptions(ChatOptions? options) + { + options.Should().NotBeNull(); + + options.Instructions.Should().Contain(baseInstructions, because: "Handoff orchestration should not bash over existing instructions."); + options.Instructions.Should().Contain(handoffInstructions, because: "Handoff orchestration should inject handoff instructions."); + + options.Tools.Should().Contain(tool => tool.Name == baseTool.Name, "Handoff orchestration should not bash over existing tools."); + options.Tools.Should().Contain(tool => tool.Name.StartsWith(HandoffWorkflowBuilder.FunctionPrefix), + because: "Handoff orchestration should inject handoff tools."); + } + + private List ResponseMessages => + [ + new ChatMessage(ChatRole.Assistant, "Ok") + { + MessageId = Guid.NewGuid().ToString(), + AuthorName = nameof(OptionValidatingChatClient) + } + ]; + + public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + this.CheckOptions(options); + + ChatResponse response = new(this.ResponseMessages) + { + ResponseId = Guid.NewGuid().ToString("N"), + CreatedAt = DateTimeOffset.Now + }; + + return Task.FromResult(response); + } + + public object? GetService(Type serviceType, object? serviceKey = null) + { + if (serviceType == typeof(OptionValidatingChatClient)) + { + return this; + } + + return null; + } + + public async IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + this.CheckOptions(options); + + string responseId = Guid.NewGuid().ToString("N"); + foreach (ChatMessage message in this.ResponseMessages) + { + yield return new(message.Role, message.Contents) + { + ResponseId = responseId, + MessageId = message.MessageId, + CreatedAt = DateTimeOffset.Now + }; + } + } + } }