Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import com.google.adk.codeexecutors.CodeExecutionUtils.CodeExecutionInput;
import com.google.adk.codeexecutors.CodeExecutionUtils.CodeExecutionResult;
import com.google.common.collect.ImmutableList;
import java.util.List;

/**
* Abstract base class for all code executors.
Expand All @@ -30,6 +29,13 @@
* the execution results into the final response.
*/
public abstract class BaseCodeExecutor extends JsonBaseModel {

private static final ImmutableList<ImmutableList<String>> CODE_BLOCK_DELIMITERS =
ImmutableList.of(
ImmutableList.of("```tool_code\n", "\n```"), ImmutableList.of("```python\n", "\n```"));
private static final ImmutableList<String> EXECUTION_RESULT_DELIMITERS =
ImmutableList.of("```tool_output\n", "\n```");

/**
* If true, extract and process data files from the model request and attach them to the code
* executor.
Expand Down Expand Up @@ -57,6 +63,9 @@ public int errorRetryAttempts() {
/**
* The list of the enclosing delimiters to identify the code blocks.
*
* <p>Each inner list contains a pair of start and end delimiters. This supports multiple pairs of
* delimiters.
*
* <p>For example, the delimiter ('```python\n', '\n```') can be used to identify code blocks with
* the following format:
*
Expand All @@ -66,19 +75,20 @@ public int errorRetryAttempts() {
*
* <p>```
*/
public List<List<String>> codeBlockDelimiters() {
return ImmutableList.of(
ImmutableList.of("```tool_code\n", "\n```"), ImmutableList.of("```python\n", "\n```"));
public ImmutableList<ImmutableList<String>> codeBlockDelimiters() {
return CODE_BLOCK_DELIMITERS;
}

/** The delimiters to format the code execution result. */
public List<String> executionResultDelimiters() {
return ImmutableList.of("```tool_output\n", "\n```");
public ImmutableList<String> executionResultDelimiters() {
return EXECUTION_RESULT_DELIMITERS;
}

/**
* Executes code and return the code execution result.
*
* <p>This method may perform blocking operations.
*
* @param invocationContext The invocation context of the code execution.
* @param codeExecutionInput The code execution input.
* @return The code execution result.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,8 @@ public void processLlmRequest(LlmRequest.Builder llmRequestBuilder) {
if (ModelNameUtils.isGemini2Model(llmRequest.model().orElse(null))) {
GenerateContentConfig.Builder configBuilder =
llmRequest.config().map(c -> c.toBuilder()).orElseGet(GenerateContentConfig::builder);
ImmutableList.Builder<Tool> toolsBuilder =
ImmutableList.<Tool>builder()
.addAll(configBuilder.build().tools().orElse(ImmutableList.<Tool>of()));
ImmutableList.Builder<Tool> toolsBuilder = ImmutableList.<Tool>builder();
llmRequest.config().ifPresent(c -> c.tools().ifPresent(toolsBuilder::addAll));
toolsBuilder.add(Tool.builder().codeExecution(ToolCodeExecution.builder().build()).build());
configBuilder.tools(toolsBuilder.build());
llmRequestBuilder.config(configBuilder.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ public static Content convertCodeExecutionParts(
* @return The extracted code if found.
*/
public static Optional<String> extractCodeAndTruncateContent(
Content.Builder contentBuilder, List<List<String>> codeBlockDelimiters) {
Content.Builder contentBuilder, List<? extends List<String>> codeBlockDelimiters) {
Content content = contentBuilder.build();
if (content.parts().isEmpty() || content.parts().get().isEmpty()) {
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* A code executor that uses Vertex Code Interpreter Extension to execute code.
Expand All @@ -51,7 +51,7 @@
* setup.
*/
public final class VertexAiCodeExecutor extends BaseCodeExecutor {
private static final Logger logger = Logger.getLogger(VertexAiCodeExecutor.class.getName());
private static final Logger logger = LoggerFactory.getLogger(VertexAiCodeExecutor.class);

private static final ImmutableList<String> SUPPORTED_IMAGE_TYPES =
ImmutableList.of("png", "jpg", "jpeg");
Expand Down Expand Up @@ -108,7 +108,8 @@ public final class VertexAiCodeExecutor extends BaseCodeExecutor {
+ "{df_info}\"\"\")";

private final String resourceName;
private final ExtensionExecutionServiceClient codeInterpreterExtension;
private ExtensionExecutionServiceClient codeInterpreterExtension;
private final Object extensionClientLock = new Object();

/**
* Initializes the VertexAiCodeExecutor.
Expand All @@ -123,26 +124,11 @@ public VertexAiCodeExecutor(String resourceName) {
}

if (resolvedResourceName == null || resolvedResourceName.isEmpty()) {
logger.warning(
logger.warn(
"No resource name found for Vertex AI Code Interpreter. It will not be available.");
this.resourceName = null;
this.codeInterpreterExtension = null;
} else {
this.resourceName = resolvedResourceName;
try {
String[] parts = this.resourceName.split("/");
if (parts.length < 4 || !parts[2].equals("locations")) {
throw new IllegalArgumentException("Invalid resource name format: " + this.resourceName);
}
String location = parts[3];
String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);
ExtensionExecutionServiceSettings settings =
ExtensionExecutionServiceSettings.newBuilder().setEndpoint(endpoint).build();
this.codeInterpreterExtension = ExtensionExecutionServiceClient.create(settings);
} catch (IOException e) {
logger.log(Level.SEVERE, "Failed to create ExtensionExecutionServiceClient", e);
throw new IllegalStateException("Failed to create ExtensionExecutionServiceClient", e);
}
}
}

Expand Down Expand Up @@ -188,9 +174,9 @@ public CodeExecutionResult executeCode(

private Map<String, Object> executeCodeInterpreter(
String code, List<File> inputFiles, Optional<String> sessionId) {
ExtensionExecutionServiceClient codeInterpreterExtension = getCodeInterpreterExtension();
if (codeInterpreterExtension == null) {
logger.warning(
"Vertex AI Code Interpreter execution is not available. Returning empty result.");
logger.warn("Vertex AI Code Interpreter execution is not available. Returning empty result.");
return ImmutableMap.of(
"execution_result", "", "execution_error", "", "output_files", new ArrayList<>());
}
Expand Down Expand Up @@ -231,7 +217,7 @@ private Map<String, Object> executeCodeInterpreter(
ObjectMapper mapper = new ObjectMapper();
return mapper.readValue(jsonOutput, new TypeReference<Map<String, Object>>() {});
} catch (IOException e) {
logger.log(Level.SEVERE, "Failed to parse JSON from code interpreter: " + jsonOutput, e);
logger.error("Failed to parse JSON from code interpreter: " + jsonOutput, e);
return ImmutableMap.of(
"execution_result",
"",
Expand All @@ -242,6 +228,31 @@ private Map<String, Object> executeCodeInterpreter(
}
}

private ExtensionExecutionServiceClient getCodeInterpreterExtension() {
if (this.resourceName == null) {
return null;
}
synchronized (extensionClientLock) {
if (this.codeInterpreterExtension == null) {
try {
String[] parts = this.resourceName.split("/");
if (parts.length < 4 || !parts[2].equals("locations")) {
throw new IllegalArgumentException(
"Invalid resource name format: " + this.resourceName);
}
String location = parts[3];
String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);
ExtensionExecutionServiceSettings settings =
ExtensionExecutionServiceSettings.newBuilder().setEndpoint(endpoint).build();
this.codeInterpreterExtension = ExtensionExecutionServiceClient.create(settings);
} catch (IOException e) {
throw new IllegalStateException("Failed to create ExtensionExecutionServiceClient", e);
}
}
return this.codeInterpreterExtension;
}
}

private String getCodeWithImports(String code) {
return String.format("%s\n\n%s", IMPORTED_LIBRARIES, code);
}
Expand Down