diff --git a/core/src/main/java/com/google/adk/agents/BaseAgentState.java b/core/src/main/java/com/google/adk/agents/BaseAgentState.java new file mode 100644 index 000000000..dedcb93ab --- /dev/null +++ b/core/src/main/java/com/google/adk/agents/BaseAgentState.java @@ -0,0 +1,39 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.agents; + +import com.google.adk.JsonBaseModel; + +/** Base class for all agent states. */ +public class BaseAgentState extends JsonBaseModel { + + protected BaseAgentState() {} + + /** Returns a new {@link Builder} for creating {@link BaseAgentState} instances. */ + public static Builder builder() { + return new Builder(); + } + + /** Builder for {@link BaseAgentState}. */ + public static class Builder { + private Builder() {} + + public BaseAgentState build() { + return new BaseAgentState(); + } + } +} diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index 5add5dc9f..ace00db4c 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -50,6 +50,8 @@ public class InvocationContext { private final Session session; private final Optional userContent; private final RunConfig runConfig; + private final Map agentStates; + private final Map endOfAgents; private final ResumabilityConfig resumabilityConfig; private final InvocationCostManager invocationCostManager; @@ -71,6 +73,8 @@ protected InvocationContext(Builder builder) { this.userContent = builder.userContent; this.runConfig = builder.runConfig; this.endInvocation = builder.endInvocation; + this.agentStates = builder.agentStates; + this.endOfAgents = builder.endOfAgents; this.resumabilityConfig = builder.resumabilityConfig; this.invocationCostManager = builder.invocationCostManager; } @@ -299,6 +303,16 @@ public RunConfig runConfig() { return runConfig; } + /** Returns agent-specific state saved within this invocation. */ + public Map agentStates() { + return agentStates; + } + + /** Returns map of agents that ended during this invocation. */ + public Map endOfAgents() { + return endOfAgents; + } + /** * Returns whether this invocation should be ended, e.g., due to reaching a terminal state or * error. @@ -410,6 +424,8 @@ private Builder(InvocationContext context) { this.userContent = context.userContent; this.runConfig = context.runConfig; this.endInvocation = context.endInvocation; + this.agentStates = new ConcurrentHashMap<>(context.agentStates); + this.endOfAgents = new ConcurrentHashMap<>(context.endOfAgents); this.resumabilityConfig = context.resumabilityConfig; this.invocationCostManager = context.invocationCostManager; } @@ -427,6 +443,8 @@ private Builder(InvocationContext context) { private Optional userContent = Optional.empty(); private RunConfig runConfig = RunConfig.builder().build(); private boolean endInvocation = false; + private Map agentStates = new ConcurrentHashMap<>(); + private Map endOfAgents = new ConcurrentHashMap<>(); private ResumabilityConfig resumabilityConfig = new ResumabilityConfig(); private InvocationCostManager invocationCostManager = new InvocationCostManager(); @@ -616,6 +634,30 @@ public Builder endInvocation(boolean endInvocation) { return this; } + /** + * Sets agent-specific state saved within this invocation. + * + * @param agentStates agent-specific state saved within this invocation. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder agentStates(Map agentStates) { + this.agentStates = agentStates; + return this; + } + + /** + * Sets agent end-of-invocation status. + * + * @param endOfAgents agent end-of-invocation status. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder endOfAgents(Map endOfAgents) { + this.endOfAgents = endOfAgents; + return this; + } + /** * Sets the resumability configuration for the current agent run. * @@ -660,6 +702,8 @@ public boolean equals(Object o) { && Objects.equals(session, that.session) && Objects.equals(userContent, that.userContent) && Objects.equals(runConfig, that.runConfig) + && Objects.equals(agentStates, that.agentStates) + && Objects.equals(endOfAgents, that.endOfAgents) && Objects.equals(resumabilityConfig, that.resumabilityConfig) && Objects.equals(invocationCostManager, that.invocationCostManager); } @@ -680,6 +724,8 @@ public int hashCode() { userContent, runConfig, endInvocation, + agentStates, + endOfAgents, resumabilityConfig, invocationCostManager); } diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index f7628df71..444985971 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -17,7 +17,6 @@ package com.google.adk.agents; import static com.google.common.collect.ImmutableList.toImmutableList; -import static java.util.Objects.requireNonNullElse; import static java.util.stream.Collectors.joining; import com.fasterxml.jackson.core.JsonProcessingException; @@ -104,12 +103,12 @@ public enum IncludeContents { private final Optional maxSteps; private final boolean disallowTransferToParent; private final boolean disallowTransferToPeers; - private final ImmutableList beforeModelCallback; - private final ImmutableList afterModelCallback; - private final ImmutableList onModelErrorCallback; - private final ImmutableList beforeToolCallback; - private final ImmutableList afterToolCallback; - private final ImmutableList onToolErrorCallback; + private final Optional> beforeModelCallback; + private final Optional> afterModelCallback; + private final Optional> onModelErrorCallback; + private final Optional> beforeToolCallback; + private final Optional> afterToolCallback; + private final Optional> onToolErrorCallback; private final Optional inputSchema; private final Optional outputSchema; private final Optional executor; @@ -127,28 +126,29 @@ protected LlmAgent(Builder builder) { builder.beforeAgentCallback, builder.afterAgentCallback); this.model = Optional.ofNullable(builder.model); - this.instruction = requireNonNullElse(builder.instruction, new Instruction.Static("")); + this.instruction = + builder.instruction == null ? new Instruction.Static("") : builder.instruction; this.globalInstruction = - requireNonNullElse(builder.globalInstruction, new Instruction.Static("")); + builder.globalInstruction == null ? new Instruction.Static("") : builder.globalInstruction; this.generateContentConfig = Optional.ofNullable(builder.generateContentConfig); this.exampleProvider = Optional.ofNullable(builder.exampleProvider); - this.includeContents = requireNonNullElse(builder.includeContents, IncludeContents.DEFAULT); + this.includeContents = + builder.includeContents != null ? builder.includeContents : IncludeContents.DEFAULT; this.planning = builder.planning != null && builder.planning; this.maxSteps = Optional.ofNullable(builder.maxSteps); this.disallowTransferToParent = builder.disallowTransferToParent; this.disallowTransferToPeers = builder.disallowTransferToPeers; - this.beforeModelCallback = requireNonNullElse(builder.beforeModelCallback, ImmutableList.of()); - this.afterModelCallback = requireNonNullElse(builder.afterModelCallback, ImmutableList.of()); - this.onModelErrorCallback = - requireNonNullElse(builder.onModelErrorCallback, ImmutableList.of()); - this.beforeToolCallback = requireNonNullElse(builder.beforeToolCallback, ImmutableList.of()); - this.afterToolCallback = requireNonNullElse(builder.afterToolCallback, ImmutableList.of()); - this.onToolErrorCallback = requireNonNullElse(builder.onToolErrorCallback, ImmutableList.of()); + this.beforeModelCallback = Optional.ofNullable(builder.beforeModelCallback); + this.afterModelCallback = Optional.ofNullable(builder.afterModelCallback); + this.onModelErrorCallback = Optional.ofNullable(builder.onModelErrorCallback); + this.beforeToolCallback = Optional.ofNullable(builder.beforeToolCallback); + this.afterToolCallback = Optional.ofNullable(builder.afterToolCallback); + this.onToolErrorCallback = Optional.ofNullable(builder.onToolErrorCallback); this.inputSchema = Optional.ofNullable(builder.inputSchema); this.outputSchema = Optional.ofNullable(builder.outputSchema); this.executor = Optional.ofNullable(builder.executor); this.outputKey = Optional.ofNullable(builder.outputKey); - this.toolsUnion = requireNonNullElse(builder.toolsUnion, ImmutableList.of()); + this.toolsUnion = builder.toolsUnion != null ? builder.toolsUnion : ImmutableList.of(); this.toolsets = extractToolsets(this.toolsUnion); this.codeExecutor = Optional.ofNullable(builder.codeExecutor); @@ -704,16 +704,7 @@ private static boolean isThought(Part part) { @Override protected Flowable runAsyncImpl(InvocationContext invocationContext) { - return llmFlow - .run(invocationContext) - .concatMap( - event -> { - this.maybeSaveOutputToState(event); - if (invocationContext.shouldPauseInvocation(event)) { - return Flowable.just(event).concatWith(Flowable.empty()); - } - return Flowable.just(event); - }); + return llmFlow.run(invocationContext).doOnNext(this::maybeSaveOutputToState); } @Override @@ -850,27 +841,27 @@ public boolean disallowTransferToPeers() { return disallowTransferToPeers; } - public List beforeModelCallback() { + public Optional> beforeModelCallback() { return beforeModelCallback; } - public List afterModelCallback() { + public Optional> afterModelCallback() { return afterModelCallback; } - public List beforeToolCallback() { + public Optional> beforeToolCallback() { return beforeToolCallback; } - public List afterToolCallback() { + public Optional> afterToolCallback() { return afterToolCallback; } - public List onModelErrorCallback() { + public Optional> onModelErrorCallback() { return onModelErrorCallback; } - public List onToolErrorCallback() { + public Optional> onToolErrorCallback() { return onToolErrorCallback; } @@ -880,7 +871,7 @@ public List onToolErrorCallback() { *

This method is only for use by Agent Development Kit. */ public List canonicalBeforeModelCallbacks() { - return beforeModelCallback; + return beforeModelCallback.orElse(ImmutableList.of()); } /** @@ -889,7 +880,7 @@ public List canonicalBeforeModelCallbacks() { *

This method is only for use by Agent Development Kit. */ public List canonicalAfterModelCallbacks() { - return afterModelCallback; + return afterModelCallback.orElse(ImmutableList.of()); } /** @@ -898,7 +889,7 @@ public List canonicalAfterModelCallbacks() { *

This method is only for use by Agent Development Kit. */ public List canonicalOnModelErrorCallbacks() { - return onModelErrorCallback; + return onModelErrorCallback.orElse(ImmutableList.of()); } /** @@ -907,7 +898,7 @@ public List canonicalOnModelErrorCallbacks() { *

This method is only for use by Agent Development Kit. */ public List canonicalBeforeToolCallbacks() { - return beforeToolCallback; + return beforeToolCallback.orElse(ImmutableList.of()); } /** @@ -916,7 +907,7 @@ public List canonicalBeforeToolCallbacks() { *

This method is only for use by Agent Development Kit. */ public List canonicalAfterToolCallbacks() { - return afterToolCallback; + return afterToolCallback.orElse(ImmutableList.of()); } /** @@ -925,7 +916,7 @@ public List canonicalAfterToolCallbacks() { *

This method is only for use by Agent Development Kit. */ public List canonicalOnToolErrorCallbacks() { - return onToolErrorCallback; + return onToolErrorCallback.orElse(ImmutableList.of()); } public Optional inputSchema() { diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 487cca2af..63909ee1a 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -15,8 +15,10 @@ */ package com.google.adk.events; +import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.google.adk.agents.BaseAgentState; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.Part; import java.util.Objects; @@ -37,8 +39,11 @@ public class EventActions { private Optional escalate; private ConcurrentMap> requestedAuthConfigs; private ConcurrentMap requestedToolConfirmations; + private boolean endOfAgent; + private ConcurrentMap agentState; private Optional endInvocation; private Optional compaction; + private Optional rewindBeforeInvocationId; /** Default constructor for Jackson. */ public EventActions() { @@ -49,8 +54,11 @@ public EventActions() { this.escalate = Optional.empty(); this.requestedAuthConfigs = new ConcurrentHashMap<>(); this.requestedToolConfirmations = new ConcurrentHashMap<>(); + this.endOfAgent = false; this.endInvocation = Optional.empty(); this.compaction = Optional.empty(); + this.agentState = new ConcurrentHashMap<>(); + this.rewindBeforeInvocationId = Optional.empty(); } private EventActions(Builder builder) { @@ -61,8 +69,11 @@ private EventActions(Builder builder) { this.escalate = builder.escalate; this.requestedAuthConfigs = builder.requestedAuthConfigs; this.requestedToolConfirmations = builder.requestedToolConfirmations; + this.endOfAgent = builder.endOfAgent; this.endInvocation = builder.endInvocation; this.compaction = builder.compaction; + this.agentState = builder.agentState; + this.rewindBeforeInvocationId = builder.rewindBeforeInvocationId; } @JsonProperty("skipSummarization") @@ -146,6 +157,16 @@ public void setRequestedToolConfirmations( this.requestedToolConfirmations = requestedToolConfirmations; } + @JsonProperty("endOfAgent") + @JsonInclude(JsonInclude.Include.NON_DEFAULT) + public boolean endOfAgent() { + return endOfAgent; + } + + public void setEndOfAgent(boolean endOfAgent) { + this.endOfAgent = endOfAgent; + } + @JsonProperty("endInvocation") public Optional endInvocation() { return endInvocation; @@ -168,6 +189,25 @@ public void setCompaction(Optional compaction) { this.compaction = compaction; } + @JsonProperty("agentState") + @JsonInclude(JsonInclude.Include.NON_EMPTY) + public ConcurrentMap agentState() { + return agentState; + } + + public void setAgentState(ConcurrentMap agentState) { + this.agentState = agentState; + } + + @JsonProperty("rewindBeforeInvocationId") + public Optional rewindBeforeInvocationId() { + return rewindBeforeInvocationId; + } + + public void setRewindBeforeInvocationId(@Nullable String rewindBeforeInvocationId) { + this.rewindBeforeInvocationId = Optional.ofNullable(rewindBeforeInvocationId); + } + public static Builder builder() { return new Builder(); } @@ -191,8 +231,11 @@ public boolean equals(Object o) { && Objects.equals(escalate, that.escalate) && Objects.equals(requestedAuthConfigs, that.requestedAuthConfigs) && Objects.equals(requestedToolConfirmations, that.requestedToolConfirmations) + && (endOfAgent == that.endOfAgent) && Objects.equals(endInvocation, that.endInvocation) - && Objects.equals(compaction, that.compaction); + && Objects.equals(compaction, that.compaction) + && Objects.equals(agentState, that.agentState) + && Objects.equals(rewindBeforeInvocationId, that.rewindBeforeInvocationId); } @Override @@ -205,8 +248,11 @@ public int hashCode() { escalate, requestedAuthConfigs, requestedToolConfirmations, + endOfAgent, endInvocation, - compaction); + compaction, + agentState, + rewindBeforeInvocationId); } /** Builder for {@link EventActions}. */ @@ -218,8 +264,11 @@ public static class Builder { private Optional escalate; private ConcurrentMap> requestedAuthConfigs; private ConcurrentMap requestedToolConfirmations; + private boolean endOfAgent = false; private Optional endInvocation; private Optional compaction; + private ConcurrentMap agentState; + private Optional rewindBeforeInvocationId; public Builder() { this.skipSummarization = Optional.empty(); @@ -231,6 +280,8 @@ public Builder() { this.requestedToolConfirmations = new ConcurrentHashMap<>(); this.endInvocation = Optional.empty(); this.compaction = Optional.empty(); + this.agentState = new ConcurrentHashMap<>(); + this.rewindBeforeInvocationId = Optional.empty(); } private Builder(EventActions eventActions) { @@ -242,8 +293,11 @@ private Builder(EventActions eventActions) { this.requestedAuthConfigs = new ConcurrentHashMap<>(eventActions.requestedAuthConfigs()); this.requestedToolConfirmations = new ConcurrentHashMap<>(eventActions.requestedToolConfirmations()); + this.endOfAgent = eventActions.endOfAgent(); this.endInvocation = eventActions.endInvocation(); this.compaction = eventActions.compaction(); + this.agentState = new ConcurrentHashMap<>(eventActions.agentState()); + this.rewindBeforeInvocationId = eventActions.rewindBeforeInvocationId(); } @CanIgnoreReturnValue @@ -296,6 +350,13 @@ public Builder requestedToolConfirmations(ConcurrentMap agentState) { + this.agentState = agentState; + return this; + } + + @CanIgnoreReturnValue + @JsonProperty("rewindBeforeInvocationId") + public Builder rewindBeforeInvocationId(String rewindBeforeInvocationId) { + this.rewindBeforeInvocationId = Optional.ofNullable(rewindBeforeInvocationId); + return this; + } + @CanIgnoreReturnValue public Builder merge(EventActions other) { other.skipSummarization().ifPresent(this::skipSummarization); @@ -319,8 +394,11 @@ public Builder merge(EventActions other) { other.escalate().ifPresent(this::escalate); this.requestedAuthConfigs.putAll(other.requestedAuthConfigs()); this.requestedToolConfirmations.putAll(other.requestedToolConfirmations()); + this.endOfAgent = other.endOfAgent(); other.endInvocation().ifPresent(this::endInvocation); other.compaction().ifPresent(this::compaction); + this.agentState.putAll(other.agentState()); + other.rewindBeforeInvocationId().ifPresent(this::rewindBeforeInvocationId); return this; } diff --git a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java index 5d3c88af3..5dbbe76c7 100644 --- a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java +++ b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java @@ -19,8 +19,10 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.JsonBaseModel; +import com.google.adk.agents.BaseAgentState; import com.google.adk.events.Event; import com.google.adk.events.EventActions; +import com.google.adk.events.ToolConfirmation; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; @@ -95,8 +97,11 @@ static String convertEventToJson(Event event) { actionsJson.put("escalate", event.actions().escalate()); actionsJson.put("requestedAuthConfigs", event.actions().requestedAuthConfigs()); actionsJson.put("requestedToolConfirmations", event.actions().requestedToolConfirmations()); - actionsJson.put("endInvocation", event.actions().endInvocation()); actionsJson.put("compaction", event.actions().compaction()); + if (!event.actions().agentState().isEmpty()) { + actionsJson.put("agentState", event.actions().agentState()); + } + actionsJson.put("rewindBeforeInvocationId", event.actions().rewindBeforeInvocationId()); eventJson.put("actions", actionsJson); } if (event.content().isPresent()) { @@ -121,6 +126,7 @@ static String convertEventToJson(Event event) { * @return parsed {@link Content}, or {@code null} if conversion fails. */ @Nullable + // Safe because we check instanceof Map before casting. @SuppressWarnings("unchecked") private static Content convertMapToContent(Object rawContentValue) { if (rawContentValue == null) { @@ -147,6 +153,7 @@ private static Content convertMapToContent(Object rawContentValue) { * * @return parsed {@link Event}. */ + // Safe because we are parsing from a raw Map structure that follows a known schema. @SuppressWarnings("unchecked") static Event fromApiEvent(Map apiEvent) { EventActions.Builder eventActionsBuilder = EventActions.builder(); @@ -171,6 +178,17 @@ static Event fromApiEvent(Map apiEvent) { Optional.ofNullable(actionsMap.get("requestedAuthConfigs")) .map(SessionJsonConverter::asConcurrentMapOfConcurrentMaps) .orElse(new ConcurrentHashMap<>())); + eventActionsBuilder.requestedToolConfirmations( + Optional.ofNullable(actionsMap.get("requestedToolConfirmations")) + .map(SessionJsonConverter::asConcurrentMapOfToolConfirmations) + .orElse(new ConcurrentHashMap<>())); + if (actionsMap.get("agentState") != null) { + eventActionsBuilder.agentState(asConcurrentMapOfAgentState(actionsMap.get("agentState"))); + } + if (actionsMap.get("rewindBeforeInvocationId") != null) { + eventActionsBuilder.rewindBeforeInvocationId( + (String) actionsMap.get("rewindBeforeInvocationId")); + } } Event event = @@ -245,6 +263,7 @@ private static Instant convertToInstant(Object timestampObj) { * @param artifactDeltaObj The raw object from which to parse the artifact delta. * @return A {@link ConcurrentMap} representing the artifact delta. */ + // Safe because we check instanceof Map before casting. @SuppressWarnings("unchecked") private static ConcurrentMap convertToArtifactDeltaMap(Object artifactDeltaObj) { if (!(artifactDeltaObj instanceof Map)) { @@ -268,6 +287,7 @@ private static ConcurrentMap convertToArtifactDeltaMap(Object arti * * @return thread-safe nested map. */ + // Safe because we are parsing from a raw Map structure that follows a known schema. @SuppressWarnings("unchecked") private static ConcurrentMap> asConcurrentMapOfConcurrentMaps(Object value) { @@ -278,4 +298,33 @@ private static ConcurrentMap convertToArtifactDeltaMap(Object arti (map, entry) -> map.put(entry.getKey(), new ConcurrentHashMap<>(entry.getValue())), ConcurrentHashMap::putAll); } + + // Safe because we are parsing from a raw Map structure that follows a known schema. + @SuppressWarnings("unchecked") + private static ConcurrentMap asConcurrentMapOfAgentState(Object value) { + return ((Map) value) + .entrySet().stream() + .collect( + ConcurrentHashMap::new, + (map, entry) -> + map.put( + entry.getKey(), + objectMapper.convertValue(entry.getValue(), BaseAgentState.class)), + ConcurrentHashMap::putAll); + } + + // Safe because we are parsing from a raw Map structure that follows a known schema. + @SuppressWarnings("unchecked") + private static ConcurrentMap asConcurrentMapOfToolConfirmations( + Object value) { + return ((Map) value) + .entrySet().stream() + .collect( + ConcurrentHashMap::new, + (map, entry) -> + map.put( + entry.getKey(), + objectMapper.convertValue(entry.getValue(), ToolConfirmation.class)), + ConcurrentHashMap::putAll); + } } diff --git a/core/src/main/java/com/google/adk/tools/FunctionTool.java b/core/src/main/java/com/google/adk/tools/FunctionTool.java index b856c2435..a6167ee46 100644 --- a/core/src/main/java/com/google/adk/tools/FunctionTool.java +++ b/core/src/main/java/com/google/adk/tools/FunctionTool.java @@ -56,6 +56,11 @@ public static FunctionTool create(Object instance, Method func) { } public static FunctionTool create(Object instance, Method func, boolean requireConfirmation) { + return create(instance, func, requireConfirmation, false); + } + + public static FunctionTool create( + Object instance, Method func, boolean requireConfirmation, boolean isLongRunning) { if (!areParametersAnnotatedWithSchema(func) && wasCompiledWithDefaultParameterNames(func)) { logger.error( """ @@ -72,7 +77,7 @@ public static FunctionTool create(Object instance, Method func, boolean requireC func.getDeclaringClass().getName(), instance.getClass().getName())); } return new FunctionTool( - instance, func, /* isLongRunning= */ false, /* requireConfirmation= */ requireConfirmation); + instance, func, isLongRunning, /* requireConfirmation= */ requireConfirmation); } public static FunctionTool create(Method func) { @@ -80,6 +85,11 @@ public static FunctionTool create(Method func) { } public static FunctionTool create(Method func, boolean requireConfirmation) { + return create(func, requireConfirmation, false); + } + + public static FunctionTool create( + Method func, boolean requireConfirmation, boolean isLongRunning) { if (!areParametersAnnotatedWithSchema(func) && wasCompiledWithDefaultParameterNames(func)) { logger.error( """ @@ -91,7 +101,7 @@ public static FunctionTool create(Method func, boolean requireConfirmation) { if (!Modifier.isStatic(func.getModifiers())) { throw new IllegalArgumentException("The method provided must be static."); } - return new FunctionTool(null, func, /* isLongRunning= */ false, requireConfirmation); + return new FunctionTool(null, func, isLongRunning, requireConfirmation); } public static FunctionTool create(Class cls, String methodName) { @@ -99,9 +109,14 @@ public static FunctionTool create(Class cls, String methodName) { } public static FunctionTool create(Class cls, String methodName, boolean requireConfirmation) { + return create(cls, methodName, requireConfirmation, false); + } + + public static FunctionTool create( + Class cls, String methodName, boolean requireConfirmation, boolean isLongRunning) { for (Method method : cls.getMethods()) { if (method.getName().equals(methodName) && Modifier.isStatic(method.getModifiers())) { - return create(null, method, requireConfirmation); + return create(null, method, requireConfirmation, isLongRunning); } } throw new IllegalArgumentException( @@ -114,10 +129,15 @@ public static FunctionTool create(Object instance, String methodName) { public static FunctionTool create( Object instance, String methodName, boolean requireConfirmation) { + return create(instance, methodName, requireConfirmation, false); + } + + public static FunctionTool create( + Object instance, String methodName, boolean requireConfirmation, boolean isLongRunning) { Class cls = instance.getClass(); for (Method method : cls.getMethods()) { if (method.getName().equals(methodName) && !Modifier.isStatic(method.getModifiers())) { - return create(instance, method, requireConfirmation); + return create(instance, method, requireConfirmation, isLongRunning); } } throw new IllegalArgumentException( diff --git a/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java b/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java index 11e07a094..4825efca1 100644 --- a/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java +++ b/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java @@ -1161,25 +1161,20 @@ public void fromConfig_withConfiguredCallbacks_resolvesCallbacks() String pfx = "test.callbacks."; registry.register( - pfx + "before_agent_1", (Callbacks.BeforeAgentCallback) (unusedCtx) -> Maybe.empty()); + pfx + "before_agent_1", (Callbacks.BeforeAgentCallback) (ctx) -> Maybe.empty()); registry.register( - pfx + "before_agent_2", (Callbacks.BeforeAgentCallback) (unusedCtx) -> Maybe.empty()); + pfx + "before_agent_2", (Callbacks.BeforeAgentCallback) (ctx) -> Maybe.empty()); + registry.register(pfx + "after_agent_1", (Callbacks.AfterAgentCallback) (ctx) -> Maybe.empty()); registry.register( - pfx + "after_agent_1", (Callbacks.AfterAgentCallback) (unusedCtx) -> Maybe.empty()); + pfx + "before_model_1", (Callbacks.BeforeModelCallback) (ctx, req) -> Maybe.empty()); registry.register( - pfx + "before_model_1", - (Callbacks.BeforeModelCallback) (unusedCtx, unusedReq) -> Maybe.empty()); - registry.register( - pfx + "after_model_1", - (Callbacks.AfterModelCallback) (unusedCtx, unusedResp) -> Maybe.empty()); + pfx + "after_model_1", (Callbacks.AfterModelCallback) (ctx, resp) -> Maybe.empty()); registry.register( pfx + "before_tool_1", - (Callbacks.BeforeToolCallback) - (unusedInv, unusedTool, unusedArgs, unusedToolCtx) -> Maybe.empty()); + (Callbacks.BeforeToolCallback) (inv, tool, args, toolCtx) -> Maybe.empty()); registry.register( pfx + "after_tool_1", - (Callbacks.AfterToolCallback) - (unusedInv, unusedTool, unusedArgs, unusedToolCtx, unusedResp) -> Maybe.empty()); + (Callbacks.AfterToolCallback) (inv, tool, args, toolCtx, resp) -> Maybe.empty()); File configFile = tempFolder.newFile("with_callbacks.yaml"); Files.writeString( @@ -1214,11 +1209,15 @@ public void fromConfig_withConfiguredCallbacks_resolvesCallbacks() assertThat(agent.afterAgentCallback()).isPresent(); assertThat(agent.afterAgentCallback().get()).hasSize(1); - assertThat(llm.beforeModelCallback()).hasSize(1); - assertThat(llm.afterModelCallback()).hasSize(1); + assertThat(llm.beforeModelCallback()).isPresent(); + assertThat(llm.beforeModelCallback().get()).hasSize(1); + assertThat(llm.afterModelCallback()).isPresent(); + assertThat(llm.afterModelCallback().get()).hasSize(1); - assertThat(llm.beforeToolCallback()).hasSize(1); - assertThat(llm.afterToolCallback()).hasSize(1); + assertThat(llm.beforeToolCallback()).isPresent(); + assertThat(llm.beforeToolCallback().get()).hasSize(1); + assertThat(llm.afterToolCallback()).isPresent(); + assertThat(llm.afterToolCallback().get()).hasSize(1); } @Test diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index 8a2ff6df8..519c90558 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -341,13 +341,6 @@ public void canonicalCallbacks_returnsEmptyListWhenNull() { assertThat(agent.canonicalBeforeToolCallbacks()).isEmpty(); assertThat(agent.canonicalAfterToolCallbacks()).isEmpty(); assertThat(agent.canonicalOnToolErrorCallbacks()).isEmpty(); - - assertThat(agent.beforeModelCallback()).isEmpty(); - assertThat(agent.afterModelCallback()).isEmpty(); - assertThat(agent.onModelErrorCallback()).isEmpty(); - assertThat(agent.beforeToolCallback()).isEmpty(); - assertThat(agent.afterToolCallback()).isEmpty(); - assertThat(agent.onToolErrorCallback()).isEmpty(); } @Test @@ -378,12 +371,5 @@ public void canonicalCallbacks_returnsListWhenPresent() { assertThat(agent.canonicalBeforeToolCallbacks()).containsExactly(btc); assertThat(agent.canonicalAfterToolCallbacks()).containsExactly(atc); assertThat(agent.canonicalOnToolErrorCallbacks()).containsExactly(otec); - - assertThat(agent.beforeModelCallback()).containsExactly(bmc); - assertThat(agent.afterModelCallback()).containsExactly(amc); - assertThat(agent.onModelErrorCallback()).containsExactly(omec); - assertThat(agent.beforeToolCallback()).containsExactly(btc); - assertThat(agent.afterToolCallback()).containsExactly(atc); - assertThat(agent.onToolErrorCallback()).containsExactly(otec); } }