diff --git a/core/src/main/java/com/google/adk/codeexecutors/BaseCodeExecutor.java b/core/src/main/java/com/google/adk/codeexecutors/BaseCodeExecutor.java index 92c3fe28..b13d69a9 100644 --- a/core/src/main/java/com/google/adk/codeexecutors/BaseCodeExecutor.java +++ b/core/src/main/java/com/google/adk/codeexecutors/BaseCodeExecutor.java @@ -21,7 +21,6 @@ import com.google.adk.codeexecutors.CodeExecutionUtils.CodeExecutionInput; import com.google.adk.codeexecutors.CodeExecutionUtils.CodeExecutionResult; import com.google.common.collect.ImmutableList; -import java.util.List; /** * Abstract base class for all code executors. @@ -30,6 +29,13 @@ * the execution results into the final response. */ public abstract class BaseCodeExecutor extends JsonBaseModel { + + private static final ImmutableList> CODE_BLOCK_DELIMITERS = + ImmutableList.of( + ImmutableList.of("```tool_code\n", "\n```"), ImmutableList.of("```python\n", "\n```")); + private static final ImmutableList EXECUTION_RESULT_DELIMITERS = + ImmutableList.of("```tool_output\n", "\n```"); + /** * If true, extract and process data files from the model request and attach them to the code * executor. @@ -57,6 +63,9 @@ public int errorRetryAttempts() { /** * The list of the enclosing delimiters to identify the code blocks. * + *

Each inner list contains a pair of start and end delimiters. This supports multiple pairs of + * delimiters. + * *

For example, the delimiter ('```python\n', '\n```') can be used to identify code blocks with * the following format: * @@ -66,19 +75,20 @@ public int errorRetryAttempts() { * *

``` */ - public List> codeBlockDelimiters() { - return ImmutableList.of( - ImmutableList.of("```tool_code\n", "\n```"), ImmutableList.of("```python\n", "\n```")); + public ImmutableList> codeBlockDelimiters() { + return CODE_BLOCK_DELIMITERS; } /** The delimiters to format the code execution result. */ - public List executionResultDelimiters() { - return ImmutableList.of("```tool_output\n", "\n```"); + public ImmutableList executionResultDelimiters() { + return EXECUTION_RESULT_DELIMITERS; } /** * Executes code and return the code execution result. * + *

This method may perform blocking operations. + * * @param invocationContext The invocation context of the code execution. * @param codeExecutionInput The code execution input. * @return The code execution result. diff --git a/core/src/main/java/com/google/adk/codeexecutors/BuiltInCodeExecutor.java b/core/src/main/java/com/google/adk/codeexecutors/BuiltInCodeExecutor.java index f92c79a8..972082dd 100644 --- a/core/src/main/java/com/google/adk/codeexecutors/BuiltInCodeExecutor.java +++ b/core/src/main/java/com/google/adk/codeexecutors/BuiltInCodeExecutor.java @@ -46,9 +46,8 @@ public void processLlmRequest(LlmRequest.Builder llmRequestBuilder) { if (ModelNameUtils.isGemini2Model(llmRequest.model().orElse(null))) { GenerateContentConfig.Builder configBuilder = llmRequest.config().map(c -> c.toBuilder()).orElseGet(GenerateContentConfig::builder); - ImmutableList.Builder toolsBuilder = - ImmutableList.builder() - .addAll(configBuilder.build().tools().orElse(ImmutableList.of())); + ImmutableList.Builder toolsBuilder = ImmutableList.builder(); + llmRequest.config().ifPresent(c -> c.tools().ifPresent(toolsBuilder::addAll)); toolsBuilder.add(Tool.builder().codeExecution(ToolCodeExecution.builder().build()).build()); configBuilder.tools(toolsBuilder.build()); llmRequestBuilder.config(configBuilder.build()); diff --git a/core/src/main/java/com/google/adk/codeexecutors/CodeExecutionUtils.java b/core/src/main/java/com/google/adk/codeexecutors/CodeExecutionUtils.java index ee5be789..b9afdcaf 100644 --- a/core/src/main/java/com/google/adk/codeexecutors/CodeExecutionUtils.java +++ b/core/src/main/java/com/google/adk/codeexecutors/CodeExecutionUtils.java @@ -128,7 +128,7 @@ public static Content convertCodeExecutionParts( * @return The extracted code if found. */ public static Optional extractCodeAndTruncateContent( - Content.Builder contentBuilder, List> codeBlockDelimiters) { + Content.Builder contentBuilder, List> codeBlockDelimiters) { Content content = contentBuilder.build(); if (content.parts().isEmpty() || content.parts().get().isEmpty()) { return Optional.empty(); diff --git a/core/src/main/java/com/google/adk/codeexecutors/VertexAiCodeExecutor.java b/core/src/main/java/com/google/adk/codeexecutors/VertexAiCodeExecutor.java index 03110df6..5268edf3 100644 --- a/core/src/main/java/com/google/adk/codeexecutors/VertexAiCodeExecutor.java +++ b/core/src/main/java/com/google/adk/codeexecutors/VertexAiCodeExecutor.java @@ -37,8 +37,8 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.logging.Level; -import java.util.logging.Logger; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A code executor that uses Vertex Code Interpreter Extension to execute code. @@ -51,7 +51,7 @@ * setup. */ public final class VertexAiCodeExecutor extends BaseCodeExecutor { - private static final Logger logger = Logger.getLogger(VertexAiCodeExecutor.class.getName()); + private static final Logger logger = LoggerFactory.getLogger(VertexAiCodeExecutor.class); private static final ImmutableList SUPPORTED_IMAGE_TYPES = ImmutableList.of("png", "jpg", "jpeg"); @@ -108,7 +108,8 @@ public final class VertexAiCodeExecutor extends BaseCodeExecutor { + "{df_info}\"\"\")"; private final String resourceName; - private final ExtensionExecutionServiceClient codeInterpreterExtension; + private ExtensionExecutionServiceClient codeInterpreterExtension; + private final Object extensionClientLock = new Object(); /** * Initializes the VertexAiCodeExecutor. @@ -123,26 +124,11 @@ public VertexAiCodeExecutor(String resourceName) { } if (resolvedResourceName == null || resolvedResourceName.isEmpty()) { - logger.warning( + logger.warn( "No resource name found for Vertex AI Code Interpreter. It will not be available."); this.resourceName = null; - this.codeInterpreterExtension = null; } else { this.resourceName = resolvedResourceName; - try { - String[] parts = this.resourceName.split("/"); - if (parts.length < 4 || !parts[2].equals("locations")) { - throw new IllegalArgumentException("Invalid resource name format: " + this.resourceName); - } - String location = parts[3]; - String endpoint = String.format("%s-aiplatform.googleapis.com:443", location); - ExtensionExecutionServiceSettings settings = - ExtensionExecutionServiceSettings.newBuilder().setEndpoint(endpoint).build(); - this.codeInterpreterExtension = ExtensionExecutionServiceClient.create(settings); - } catch (IOException e) { - logger.log(Level.SEVERE, "Failed to create ExtensionExecutionServiceClient", e); - throw new IllegalStateException("Failed to create ExtensionExecutionServiceClient", e); - } } } @@ -188,9 +174,9 @@ public CodeExecutionResult executeCode( private Map executeCodeInterpreter( String code, List inputFiles, Optional sessionId) { + ExtensionExecutionServiceClient codeInterpreterExtension = getCodeInterpreterExtension(); if (codeInterpreterExtension == null) { - logger.warning( - "Vertex AI Code Interpreter execution is not available. Returning empty result."); + logger.warn("Vertex AI Code Interpreter execution is not available. Returning empty result."); return ImmutableMap.of( "execution_result", "", "execution_error", "", "output_files", new ArrayList<>()); } @@ -231,7 +217,7 @@ private Map executeCodeInterpreter( ObjectMapper mapper = new ObjectMapper(); return mapper.readValue(jsonOutput, new TypeReference>() {}); } catch (IOException e) { - logger.log(Level.SEVERE, "Failed to parse JSON from code interpreter: " + jsonOutput, e); + logger.error("Failed to parse JSON from code interpreter: " + jsonOutput, e); return ImmutableMap.of( "execution_result", "", @@ -242,6 +228,31 @@ private Map executeCodeInterpreter( } } + private ExtensionExecutionServiceClient getCodeInterpreterExtension() { + if (this.resourceName == null) { + return null; + } + synchronized (extensionClientLock) { + if (this.codeInterpreterExtension == null) { + try { + String[] parts = this.resourceName.split("/"); + if (parts.length < 4 || !parts[2].equals("locations")) { + throw new IllegalArgumentException( + "Invalid resource name format: " + this.resourceName); + } + String location = parts[3]; + String endpoint = String.format("%s-aiplatform.googleapis.com:443", location); + ExtensionExecutionServiceSettings settings = + ExtensionExecutionServiceSettings.newBuilder().setEndpoint(endpoint).build(); + this.codeInterpreterExtension = ExtensionExecutionServiceClient.create(settings); + } catch (IOException e) { + throw new IllegalStateException("Failed to create ExtensionExecutionServiceClient", e); + } + } + return this.codeInterpreterExtension; + } + } + private String getCodeWithImports(String code) { return String.format("%s\n\n%s", IMPORTED_LIBRARIES, code); }