Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 30 additions & 30 deletions core/src/main/java/com/google/adk/agents/LlmAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.adk.agents;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNullElse;
import static java.util.stream.Collectors.joining;

import com.fasterxml.jackson.core.JsonProcessingException;
Expand Down Expand Up @@ -103,12 +104,12 @@ public enum IncludeContents {
private final Optional<Integer> maxSteps;
private final boolean disallowTransferToParent;
private final boolean disallowTransferToPeers;
private final Optional<List<? extends BeforeModelCallback>> beforeModelCallback;
private final Optional<List<? extends AfterModelCallback>> afterModelCallback;
private final Optional<List<? extends OnModelErrorCallback>> onModelErrorCallback;
private final Optional<List<? extends BeforeToolCallback>> beforeToolCallback;
private final Optional<List<? extends AfterToolCallback>> afterToolCallback;
private final Optional<List<? extends OnToolErrorCallback>> onToolErrorCallback;
private final ImmutableList<? extends BeforeModelCallback> beforeModelCallback;
private final ImmutableList<? extends AfterModelCallback> afterModelCallback;
private final ImmutableList<? extends OnModelErrorCallback> onModelErrorCallback;
private final ImmutableList<? extends BeforeToolCallback> beforeToolCallback;
private final ImmutableList<? extends AfterToolCallback> afterToolCallback;
private final ImmutableList<? extends OnToolErrorCallback> onToolErrorCallback;
private final Optional<Schema> inputSchema;
private final Optional<Schema> outputSchema;
private final Optional<Executor> executor;
Expand All @@ -126,29 +127,28 @@ protected LlmAgent(Builder builder) {
builder.beforeAgentCallback,
builder.afterAgentCallback);
this.model = Optional.ofNullable(builder.model);
this.instruction =
builder.instruction == null ? new Instruction.Static("") : builder.instruction;
this.instruction = requireNonNullElse(builder.instruction, new Instruction.Static(""));
this.globalInstruction =
builder.globalInstruction == null ? new Instruction.Static("") : builder.globalInstruction;
requireNonNullElse(builder.globalInstruction, new Instruction.Static(""));
this.generateContentConfig = Optional.ofNullable(builder.generateContentConfig);
this.exampleProvider = Optional.ofNullable(builder.exampleProvider);
this.includeContents =
builder.includeContents != null ? builder.includeContents : IncludeContents.DEFAULT;
this.includeContents = requireNonNullElse(builder.includeContents, IncludeContents.DEFAULT);
this.planning = builder.planning != null && builder.planning;
this.maxSteps = Optional.ofNullable(builder.maxSteps);
this.disallowTransferToParent = builder.disallowTransferToParent;
this.disallowTransferToPeers = builder.disallowTransferToPeers;
this.beforeModelCallback = Optional.ofNullable(builder.beforeModelCallback);
this.afterModelCallback = Optional.ofNullable(builder.afterModelCallback);
this.onModelErrorCallback = Optional.ofNullable(builder.onModelErrorCallback);
this.beforeToolCallback = Optional.ofNullable(builder.beforeToolCallback);
this.afterToolCallback = Optional.ofNullable(builder.afterToolCallback);
this.onToolErrorCallback = Optional.ofNullable(builder.onToolErrorCallback);
this.beforeModelCallback = requireNonNullElse(builder.beforeModelCallback, ImmutableList.of());
this.afterModelCallback = requireNonNullElse(builder.afterModelCallback, ImmutableList.of());
this.onModelErrorCallback =
requireNonNullElse(builder.onModelErrorCallback, ImmutableList.of());
this.beforeToolCallback = requireNonNullElse(builder.beforeToolCallback, ImmutableList.of());
this.afterToolCallback = requireNonNullElse(builder.afterToolCallback, ImmutableList.of());
this.onToolErrorCallback = requireNonNullElse(builder.onToolErrorCallback, ImmutableList.of());
this.inputSchema = Optional.ofNullable(builder.inputSchema);
this.outputSchema = Optional.ofNullable(builder.outputSchema);
this.executor = Optional.ofNullable(builder.executor);
this.outputKey = Optional.ofNullable(builder.outputKey);
this.toolsUnion = builder.toolsUnion != null ? builder.toolsUnion : ImmutableList.of();
this.toolsUnion = requireNonNullElse(builder.toolsUnion, ImmutableList.of());
this.toolsets = extractToolsets(this.toolsUnion);
this.codeExecutor = Optional.ofNullable(builder.codeExecutor);

Expand Down Expand Up @@ -850,27 +850,27 @@ public boolean disallowTransferToPeers() {
return disallowTransferToPeers;
}

public Optional<List<? extends BeforeModelCallback>> beforeModelCallback() {
public List<? extends BeforeModelCallback> beforeModelCallback() {
return beforeModelCallback;
}

public Optional<List<? extends AfterModelCallback>> afterModelCallback() {
public List<? extends AfterModelCallback> afterModelCallback() {
return afterModelCallback;
}

public Optional<List<? extends BeforeToolCallback>> beforeToolCallback() {
public List<? extends BeforeToolCallback> beforeToolCallback() {
return beforeToolCallback;
}

public Optional<List<? extends AfterToolCallback>> afterToolCallback() {
public List<? extends AfterToolCallback> afterToolCallback() {
return afterToolCallback;
}

public Optional<List<? extends OnModelErrorCallback>> onModelErrorCallback() {
public List<? extends OnModelErrorCallback> onModelErrorCallback() {
return onModelErrorCallback;
}

public Optional<List<? extends OnToolErrorCallback>> onToolErrorCallback() {
public List<? extends OnToolErrorCallback> onToolErrorCallback() {
return onToolErrorCallback;
}

Expand All @@ -880,7 +880,7 @@ public Optional<List<? extends OnToolErrorCallback>> onToolErrorCallback() {
* <p>This method is only for use by Agent Development Kit.
*/
public List<? extends BeforeModelCallback> canonicalBeforeModelCallbacks() {
return beforeModelCallback.orElse(ImmutableList.of());
return beforeModelCallback;
}

/**
Expand All @@ -889,7 +889,7 @@ public List<? extends BeforeModelCallback> canonicalBeforeModelCallbacks() {
* <p>This method is only for use by Agent Development Kit.
*/
public List<? extends AfterModelCallback> canonicalAfterModelCallbacks() {
return afterModelCallback.orElse(ImmutableList.of());
return afterModelCallback;
}

/**
Expand All @@ -898,7 +898,7 @@ public List<? extends AfterModelCallback> canonicalAfterModelCallbacks() {
* <p>This method is only for use by Agent Development Kit.
*/
public List<? extends OnModelErrorCallback> canonicalOnModelErrorCallbacks() {
return onModelErrorCallback.orElse(ImmutableList.of());
return onModelErrorCallback;
}

/**
Expand All @@ -907,7 +907,7 @@ public List<? extends OnModelErrorCallback> canonicalOnModelErrorCallbacks() {
* <p>This method is only for use by Agent Development Kit.
*/
public List<? extends BeforeToolCallback> canonicalBeforeToolCallbacks() {
return beforeToolCallback.orElse(ImmutableList.of());
return beforeToolCallback;
}

/**
Expand All @@ -916,7 +916,7 @@ public List<? extends BeforeToolCallback> canonicalBeforeToolCallbacks() {
* <p>This method is only for use by Agent Development Kit.
*/
public List<? extends AfterToolCallback> canonicalAfterToolCallbacks() {
return afterToolCallback.orElse(ImmutableList.of());
return afterToolCallback;
}

/**
Expand All @@ -925,7 +925,7 @@ public List<? extends AfterToolCallback> canonicalAfterToolCallbacks() {
* <p>This method is only for use by Agent Development Kit.
*/
public List<? extends OnToolErrorCallback> canonicalOnToolErrorCallbacks() {
return onToolErrorCallback.orElse(ImmutableList.of());
return onToolErrorCallback;
}

public Optional<Schema> inputSchema() {
Expand Down
31 changes: 16 additions & 15 deletions core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1161,20 +1161,25 @@ public void fromConfig_withConfiguredCallbacks_resolvesCallbacks()

String pfx = "test.callbacks.";
registry.register(
pfx + "before_agent_1", (Callbacks.BeforeAgentCallback) (ctx) -> Maybe.empty());
pfx + "before_agent_1", (Callbacks.BeforeAgentCallback) (unusedCtx) -> Maybe.empty());
registry.register(
pfx + "before_agent_2", (Callbacks.BeforeAgentCallback) (ctx) -> Maybe.empty());
registry.register(pfx + "after_agent_1", (Callbacks.AfterAgentCallback) (ctx) -> Maybe.empty());
pfx + "before_agent_2", (Callbacks.BeforeAgentCallback) (unusedCtx) -> Maybe.empty());
registry.register(
pfx + "before_model_1", (Callbacks.BeforeModelCallback) (ctx, req) -> Maybe.empty());
pfx + "after_agent_1", (Callbacks.AfterAgentCallback) (unusedCtx) -> Maybe.empty());
registry.register(
pfx + "after_model_1", (Callbacks.AfterModelCallback) (ctx, resp) -> Maybe.empty());
pfx + "before_model_1",
(Callbacks.BeforeModelCallback) (unusedCtx, unusedReq) -> Maybe.empty());
registry.register(
pfx + "after_model_1",
(Callbacks.AfterModelCallback) (unusedCtx, unusedResp) -> Maybe.empty());
registry.register(
pfx + "before_tool_1",
(Callbacks.BeforeToolCallback) (inv, tool, args, toolCtx) -> Maybe.empty());
(Callbacks.BeforeToolCallback)
(unusedInv, unusedTool, unusedArgs, unusedToolCtx) -> Maybe.empty());
registry.register(
pfx + "after_tool_1",
(Callbacks.AfterToolCallback) (inv, tool, args, toolCtx, resp) -> Maybe.empty());
(Callbacks.AfterToolCallback)
(unusedInv, unusedTool, unusedArgs, unusedToolCtx, unusedResp) -> Maybe.empty());

File configFile = tempFolder.newFile("with_callbacks.yaml");
Files.writeString(
Expand Down Expand Up @@ -1209,15 +1214,11 @@ public void fromConfig_withConfiguredCallbacks_resolvesCallbacks()
assertThat(agent.afterAgentCallback()).isPresent();
assertThat(agent.afterAgentCallback().get()).hasSize(1);

assertThat(llm.beforeModelCallback()).isPresent();
assertThat(llm.beforeModelCallback().get()).hasSize(1);
assertThat(llm.afterModelCallback()).isPresent();
assertThat(llm.afterModelCallback().get()).hasSize(1);
assertThat(llm.beforeModelCallback()).hasSize(1);
assertThat(llm.afterModelCallback()).hasSize(1);

assertThat(llm.beforeToolCallback()).isPresent();
assertThat(llm.beforeToolCallback().get()).hasSize(1);
assertThat(llm.afterToolCallback()).isPresent();
assertThat(llm.afterToolCallback().get()).hasSize(1);
assertThat(llm.beforeToolCallback()).hasSize(1);
assertThat(llm.afterToolCallback()).hasSize(1);
}

@Test
Expand Down
14 changes: 14 additions & 0 deletions core/src/test/java/com/google/adk/agents/LlmAgentTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,13 @@ public void canonicalCallbacks_returnsEmptyListWhenNull() {
assertThat(agent.canonicalBeforeToolCallbacks()).isEmpty();
assertThat(agent.canonicalAfterToolCallbacks()).isEmpty();
assertThat(agent.canonicalOnToolErrorCallbacks()).isEmpty();

assertThat(agent.beforeModelCallback()).isEmpty();
assertThat(agent.afterModelCallback()).isEmpty();
assertThat(agent.onModelErrorCallback()).isEmpty();
assertThat(agent.beforeToolCallback()).isEmpty();
assertThat(agent.afterToolCallback()).isEmpty();
assertThat(agent.onToolErrorCallback()).isEmpty();
}

@Test
Expand Down Expand Up @@ -371,5 +378,12 @@ public void canonicalCallbacks_returnsListWhenPresent() {
assertThat(agent.canonicalBeforeToolCallbacks()).containsExactly(btc);
assertThat(agent.canonicalAfterToolCallbacks()).containsExactly(atc);
assertThat(agent.canonicalOnToolErrorCallbacks()).containsExactly(otec);

assertThat(agent.beforeModelCallback()).containsExactly(bmc);
assertThat(agent.afterModelCallback()).containsExactly(amc);
assertThat(agent.onModelErrorCallback()).containsExactly(omec);
assertThat(agent.beforeToolCallback()).containsExactly(btc);
assertThat(agent.afterToolCallback()).containsExactly(atc);
assertThat(agent.onToolErrorCallback()).containsExactly(otec);
}
}