diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index 78ebe757..f7628df7 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -17,6 +17,7 @@ package com.google.adk.agents; import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNullElse; import static java.util.stream.Collectors.joining; import com.fasterxml.jackson.core.JsonProcessingException; @@ -103,12 +104,12 @@ public enum IncludeContents { private final Optional maxSteps; private final boolean disallowTransferToParent; private final boolean disallowTransferToPeers; - private final Optional> beforeModelCallback; - private final Optional> afterModelCallback; - private final Optional> onModelErrorCallback; - private final Optional> beforeToolCallback; - private final Optional> afterToolCallback; - private final Optional> onToolErrorCallback; + private final ImmutableList beforeModelCallback; + private final ImmutableList afterModelCallback; + private final ImmutableList onModelErrorCallback; + private final ImmutableList beforeToolCallback; + private final ImmutableList afterToolCallback; + private final ImmutableList onToolErrorCallback; private final Optional inputSchema; private final Optional outputSchema; private final Optional executor; @@ -126,29 +127,28 @@ protected LlmAgent(Builder builder) { builder.beforeAgentCallback, builder.afterAgentCallback); this.model = Optional.ofNullable(builder.model); - this.instruction = - builder.instruction == null ? new Instruction.Static("") : builder.instruction; + this.instruction = requireNonNullElse(builder.instruction, new Instruction.Static("")); this.globalInstruction = - builder.globalInstruction == null ? new Instruction.Static("") : builder.globalInstruction; + requireNonNullElse(builder.globalInstruction, new Instruction.Static("")); this.generateContentConfig = Optional.ofNullable(builder.generateContentConfig); this.exampleProvider = Optional.ofNullable(builder.exampleProvider); - this.includeContents = - builder.includeContents != null ? builder.includeContents : IncludeContents.DEFAULT; + this.includeContents = requireNonNullElse(builder.includeContents, IncludeContents.DEFAULT); this.planning = builder.planning != null && builder.planning; this.maxSteps = Optional.ofNullable(builder.maxSteps); this.disallowTransferToParent = builder.disallowTransferToParent; this.disallowTransferToPeers = builder.disallowTransferToPeers; - this.beforeModelCallback = Optional.ofNullable(builder.beforeModelCallback); - this.afterModelCallback = Optional.ofNullable(builder.afterModelCallback); - this.onModelErrorCallback = Optional.ofNullable(builder.onModelErrorCallback); - this.beforeToolCallback = Optional.ofNullable(builder.beforeToolCallback); - this.afterToolCallback = Optional.ofNullable(builder.afterToolCallback); - this.onToolErrorCallback = Optional.ofNullable(builder.onToolErrorCallback); + this.beforeModelCallback = requireNonNullElse(builder.beforeModelCallback, ImmutableList.of()); + this.afterModelCallback = requireNonNullElse(builder.afterModelCallback, ImmutableList.of()); + this.onModelErrorCallback = + requireNonNullElse(builder.onModelErrorCallback, ImmutableList.of()); + this.beforeToolCallback = requireNonNullElse(builder.beforeToolCallback, ImmutableList.of()); + this.afterToolCallback = requireNonNullElse(builder.afterToolCallback, ImmutableList.of()); + this.onToolErrorCallback = requireNonNullElse(builder.onToolErrorCallback, ImmutableList.of()); this.inputSchema = Optional.ofNullable(builder.inputSchema); this.outputSchema = Optional.ofNullable(builder.outputSchema); this.executor = Optional.ofNullable(builder.executor); this.outputKey = Optional.ofNullable(builder.outputKey); - this.toolsUnion = builder.toolsUnion != null ? builder.toolsUnion : ImmutableList.of(); + this.toolsUnion = requireNonNullElse(builder.toolsUnion, ImmutableList.of()); this.toolsets = extractToolsets(this.toolsUnion); this.codeExecutor = Optional.ofNullable(builder.codeExecutor); @@ -850,27 +850,27 @@ public boolean disallowTransferToPeers() { return disallowTransferToPeers; } - public Optional> beforeModelCallback() { + public List beforeModelCallback() { return beforeModelCallback; } - public Optional> afterModelCallback() { + public List afterModelCallback() { return afterModelCallback; } - public Optional> beforeToolCallback() { + public List beforeToolCallback() { return beforeToolCallback; } - public Optional> afterToolCallback() { + public List afterToolCallback() { return afterToolCallback; } - public Optional> onModelErrorCallback() { + public List onModelErrorCallback() { return onModelErrorCallback; } - public Optional> onToolErrorCallback() { + public List onToolErrorCallback() { return onToolErrorCallback; } @@ -880,7 +880,7 @@ public Optional> onToolErrorCallback() { *

This method is only for use by Agent Development Kit. */ public List canonicalBeforeModelCallbacks() { - return beforeModelCallback.orElse(ImmutableList.of()); + return beforeModelCallback; } /** @@ -889,7 +889,7 @@ public List canonicalBeforeModelCallbacks() { *

This method is only for use by Agent Development Kit. */ public List canonicalAfterModelCallbacks() { - return afterModelCallback.orElse(ImmutableList.of()); + return afterModelCallback; } /** @@ -898,7 +898,7 @@ public List canonicalAfterModelCallbacks() { *

This method is only for use by Agent Development Kit. */ public List canonicalOnModelErrorCallbacks() { - return onModelErrorCallback.orElse(ImmutableList.of()); + return onModelErrorCallback; } /** @@ -907,7 +907,7 @@ public List canonicalOnModelErrorCallbacks() { *

This method is only for use by Agent Development Kit. */ public List canonicalBeforeToolCallbacks() { - return beforeToolCallback.orElse(ImmutableList.of()); + return beforeToolCallback; } /** @@ -916,7 +916,7 @@ public List canonicalBeforeToolCallbacks() { *

This method is only for use by Agent Development Kit. */ public List canonicalAfterToolCallbacks() { - return afterToolCallback.orElse(ImmutableList.of()); + return afterToolCallback; } /** @@ -925,7 +925,7 @@ public List canonicalAfterToolCallbacks() { *

This method is only for use by Agent Development Kit. */ public List canonicalOnToolErrorCallbacks() { - return onToolErrorCallback.orElse(ImmutableList.of()); + return onToolErrorCallback; } public Optional inputSchema() { diff --git a/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java b/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java index 4825efca..11e07a09 100644 --- a/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java +++ b/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java @@ -1161,20 +1161,25 @@ public void fromConfig_withConfiguredCallbacks_resolvesCallbacks() String pfx = "test.callbacks."; registry.register( - pfx + "before_agent_1", (Callbacks.BeforeAgentCallback) (ctx) -> Maybe.empty()); + pfx + "before_agent_1", (Callbacks.BeforeAgentCallback) (unusedCtx) -> Maybe.empty()); registry.register( - pfx + "before_agent_2", (Callbacks.BeforeAgentCallback) (ctx) -> Maybe.empty()); - registry.register(pfx + "after_agent_1", (Callbacks.AfterAgentCallback) (ctx) -> Maybe.empty()); + pfx + "before_agent_2", (Callbacks.BeforeAgentCallback) (unusedCtx) -> Maybe.empty()); registry.register( - pfx + "before_model_1", (Callbacks.BeforeModelCallback) (ctx, req) -> Maybe.empty()); + pfx + "after_agent_1", (Callbacks.AfterAgentCallback) (unusedCtx) -> Maybe.empty()); registry.register( - pfx + "after_model_1", (Callbacks.AfterModelCallback) (ctx, resp) -> Maybe.empty()); + pfx + "before_model_1", + (Callbacks.BeforeModelCallback) (unusedCtx, unusedReq) -> Maybe.empty()); + registry.register( + pfx + "after_model_1", + (Callbacks.AfterModelCallback) (unusedCtx, unusedResp) -> Maybe.empty()); registry.register( pfx + "before_tool_1", - (Callbacks.BeforeToolCallback) (inv, tool, args, toolCtx) -> Maybe.empty()); + (Callbacks.BeforeToolCallback) + (unusedInv, unusedTool, unusedArgs, unusedToolCtx) -> Maybe.empty()); registry.register( pfx + "after_tool_1", - (Callbacks.AfterToolCallback) (inv, tool, args, toolCtx, resp) -> Maybe.empty()); + (Callbacks.AfterToolCallback) + (unusedInv, unusedTool, unusedArgs, unusedToolCtx, unusedResp) -> Maybe.empty()); File configFile = tempFolder.newFile("with_callbacks.yaml"); Files.writeString( @@ -1209,15 +1214,11 @@ public void fromConfig_withConfiguredCallbacks_resolvesCallbacks() assertThat(agent.afterAgentCallback()).isPresent(); assertThat(agent.afterAgentCallback().get()).hasSize(1); - assertThat(llm.beforeModelCallback()).isPresent(); - assertThat(llm.beforeModelCallback().get()).hasSize(1); - assertThat(llm.afterModelCallback()).isPresent(); - assertThat(llm.afterModelCallback().get()).hasSize(1); + assertThat(llm.beforeModelCallback()).hasSize(1); + assertThat(llm.afterModelCallback()).hasSize(1); - assertThat(llm.beforeToolCallback()).isPresent(); - assertThat(llm.beforeToolCallback().get()).hasSize(1); - assertThat(llm.afterToolCallback()).isPresent(); - assertThat(llm.afterToolCallback().get()).hasSize(1); + assertThat(llm.beforeToolCallback()).hasSize(1); + assertThat(llm.afterToolCallback()).hasSize(1); } @Test diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index 519c9055..8a2ff6df 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -341,6 +341,13 @@ public void canonicalCallbacks_returnsEmptyListWhenNull() { assertThat(agent.canonicalBeforeToolCallbacks()).isEmpty(); assertThat(agent.canonicalAfterToolCallbacks()).isEmpty(); assertThat(agent.canonicalOnToolErrorCallbacks()).isEmpty(); + + assertThat(agent.beforeModelCallback()).isEmpty(); + assertThat(agent.afterModelCallback()).isEmpty(); + assertThat(agent.onModelErrorCallback()).isEmpty(); + assertThat(agent.beforeToolCallback()).isEmpty(); + assertThat(agent.afterToolCallback()).isEmpty(); + assertThat(agent.onToolErrorCallback()).isEmpty(); } @Test @@ -371,5 +378,12 @@ public void canonicalCallbacks_returnsListWhenPresent() { assertThat(agent.canonicalBeforeToolCallbacks()).containsExactly(btc); assertThat(agent.canonicalAfterToolCallbacks()).containsExactly(atc); assertThat(agent.canonicalOnToolErrorCallbacks()).containsExactly(otec); + + assertThat(agent.beforeModelCallback()).containsExactly(bmc); + assertThat(agent.afterModelCallback()).containsExactly(amc); + assertThat(agent.onModelErrorCallback()).containsExactly(omec); + assertThat(agent.beforeToolCallback()).containsExactly(btc); + assertThat(agent.afterToolCallback()).containsExactly(atc); + assertThat(agent.onToolErrorCallback()).containsExactly(otec); } }