From 83d60d7b62f02c0f15a804b7afeff357c4dd0fc9 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 28 Jan 2026 09:43:23 -0800 Subject: [PATCH] feat: Adding SafeFilesAsArtifactsPlugin Also, adding a new `ArtifactService.saveAndReloadArtifact()` method so that the plugin doesn't have to perform a second i/o call just to get the full file path. PiperOrigin-RevId: 862275776 --- .../adk/artifacts/BaseArtifactService.java | 17 + .../adk/artifacts/GcsArtifactService.java | 104 ++++-- .../artifacts/InMemoryArtifactService.java | 8 + .../plugins/SaveFilesAsArtifactsPlugin.java | 166 +++++++++ .../adk/artifacts/GcsArtifactServiceTest.java | 32 ++ .../InMemoryArtifactServiceTest.java | 77 ++++ .../SaveFilesAsArtifactsPluginTest.java | 347 ++++++++++++++++++ 7 files changed, 723 insertions(+), 28 deletions(-) create mode 100644 core/src/main/java/com/google/adk/plugins/SaveFilesAsArtifactsPlugin.java create mode 100644 core/src/test/java/com/google/adk/artifacts/InMemoryArtifactServiceTest.java create mode 100644 core/src/test/java/com/google/adk/plugins/SaveFilesAsArtifactsPluginTest.java diff --git a/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java b/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java index 847e88dd..1cb717a0 100644 --- a/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java @@ -72,6 +72,23 @@ Maybe loadArtifact( */ Completable deleteArtifact(String appName, String userId, String sessionId, String filename); + /** + * Saves an artifact and returns it with fileData if available. + * + * @param appName the app name + * @param userId the user ID + * @param sessionId the session ID + * @param filename the filename + * @param artifact the artifact to save + * @return the saved artifact with fileData if available. + */ + default Maybe saveAndReloadArtifact( + String appName, String userId, String sessionId, String filename, Part artifact) { + return saveArtifact(appName, userId, sessionId, filename, artifact) + .flatMapMaybe( + version -> loadArtifact(appName, userId, sessionId, filename, Optional.of(version))); + } + /** * Lists all the versions (as revision IDs) of an artifact. * diff --git a/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java b/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java index 1bfef8cf..4dafca5a 100644 --- a/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java @@ -18,6 +18,7 @@ import static java.util.Collections.max; +import com.google.auto.value.AutoValue; import com.google.cloud.storage.Blob; import com.google.cloud.storage.BlobId; import com.google.cloud.storage.BlobInfo; @@ -27,6 +28,7 @@ import com.google.common.base.Splitter; import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; +import com.google.genai.types.FileData; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; @@ -108,34 +110,8 @@ private String getBlobName( @Override public Single saveArtifact( String appName, String userId, String sessionId, String filename, Part artifact) { - return listVersions(appName, userId, sessionId, filename) - .map(versions -> versions.isEmpty() ? 0 : max(versions) + 1) - .map( - nextVersion -> { - String blobName = getBlobName(appName, userId, sessionId, filename, nextVersion); - BlobId blobId = BlobId.of(bucketName, blobName); - - BlobInfo blobInfo = - BlobInfo.newBuilder(blobId) - .setContentType(artifact.inlineData().get().mimeType().orElse(null)) - .build(); - - try { - byte[] dataToSave = - artifact - .inlineData() - .get() - .data() - .orElseThrow( - () -> - new IllegalArgumentException( - "Saveable artifact data must be non-empty.")); - storageClient.create(blobInfo, dataToSave); - return nextVersion; - } catch (StorageException e) { - throw new VerifyException("Failed to save artifact to GCS", e); - } - }); + return saveArtifactAndReturnBlob(appName, userId, sessionId, filename, artifact) + .map(SaveResult::version); } /** @@ -275,4 +251,76 @@ public Single> listVersions( return Single.just(ImmutableList.of()); } } + + @Override + public Maybe saveAndReloadArtifact( + String appName, String userId, String sessionId, String filename, Part artifact) { + return saveArtifactAndReturnBlob(appName, userId, sessionId, filename, artifact) + .flatMapMaybe( + blob -> { + Blob savedBlob = blob.blob(); + String resultMimeType = savedBlob.getContentType(); + if (resultMimeType == null && artifact.inlineData().isPresent()) { + resultMimeType = artifact.inlineData().get().mimeType().orElse(null); + } + if (resultMimeType == null) { + resultMimeType = "application/octet-stream"; + } + return Maybe.just( + Part.builder() + .fileData( + FileData.builder() + .fileUri("gs://" + savedBlob.getBucket() + "/" + savedBlob.getName()) + .mimeType(resultMimeType) + .build()) + .build()); + }); + } + + @AutoValue + abstract static class SaveResult { + static SaveResult create(Blob blob, int version) { + return new AutoValue_GcsArtifactService_SaveResult(blob, version); + } + + abstract Blob blob(); + + abstract int version(); + } + + private Single saveArtifactAndReturnBlob( + String appName, String userId, String sessionId, String filename, Part artifact) { + return listVersions(appName, userId, sessionId, filename) + .map(versions -> versions.isEmpty() ? 0 : max(versions) + 1) + .map( + nextVersion -> { + if (artifact.inlineData().isEmpty()) { + throw new IllegalArgumentException("Saveable artifact must have inline data."); + } + + String blobName = getBlobName(appName, userId, sessionId, filename, nextVersion); + BlobId blobId = BlobId.of(bucketName, blobName); + + BlobInfo blobInfo = + BlobInfo.newBuilder(blobId) + .setContentType(artifact.inlineData().get().mimeType().orElse(null)) + .build(); + + try { + byte[] dataToSave = + artifact + .inlineData() + .get() + .data() + .orElseThrow( + () -> + new IllegalArgumentException( + "Saveable artifact data must be non-empty.")); + Blob blob = storageClient.create(blobInfo, dataToSave); + return SaveResult.create(blob, nextVersion); + } catch (StorageException e) { + throw new VerifyException("Failed to save artifact to GCS", e); + } + }); + } } diff --git a/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java b/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java index 27b85136..b1f61019 100644 --- a/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java @@ -125,6 +125,14 @@ public Single> listVersions( return Single.just(IntStream.range(0, size).boxed().collect(toImmutableList())); } + @Override + public Maybe saveAndReloadArtifact( + String appName, String userId, String sessionId, String filename, Part artifact) { + return saveArtifact(appName, userId, sessionId, filename, artifact) + .flatMapMaybe( + version -> loadArtifact(appName, userId, sessionId, filename, Optional.of(version))); + } + private Map> getArtifactsMap(String appName, String userId, String sessionId) { return artifacts .computeIfAbsent(appName, unused -> new HashMap<>()) diff --git a/core/src/main/java/com/google/adk/plugins/SaveFilesAsArtifactsPlugin.java b/core/src/main/java/com/google/adk/plugins/SaveFilesAsArtifactsPlugin.java new file mode 100644 index 00000000..2474a772 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/SaveFilesAsArtifactsPlugin.java @@ -0,0 +1,166 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.plugins; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +import com.google.adk.agents.InvocationContext; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.FileData; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; +import java.util.Locale; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A plugin that saves files embedded in user messages as artifacts. + * + *

