diff --git a/src/google/adk/artifacts/gcs_artifact_service.py b/src/google/adk/artifacts/gcs_artifact_service.py index f8706dedbd..9f4e7adced 100644 --- a/src/google/adk/artifacts/gcs_artifact_service.py +++ b/src/google/adk/artifacts/gcs_artifact_service.py @@ -32,6 +32,7 @@ from google.genai import types from typing_extensions import override +from . import artifact_util from ..errors.input_validation_error import InputValidationError from .base_artifact_service import ArtifactVersion from .base_artifact_service import BaseArtifactService @@ -230,9 +231,21 @@ def _save_artifact( content_type="text/plain", ) elif artifact.file_data: - raise NotImplementedError( - "Saving artifact with file_data is not supported yet in" - " GcsArtifactService." + if not artifact.file_data.file_uri: + raise InputValidationError("Artifact file_data must have a file_uri.") + if artifact_util.is_artifact_ref(artifact): + if not artifact_util.parse_artifact_uri(artifact.file_data.file_uri): + raise InputValidationError( + f"Invalid artifact reference URI: {artifact.file_data.file_uri}" + ) + # Store the URI as blob metadata; no content to upload. + blob.metadata = { + **(blob.metadata or {}), + "file_uri": artifact.file_data.file_uri, + } + blob.upload_from_string( + b"", + content_type=artifact.file_data.mime_type or None, ) else: raise InputValidationError( @@ -263,15 +276,25 @@ def _load_artifact( blob_name = self._get_blob_name( app_name, user_id, filename, version, session_id ) - blob = self.bucket.blob(blob_name) + blob = self.bucket.get_blob(blob_name) + if blob is None: + return None + + # If the artifact was saved as a file_data URI reference, restore it. + if blob.metadata and "file_uri" in blob.metadata: + return types.Part( + file_data=types.FileData( + file_uri=blob.metadata["file_uri"], + mime_type=blob.content_type or None, + ) + ) artifact_bytes = blob.download_as_bytes() if not artifact_bytes: return None - artifact = types.Part.from_bytes( + return types.Part.from_bytes( data=artifact_bytes, mime_type=blob.content_type ) - return artifact def _list_artifact_keys( self, app_name: str, user_id: str, session_id: Optional[str] diff --git a/tests/unittests/artifacts/test_artifact_service.py b/tests/unittests/artifacts/test_artifact_service.py index 8b82397097..7a0ef8879c 100644 --- a/tests/unittests/artifacts/test_artifact_service.py +++ b/tests/unittests/artifacts/test_artifact_service.py @@ -865,6 +865,119 @@ def test_converts_text_dict(self): assert result.text == "hello world" +# --------------------------------------------------------------------------- +# GCS file_data (URI reference) tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_gcs_save_artifact_with_external_gcs_uri(): + """GcsArtifactService saves and loads a gs:// file_data URI reference.""" + service = mock_gcs_artifact_service() + artifact = types.Part( + file_data=types.FileData( + file_uri="gs://my-bucket/report.pdf", + mime_type="application/pdf", + ) + ) + + version = await service.save_artifact( + app_name="app", + user_id="user1", + session_id="sess1", + filename="report.pdf", + artifact=artifact, + ) + assert version == 0 + + loaded = await service.load_artifact( + app_name="app", + user_id="user1", + session_id="sess1", + filename="report.pdf", + ) + assert loaded is not None + assert loaded.file_data is not None + assert loaded.file_data.file_uri == "gs://my-bucket/report.pdf" + assert loaded.file_data.mime_type == "application/pdf" + + +@pytest.mark.asyncio +async def test_gcs_save_artifact_with_artifact_ref_uri(): + """GcsArtifactService saves and loads an internal artifact:// URI reference.""" + service = mock_gcs_artifact_service() + artifact_ref_uri = "artifact://apps/app/users/user1/sessions/sess1/artifacts/source.txt/versions/0" + artifact = types.Part( + file_data=types.FileData( + file_uri=artifact_ref_uri, + mime_type="text/plain", + ) + ) + + version = await service.save_artifact( + app_name="app", + user_id="user1", + session_id="sess1", + filename="ref.txt", + artifact=artifact, + ) + assert version == 0 + + loaded = await service.load_artifact( + app_name="app", + user_id="user1", + session_id="sess1", + filename="ref.txt", + ) + assert loaded is not None + assert loaded.file_data is not None + assert loaded.file_data.file_uri == artifact_ref_uri + + +@pytest.mark.asyncio +async def test_gcs_save_artifact_file_data_without_mime_type(): + """GcsArtifactService handles file_data with no mime_type.""" + service = mock_gcs_artifact_service() + artifact = types.Part( + file_data=types.FileData(file_uri="gs://my-bucket/data.bin") + ) + + version = await service.save_artifact( + app_name="app", + user_id="user1", + session_id="sess1", + filename="data.bin", + artifact=artifact, + ) + assert version == 0 + + loaded = await service.load_artifact( + app_name="app", + user_id="user1", + session_id="sess1", + filename="data.bin", + ) + assert loaded is not None + assert loaded.file_data is not None + assert loaded.file_data.file_uri == "gs://my-bucket/data.bin" + + +@pytest.mark.asyncio +async def test_gcs_save_artifact_file_data_missing_uri_raises(): + """GcsArtifactService raises InputValidationError when file_uri is empty.""" + service = mock_gcs_artifact_service() + artifact = types.Part(file_data=types.FileData(file_uri="")) + + with pytest.raises(InputValidationError): + await service.save_artifact( + app_name="app", + user_id="user1", + session_id="sess1", + filename="empty.bin", + artifact=artifact, + ) + + @pytest.mark.asyncio @pytest.mark.parametrize( "service_type",