Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions core/src/main/java/com/google/adk/agents/BaseAgentState.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2026 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.adk.agents;

import com.google.adk.JsonBaseModel;

/** Base class for all agent states. */
public class BaseAgentState extends JsonBaseModel {

protected BaseAgentState() {}

/** Returns a new {@link Builder} for creating {@link BaseAgentState} instances. */
public static Builder builder() {
return new Builder();
}

/** Builder for {@link BaseAgentState}. */
public static class Builder {
private Builder() {}

public BaseAgentState build() {
return new BaseAgentState();
}
}
}
46 changes: 46 additions & 0 deletions core/src/main/java/com/google/adk/agents/InvocationContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ public class InvocationContext {
private final Session session;
private final Optional<Content> userContent;
private final RunConfig runConfig;
private final Map<String, BaseAgentState> agentStates;
private final Map<String, Boolean> endOfAgents;
private final ResumabilityConfig resumabilityConfig;
private final InvocationCostManager invocationCostManager;

Expand All @@ -71,6 +73,8 @@ protected InvocationContext(Builder builder) {
this.userContent = builder.userContent;
this.runConfig = builder.runConfig;
this.endInvocation = builder.endInvocation;
this.agentStates = builder.agentStates;
this.endOfAgents = builder.endOfAgents;
this.resumabilityConfig = builder.resumabilityConfig;
this.invocationCostManager = builder.invocationCostManager;
}
Expand Down Expand Up @@ -299,6 +303,16 @@ public RunConfig runConfig() {
return runConfig;
}

/** Returns agent-specific state saved within this invocation. */
public Map<String, BaseAgentState> agentStates() {
return agentStates;
}

/** Returns map of agents that ended during this invocation. */
public Map<String, Boolean> endOfAgents() {
return endOfAgents;
}

/**
* Returns whether this invocation should be ended, e.g., due to reaching a terminal state or
* error.
Expand Down Expand Up @@ -410,6 +424,8 @@ private Builder(InvocationContext context) {
this.userContent = context.userContent;
this.runConfig = context.runConfig;
this.endInvocation = context.endInvocation;
this.agentStates = new ConcurrentHashMap<>(context.agentStates);
this.endOfAgents = new ConcurrentHashMap<>(context.endOfAgents);
this.resumabilityConfig = context.resumabilityConfig;
this.invocationCostManager = context.invocationCostManager;
}
Expand All @@ -427,6 +443,8 @@ private Builder(InvocationContext context) {
private Optional<Content> userContent = Optional.empty();
private RunConfig runConfig = RunConfig.builder().build();
private boolean endInvocation = false;
private Map<String, BaseAgentState> agentStates = new ConcurrentHashMap<>();
private Map<String, Boolean> endOfAgents = new ConcurrentHashMap<>();
private ResumabilityConfig resumabilityConfig = new ResumabilityConfig();
private InvocationCostManager invocationCostManager = new InvocationCostManager();

Expand Down Expand Up @@ -616,6 +634,30 @@ public Builder endInvocation(boolean endInvocation) {
return this;
}

/**
* Sets agent-specific state saved within this invocation.
*
* @param agentStates agent-specific state saved within this invocation.
* @return this builder instance for chaining.
*/
@CanIgnoreReturnValue
public Builder agentStates(Map<String, BaseAgentState> agentStates) {
this.agentStates = agentStates;
return this;
}

/**
* Sets agent end-of-invocation status.
*
* @param endOfAgents agent end-of-invocation status.
* @return this builder instance for chaining.
*/
@CanIgnoreReturnValue
public Builder endOfAgents(Map<String, Boolean> endOfAgents) {
this.endOfAgents = endOfAgents;
return this;
}

/**
* Sets the resumability configuration for the current agent run.
*
Expand Down Expand Up @@ -660,6 +702,8 @@ public boolean equals(Object o) {
&& Objects.equals(session, that.session)
&& Objects.equals(userContent, that.userContent)
&& Objects.equals(runConfig, that.runConfig)
&& Objects.equals(agentStates, that.agentStates)
&& Objects.equals(endOfAgents, that.endOfAgents)
&& Objects.equals(resumabilityConfig, that.resumabilityConfig)
&& Objects.equals(invocationCostManager, that.invocationCostManager);
}
Expand All @@ -680,6 +724,8 @@ public int hashCode() {
userContent,
runConfig,
endInvocation,
agentStates,
endOfAgents,
resumabilityConfig,
invocationCostManager);
}
Expand Down
71 changes: 31 additions & 40 deletions core/src/main/java/com/google/adk/agents/LlmAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package com.google.adk.agents;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNullElse;
import static java.util.stream.Collectors.joining;

import com.fasterxml.jackson.core.JsonProcessingException;
Expand Down Expand Up @@ -104,12 +103,12 @@ public enum IncludeContents {
private final Optional<Integer> maxSteps;
private final boolean disallowTransferToParent;
private final boolean disallowTransferToPeers;
private final ImmutableList<? extends BeforeModelCallback> beforeModelCallback;
private final ImmutableList<? extends AfterModelCallback> afterModelCallback;
private final ImmutableList<? extends OnModelErrorCallback> onModelErrorCallback;
private final ImmutableList<? extends BeforeToolCallback> beforeToolCallback;
private final ImmutableList<? extends AfterToolCallback> afterToolCallback;
private final ImmutableList<? extends OnToolErrorCallback> onToolErrorCallback;
private final Optional<List<? extends BeforeModelCallback>> beforeModelCallback;
private final Optional<List<? extends AfterModelCallback>> afterModelCallback;
private final Optional<List<? extends OnModelErrorCallback>> onModelErrorCallback;
private final Optional<List<? extends BeforeToolCallback>> beforeToolCallback;
private final Optional<List<? extends AfterToolCallback>> afterToolCallback;
private final Optional<List<? extends OnToolErrorCallback>> onToolErrorCallback;
private final Optional<Schema> inputSchema;
private final Optional<Schema> outputSchema;
private final Optional<Executor> executor;
Expand All @@ -127,28 +126,29 @@ protected LlmAgent(Builder builder) {
builder.beforeAgentCallback,
builder.afterAgentCallback);
this.model = Optional.ofNullable(builder.model);
this.instruction = requireNonNullElse(builder.instruction, new Instruction.Static(""));
this.instruction =
builder.instruction == null ? new Instruction.Static("") : builder.instruction;
this.globalInstruction =
requireNonNullElse(builder.globalInstruction, new Instruction.Static(""));
builder.globalInstruction == null ? new Instruction.Static("") : builder.globalInstruction;
this.generateContentConfig = Optional.ofNullable(builder.generateContentConfig);
this.exampleProvider = Optional.ofNullable(builder.exampleProvider);
this.includeContents = requireNonNullElse(builder.includeContents, IncludeContents.DEFAULT);
this.includeContents =
builder.includeContents != null ? builder.includeContents : IncludeContents.DEFAULT;
this.planning = builder.planning != null && builder.planning;
this.maxSteps = Optional.ofNullable(builder.maxSteps);
this.disallowTransferToParent = builder.disallowTransferToParent;
this.disallowTransferToPeers = builder.disallowTransferToPeers;
this.beforeModelCallback = requireNonNullElse(builder.beforeModelCallback, ImmutableList.of());
this.afterModelCallback = requireNonNullElse(builder.afterModelCallback, ImmutableList.of());
this.onModelErrorCallback =
requireNonNullElse(builder.onModelErrorCallback, ImmutableList.of());
this.beforeToolCallback = requireNonNullElse(builder.beforeToolCallback, ImmutableList.of());
this.afterToolCallback = requireNonNullElse(builder.afterToolCallback, ImmutableList.of());
this.onToolErrorCallback = requireNonNullElse(builder.onToolErrorCallback, ImmutableList.of());
this.beforeModelCallback = Optional.ofNullable(builder.beforeModelCallback);
this.afterModelCallback = Optional.ofNullable(builder.afterModelCallback);
this.onModelErrorCallback = Optional.ofNullable(builder.onModelErrorCallback);
this.beforeToolCallback = Optional.ofNullable(builder.beforeToolCallback);
this.afterToolCallback = Optional.ofNullable(builder.afterToolCallback);
this.onToolErrorCallback = Optional.ofNullable(builder.onToolErrorCallback);
this.inputSchema = Optional.ofNullable(builder.inputSchema);
this.outputSchema = Optional.ofNullable(builder.outputSchema);
this.executor = Optional.ofNullable(builder.executor);
this.outputKey = Optional.ofNullable(builder.outputKey);
this.toolsUnion = requireNonNullElse(builder.toolsUnion, ImmutableList.of());
this.toolsUnion = builder.toolsUnion != null ? builder.toolsUnion : ImmutableList.of();
this.toolsets = extractToolsets(this.toolsUnion);
this.codeExecutor = Optional.ofNullable(builder.codeExecutor);

Expand Down Expand Up @@ -704,16 +704,7 @@ private static boolean isThought(Part part) {

@Override
protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
return llmFlow
.run(invocationContext)
.concatMap(
event -> {
this.maybeSaveOutputToState(event);
if (invocationContext.shouldPauseInvocation(event)) {
return Flowable.just(event).concatWith(Flowable.empty());
}
return Flowable.just(event);
});
return llmFlow.run(invocationContext).doOnNext(this::maybeSaveOutputToState);
}

@Override
Expand Down Expand Up @@ -850,27 +841,27 @@ public boolean disallowTransferToPeers() {
return disallowTransferToPeers;
}

public List<? extends BeforeModelCallback> beforeModelCallback() {
public Optional<List<? extends BeforeModelCallback>> beforeModelCallback() {
return beforeModelCallback;
}

public List<? extends AfterModelCallback> afterModelCallback() {
public Optional<List<? extends AfterModelCallback>> afterModelCallback() {
return afterModelCallback;
}

public List<? extends BeforeToolCallback> beforeToolCallback() {
public Optional<List<? extends BeforeToolCallback>> beforeToolCallback() {
return beforeToolCallback;
}

public List<? extends AfterToolCallback> afterToolCallback() {
public Optional<List<? extends AfterToolCallback>> afterToolCallback() {
return afterToolCallback;
}

public List<? extends OnModelErrorCallback> onModelErrorCallback() {
public Optional<List<? extends OnModelErrorCallback>> onModelErrorCallback() {
return onModelErrorCallback;
}

public List<? extends OnToolErrorCallback> onToolErrorCallback() {
public Optional<List<? extends OnToolErrorCallback>> onToolErrorCallback() {
return onToolErrorCallback;
}

Expand All @@ -880,7 +871,7 @@ public List<? extends OnToolErrorCallback> onToolErrorCallback() {
* <p>This method is only for use by Agent Development Kit.
*/
public List<? extends BeforeModelCallback> canonicalBeforeModelCallbacks() {
return beforeModelCallback;
return beforeModelCallback.orElse(ImmutableList.of());
}

/**
Expand All @@ -889,7 +880,7 @@ public List<? extends BeforeModelCallback> canonicalBeforeModelCallbacks() {
* <p>This method is only for use by Agent Development Kit.
*/
public List<? extends AfterModelCallback> canonicalAfterModelCallbacks() {
return afterModelCallback;
return afterModelCallback.orElse(ImmutableList.of());
}

/**
Expand All @@ -898,7 +889,7 @@ public List<? extends AfterModelCallback> canonicalAfterModelCallbacks() {
* <p>This method is only for use by Agent Development Kit.
*/
public List<? extends OnModelErrorCallback> canonicalOnModelErrorCallbacks() {
return onModelErrorCallback;
return onModelErrorCallback.orElse(ImmutableList.of());
}

/**
Expand All @@ -907,7 +898,7 @@ public List<? extends OnModelErrorCallback> canonicalOnModelErrorCallbacks() {
* <p>This method is only for use by Agent Development Kit.
*/
public List<? extends BeforeToolCallback> canonicalBeforeToolCallbacks() {
return beforeToolCallback;
return beforeToolCallback.orElse(ImmutableList.of());
}

/**
Expand All @@ -916,7 +907,7 @@ public List<? extends BeforeToolCallback> canonicalBeforeToolCallbacks() {
* <p>This method is only for use by Agent Development Kit.
*/
public List<? extends AfterToolCallback> canonicalAfterToolCallbacks() {
return afterToolCallback;
return afterToolCallback.orElse(ImmutableList.of());
}

/**
Expand All @@ -925,7 +916,7 @@ public List<? extends AfterToolCallback> canonicalAfterToolCallbacks() {
* <p>This method is only for use by Agent Development Kit.
*/
public List<? extends OnToolErrorCallback> canonicalOnToolErrorCallbacks() {
return onToolErrorCallback;
return onToolErrorCallback.orElse(ImmutableList.of());
}

public Optional<Schema> inputSchema() {
Expand Down
Loading