This is useful to allow users to upload files in the chat experience and have those files + * available to the agent within the current session. + * + *

We use Blob.display_name to determine the file name. By default, artifacts are session-scoped. + * For cross-session persistence, prefix the filename with "user:". + * + *

Artifacts with the same name will be overwritten. A placeholder with the artifact name will be + * put in place of the embedded file in the user message so the model knows where to find the file. + * You may want to add load_artifacts tool to the agent, or load the artifacts in your own tool to + * use the files. + */ +public class SaveFilesAsArtifactsPlugin extends BasePlugin { + private static final Logger logger = LoggerFactory.getLogger(SaveFilesAsArtifactsPlugin.class); + + private static final ImmutableSet MODEL_ACCESSIBLE_URI_SCHEMES = + ImmutableSet.of("gs", "https", "http"); + + public SaveFilesAsArtifactsPlugin(String name) { + super(name); + } + + public SaveFilesAsArtifactsPlugin() { + this("save_files_as_artifacts_plugin"); + } + + @Override + public Maybe onUserMessageCallback( + InvocationContext invocationContext, Content userMessage) { + if (invocationContext.artifactService() == null) { + logger.warn("Artifact service is not set. SaveFilesAsArtifactsPlugin will not be enabled."); + return Maybe.just(userMessage); + } + + if (userMessage.parts().stream() + .flatMap(List::stream) + .map(Part::inlineData) + .noneMatch(Optional::isPresent)) { + return Maybe.empty(); + } + + AtomicInteger index = new AtomicInteger(0); + + return Flowable.fromIterable(userMessage.parts().get()) + .concatMapSingle( + part -> { + if (part.inlineData().isEmpty()) { + return Single.just(ImmutableList.of(part)); + } + return saveArtifactAndBuildParts(invocationContext, part, index.getAndIncrement()); + }) + .toList() // Collects Single> into a Single>> + .map( + listOfLists -> + listOfLists.stream() + .flatMap(List::stream) + .collect(toImmutableList())) // Flatten the list of lists + .map( + parts -> Content.builder().parts(parts).role(userMessage.role().orElse("user")).build()) + .toMaybe(); + } + + private Single> saveArtifactAndBuildParts( + InvocationContext invocationContext, Part part, int index) { + Blob inlineData = part.inlineData().get(); + String fileName = + inlineData + .displayName() + .filter(s -> !s.isEmpty()) + .orElseGet( + () -> { + String generatedName = + String.format("artifact_%s_%d", invocationContext.invocationId(), index); + logger.info("No display_name found, using generated filename: {}", generatedName); + return generatedName; + }); + Part placeholderPart = Part.fromText(String.format("[Uploaded Artifact: \"%s\"]", fileName)); + + return invocationContext + .artifactService() + .saveAndReloadArtifact( + invocationContext.appName(), + invocationContext.userId(), + invocationContext.session().id(), + fileName, + part) + .doOnSuccess(unused -> logger.info("Successfully saved artifact: {}", fileName)) + .flatMap( + artifact -> + Maybe.fromOptional( + artifact + .fileData() + .filter(fd -> fd.fileUri().map(this::isModelAccessibleUri).orElse(false)) + .map(fd -> buildPartFromFileData(fd, inlineData.mimeType(), fileName)))) + .map(filePart -> ImmutableList.of(placeholderPart, filePart)) + .defaultIfEmpty(ImmutableList.of(placeholderPart)) + .onErrorReturn( + e -> { + logger.error("Failed to save artifact for part {}: {}", index, e); + return ImmutableList.of(part); // Keep original part if saving fails + }); + } + + private boolean isModelAccessibleUri(String uri) { + try { + URI parsed = new URI(uri); + return parsed.getScheme() != null + && MODEL_ACCESSIBLE_URI_SCHEMES.contains(parsed.getScheme().toLowerCase(Locale.ROOT)); + } catch (URISyntaxException e) { + return false; + } + } + + private Part buildPartFromFileData(FileData fd, Optional mimeType, String fileName) { + return Part.builder() + .fileData( + FileData.builder() + .fileUri(fd.fileUri().get()) + // Prioritize the mimeType from the original inlineData, as the artifact service + // might return a more generic type. + .mimeType(mimeType.or(fd::mimeType).orElse("application/octet-stream")) + .displayName(fileName) + .build()) + .build(); + } +} diff --git a/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java b/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java index 1df66c36..3b0dc5e4 100644 --- a/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java +++ b/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java @@ -76,6 +76,7 @@ private Blob mockBlob(String name, String contentType, byte[] content) { when(blob.exists()).thenReturn(true); BlobId blobId = BlobId.of(BUCKET_NAME, name); when(blob.getBlobId()).thenReturn(blobId); + when(blob.getBucket()).thenReturn(BUCKET_NAME); return blob; } @@ -89,6 +90,8 @@ public void save_firstVersion_savesCorrectly() { BlobInfo.newBuilder(expectedBlobId).setContentType("application/octet-stream").build(); when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of()); + Blob savedBlob = mockBlob(expectedBlobName, "application/octet-stream", new byte[] {1, 2, 3}); + when(mockStorage.create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3}))).thenReturn(savedBlob); int version = service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact).blockingGet(); @@ -109,6 +112,8 @@ public void save_subsequentVersion_savesCorrectly() { Blob blobV0 = mockBlob(blobNameV0, "text/plain", new byte[] {1}); when(mockBlobPage.iterateAll()).thenReturn(Collections.singletonList(blobV0)); + Blob savedBlob = mockBlob(expectedBlobNameV1, "image/png", new byte[] {4, 5}); + when(mockStorage.create(eq(expectedBlobInfoV1), eq(new byte[] {4, 5}))).thenReturn(savedBlob); int version = service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact).blockingGet(); @@ -126,6 +131,8 @@ public void save_userNamespace_savesCorrectly() { BlobInfo.newBuilder(expectedBlobId).setContentType("application/json").build(); when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of()); + Blob savedBlob = mockBlob(expectedBlobName, "application/json", new byte[] {1, 2, 3}); + when(mockStorage.create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3}))).thenReturn(savedBlob); int version = service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, USER_FILENAME, artifact).blockingGet(); @@ -330,6 +337,31 @@ public void listVersions_noVersions_returnsEmptyList() { assertThat(versions).isEmpty(); } + @Test + public void saveAndReloadArtifact_savesAndReturnsFileData() { + Part artifact = Part.fromBytes(new byte[] {1, 2, 3}, "application/octet-stream"); + String expectedBlobName = + String.format("%s/%s/%s/%s/0", APP_NAME, USER_ID, SESSION_ID, FILENAME); + BlobId expectedBlobId = BlobId.of(BUCKET_NAME, expectedBlobName); + BlobInfo expectedBlobInfo = + BlobInfo.newBuilder(expectedBlobId).setContentType("application/octet-stream").build(); + + when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of()); + Blob savedBlob = mockBlob(expectedBlobName, "application/octet-stream", new byte[] {1, 2, 3}); + when(mockStorage.create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3}))).thenReturn(savedBlob); + + Optional result = + asOptional( + service.saveAndReloadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact)); + + assertThat(result).isPresent(); + assertThat(result.get().fileData()).isPresent(); + assertThat(result.get().fileData().get().fileUri()) + .hasValue("gs://" + BUCKET_NAME + "/" + expectedBlobName); + assertThat(result.get().fileData().get().mimeType()).hasValue("application/octet-stream"); + verify(mockStorage).create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3})); + } + private static Optional asOptional(Maybe maybe) { return maybe.map(Optional::of).defaultIfEmpty(Optional.empty()).blockingGet(); } diff --git a/core/src/test/java/com/google/adk/artifacts/InMemoryArtifactServiceTest.java b/core/src/test/java/com/google/adk/artifacts/InMemoryArtifactServiceTest.java new file mode 100644 index 00000000..c9f7541a --- /dev/null +++ b/core/src/test/java/com/google/adk/artifacts/InMemoryArtifactServiceTest.java @@ -0,0 +1,77 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.artifacts; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Maybe; +import java.util.Optional; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link InMemoryArtifactService}. */ +@RunWith(JUnit4.class) +public class InMemoryArtifactServiceTest { + + private static final String APP_NAME = "test-app"; + private static final String USER_ID = "test-user"; + private static final String SESSION_ID = "test-session"; + private static final String FILENAME = "test-file.txt"; + + private InMemoryArtifactService service; + + @Before + public void setUp() { + service = new InMemoryArtifactService(); + } + + @Test + public void saveArtifact_savesAndReturnsVersion() { + Part artifact = Part.fromBytes(new byte[] {1, 2, 3}, "text/plain"); + int version = + service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact).blockingGet(); + assertThat(version).isEqualTo(0); + } + + @Test + public void loadArtifact_loadsLatest() { + Part artifact1 = Part.fromBytes(new byte[] {1}, "text/plain"); + Part artifact2 = Part.fromBytes(new byte[] {1, 2}, "text/plain"); + var unused1 = + service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact1).blockingGet(); + var unused2 = + service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact2).blockingGet(); + Optional result = + asOptional(service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, Optional.empty())); + assertThat(result).hasValue(artifact2); + } + + @Test + public void saveAndReloadArtifact_reloadsArtifact() { + Part artifact = Part.fromBytes(new byte[] {1, 2, 3}, "text/plain"); + Optional result = + asOptional( + service.saveAndReloadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact)); + assertThat(result).hasValue(artifact); + } + + private static Optional asOptional(Maybe maybe) { + return maybe.map(Optional::of).defaultIfEmpty(Optional.empty()).blockingGet(); + } +} diff --git a/core/src/test/java/com/google/adk/plugins/SaveFilesAsArtifactsPluginTest.java b/core/src/test/java/com/google/adk/plugins/SaveFilesAsArtifactsPluginTest.java new file mode 100644 index 00000000..1d3b7ed0 --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/SaveFilesAsArtifactsPluginTest.java @@ -0,0 +1,347 @@ +package com.google.adk.plugins; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; + +import com.google.adk.agents.CallbackContext; +import com.google.adk.agents.InvocationContext; +import com.google.adk.artifacts.BaseArtifactService; +import com.google.adk.sessions.Session; +import com.google.adk.tools.BaseTool; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.FileData; +import com.google.genai.types.Part; +import com.google.protobuf.ByteString; +import io.reactivex.rxjava3.core.Maybe; +import java.util.Optional; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** + * Unit tests for {@link SaveFilesAsArtifactsPlugin}. + * + *

This class tests the following scenarios: + * + *

    + *
  • The plugin returns the original message if no artifact service is available. + *
  • The plugin returns no message if the user message contains no inline data. + *
  • The plugin correctly saves artifacts when inline data is present, and: + *
      + *
    • Returns a text part indicating upload if the saved artifact has no URI. + *
    • Returns a text part and a file data part if the saved artifact has an accessible URI. + *
    • Returns only a text part if the saved artifact has an inaccessible URI. + *
    + *
  • The plugin returns the original part if saving the artifact fails. + *
  • The plugin correctly handles messages with multiple parts, some with inline data and some + * without. + *
  • The plugin uses the display name from the inline data as the artifact file name if + * provided. + *
+ */ +@RunWith(JUnit4.class) +public class SaveFilesAsArtifactsPluginTest { + + @Rule public MockitoRule mockitoRule = MockitoJUnit.rule(); + + private static final String APP_NAME = "test_app"; + private static final String USER_ID = "test_user"; + private static final String SESSION_ID = "test_session"; + private static final String INVOCATION_ID = "test_invocation"; + + @Mock private BaseArtifactService mockArtifactService; + @Mock private CallbackContext mockCallbackContext; + @Mock private BaseTool mockTool; + private Session session; + + private SaveFilesAsArtifactsPlugin plugin; + private InvocationContext invocationContext; + private InvocationContext invocationContextWithNoArtifactService; + + @Before + public void setUp() { + session = Session.builder(SESSION_ID).appName(APP_NAME).userId(USER_ID).build(); + + invocationContext = + InvocationContext.builder() + .invocationId(INVOCATION_ID) + .session(session) + .artifactService(mockArtifactService) + .build(); + invocationContextWithNoArtifactService = + InvocationContext.builder() + .invocationId(INVOCATION_ID) + .session(session) + .artifactService(null) + .build(); + plugin = new SaveFilesAsArtifactsPlugin(); + } + + private Part createInlineDataPart(String mimeType, String data) { + return createInlineDataPart(mimeType, data, Optional.empty()); + } + + private Part createInlineDataPart(String mimeType, String data, Optional displayName) { + Blob.Builder blobBuilder = + Blob.builder().mimeType(mimeType).data(ByteString.copyFromUtf8(data).toByteArray()); + displayName.ifPresent(blobBuilder::displayName); + return Part.builder().inlineData(blobBuilder.build()).build(); + } + + @Test + public void getName_withDefaultConstructor_returnsDefaultName() { + SaveFilesAsArtifactsPlugin defaultPlugin = new SaveFilesAsArtifactsPlugin(); + assertThat(defaultPlugin.getName()).isEqualTo("save_files_as_artifacts_plugin"); + } + + @Test + public void getName_withNameConstructor_returnsName() { + SaveFilesAsArtifactsPlugin namedPlugin = new SaveFilesAsArtifactsPlugin("custom_name"); + assertThat(namedPlugin.getName()).isEqualTo("custom_name"); + } + + @Test + public void onUserMessageCallback_noArtifactService_returnsMessage() { + Part partWithInlineData = createInlineDataPart("text/plain", "hello"); + Content userMessage = Content.builder().parts(partWithInlineData).role("user").build(); + + plugin + .onUserMessageCallback(invocationContextWithNoArtifactService, userMessage) + .test() + .assertValue(userMessage); + } + + @Test + public void onUserMessageCallback_noInlineData_returnsEmpty() { + Content userMessage = Content.builder().parts(Part.fromText("hello")).role("user").build(); + plugin.onUserMessageCallback(invocationContext, userMessage).test().assertNoValues(); + } + + @Test + public void onUserMessageCallback_withInlineDataAndSuccessfulSaveAndNoUri_returnsTextPart() { + Part partWithInlineData = createInlineDataPart("text/plain", "hello"); + Content userMessage = Content.builder().parts(partWithInlineData).role("user").build(); + String fileName = "artifact_" + INVOCATION_ID + "_0"; + + // Load artifact returns part without FileData + when(mockArtifactService.saveAndReloadArtifact( + eq(APP_NAME), eq(USER_ID), eq(SESSION_ID), eq(fileName), eq(partWithInlineData))) + .thenReturn(Maybe.just(Part.fromText("a part without file data"))); + + Content result = plugin.onUserMessageCallback(invocationContext, userMessage).blockingGet(); + + assertThat(result.parts().get()) + .containsExactly(Part.fromText("[Uploaded Artifact: \"" + fileName + "\"]")); + } + + @Test + public void + onUserMessageCallback_withInlineDataAndSuccessfulSaveAndAccessibleUri_returnsTextAndUriParts() { + Part partWithInlineData = createInlineDataPart("text/plain", "hello"); + Content userMessage = Content.builder().parts(partWithInlineData).role("user").build(); + String fileName = "artifact_" + INVOCATION_ID + "_0"; + String fileUri = "gs://my-bucket/artifact_test_invocation_0"; + String mimeType = "text/plain"; + + when(mockArtifactService.saveAndReloadArtifact( + eq(APP_NAME), eq(USER_ID), eq(SESSION_ID), eq(fileName), eq(partWithInlineData))) + .thenReturn( + Maybe.just( + Part.builder() + .fileData( + FileData.builder() + .fileUri(fileUri) + .mimeType(mimeType) + .displayName(fileName) + .build()) + .build())); + + Content result = plugin.onUserMessageCallback(invocationContext, userMessage).blockingGet(); + + assertThat(result.parts().get()) + .containsExactly( + Part.fromText("[Uploaded Artifact: \"" + fileName + "\"]"), + Part.builder() + .fileData( + FileData.builder() + .fileUri(fileUri) + .mimeType(mimeType) + .displayName(fileName) + .build()) + .build()) + .inOrder(); + } + + @Test + public void + onUserMessageCallback_withInlineDataAndSuccessfulSaveAndInaccessibleUri_returnsTextPart() { + Part partWithInlineData = createInlineDataPart("text/plain", "hello"); + Content userMessage = Content.builder().parts(partWithInlineData).role("user").build(); + String fileName = "artifact_" + INVOCATION_ID + "_0"; + String fileUri = "file://my-bucket/artifact_test_invocation_0"; // Inaccessible scheme + String mimeType = "text/plain"; + + when(mockArtifactService.saveAndReloadArtifact( + eq(APP_NAME), eq(USER_ID), eq(SESSION_ID), eq(fileName), eq(partWithInlineData))) + .thenReturn( + Maybe.just( + Part.builder() + .fileData( + FileData.builder() + .fileUri(fileUri) + .mimeType(mimeType) + .displayName(fileName) + .build()) + .build())); + + Content result = plugin.onUserMessageCallback(invocationContext, userMessage).blockingGet(); + + assertThat(result.parts().get()) + .containsExactly(Part.fromText("[Uploaded Artifact: \"" + fileName + "\"]")); + } + + @Test + public void onUserMessageCallback_withInlineDataAndFailedSave_returnsOriginalPart() { + Part partWithInlineData = createInlineDataPart("text/plain", "hello"); + Content userMessage = Content.builder().parts(partWithInlineData).role("user").build(); + String fileName = "artifact_" + INVOCATION_ID + "_0"; + + when(mockArtifactService.saveAndReloadArtifact( + eq(APP_NAME), eq(USER_ID), eq(SESSION_ID), eq(fileName), eq(partWithInlineData))) + .thenReturn(Maybe.error(new RuntimeException("Failed to save"))); + + Content result = plugin.onUserMessageCallback(invocationContext, userMessage).blockingGet(); + + assertThat(result.parts().get()).containsExactly(partWithInlineData); + } + + @Test + public void onUserMessageCallback_withInlineDataAndMultipleParts_returnsMixedParts() { + Part textPart = Part.fromText("this is text"); + Part partWithInlineData1 = createInlineDataPart("text/plain", "inline1"); + Part partWithInlineData2 = createInlineDataPart("image/png", "inline2"); + Content userMessage = + Content.builder() + .parts(ImmutableList.of(textPart, partWithInlineData1, partWithInlineData2)) + .role("user") + .build(); + + String fileName1 = "artifact_" + INVOCATION_ID + "_0"; + String fileName2 = "artifact_" + INVOCATION_ID + "_1"; + String fileUri1 = "gs://my-bucket/artifact_test_invocation_0"; + String mimeType1 = "text/plain"; + + when(mockArtifactService.saveAndReloadArtifact( + eq(APP_NAME), eq(USER_ID), eq(SESSION_ID), eq(fileName1), eq(partWithInlineData1))) + .thenReturn( + Maybe.just( + Part.builder() + .fileData( + FileData.builder() + .fileUri(fileUri1) + .mimeType(mimeType1) + .displayName(fileName1) + .build()) + .build())); + // For 2nd artifact, do not return a file URI. + when(mockArtifactService.saveAndReloadArtifact( + eq(APP_NAME), eq(USER_ID), eq(SESSION_ID), eq(fileName2), eq(partWithInlineData2))) + .thenReturn(Maybe.empty()); + + Content result = plugin.onUserMessageCallback(invocationContext, userMessage).blockingGet(); + + assertThat(result.parts().get()) + .containsExactly( + textPart, + Part.fromText("[Uploaded Artifact: \"" + fileName1 + "\"]"), + Part.builder() + .fileData( + FileData.builder() + .fileUri(fileUri1) + .mimeType(mimeType1) + .displayName(fileName1) + .build()) + .build(), + Part.fromText("[Uploaded Artifact: \"" + fileName2 + "\"]")) + .inOrder(); + } + + @Test + public void onUserMessageCallback_withDisplayName_usesDisplayNameAsFileName() { + String displayName = "mydocument.txt"; + Part partWithInlineData = createInlineDataPart("text/plain", "hello", Optional.of(displayName)); + Content userMessage = Content.builder().parts(partWithInlineData).role("user").build(); + + when(mockArtifactService.saveAndReloadArtifact( + eq(APP_NAME), eq(USER_ID), eq(SESSION_ID), eq(displayName), eq(partWithInlineData))) + .thenReturn(Maybe.empty()); + + Content result = plugin.onUserMessageCallback(invocationContext, userMessage).blockingGet(); + + assertThat(result.parts().get()) + .containsExactly(Part.fromText("[Uploaded Artifact: \"" + displayName + "\"]")); + } + + @Test + public void beforeRunCallback_returnsEmpty() { + plugin.beforeRunCallback(invocationContext).test().assertNoValues(); + } + + @Test + public void afterRunCallback_returnsComplete() { + plugin.afterRunCallback(invocationContext).test().assertComplete(); + } + + @Test + public void beforeAgentCallback_returnsEmpty() { + plugin.beforeAgentCallback(null, mockCallbackContext).test().assertNoValues(); + } + + @Test + public void afterAgentCallback_returnsEmpty() { + plugin.afterAgentCallback(null, mockCallbackContext).test().assertNoValues(); + } + + @Test + public void beforeModelCallback_returnsEmpty() { + plugin.beforeModelCallback(mockCallbackContext, null).test().assertNoValues(); + } + + @Test + public void afterModelCallback_returnsEmpty() { + plugin.afterModelCallback(mockCallbackContext, null).test().assertNoValues(); + } + + @Test + public void beforeToolCallback_returnsEmpty() { + plugin.beforeToolCallback(mockTool, null, null).test().assertNoValues(); + } + + @Test + public void afterToolCallback_returnsEmpty() { + plugin.afterToolCallback(mockTool, null, null, null).test().assertNoValues(); + } + + @Test + public void onModelErrorCallback_returnsEmpty() { + plugin.onModelErrorCallback(mockCallbackContext, null, null).test().assertNoValues(); + } + + @Test + public void onToolErrorCallback_returnsEmpty() { + plugin.onToolErrorCallback(mockTool, null, null, null).test().assertNoValues(); + } + + @Test + public void onEventCallback_returnsEmpty() { + plugin.onEventCallback(invocationContext, null).test().assertNoValues(); + } +}