Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions src/google/adk/artifacts/gcs_artifact_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from google.genai import types
from typing_extensions import override

from . import artifact_util
from ..errors.input_validation_error import InputValidationError
from .base_artifact_service import ArtifactVersion
from .base_artifact_service import BaseArtifactService
Expand Down Expand Up @@ -230,9 +231,21 @@ def _save_artifact(
content_type="text/plain",
)
elif artifact.file_data:
raise NotImplementedError(
"Saving artifact with file_data is not supported yet in"
" GcsArtifactService."
if not artifact.file_data.file_uri:
raise InputValidationError("Artifact file_data must have a file_uri.")
if artifact_util.is_artifact_ref(artifact):
if not artifact_util.parse_artifact_uri(artifact.file_data.file_uri):
raise InputValidationError(
f"Invalid artifact reference URI: {artifact.file_data.file_uri}"
)
# Store the URI as blob metadata; no content to upload.
blob.metadata = {
**(blob.metadata or {}),
"file_uri": artifact.file_data.file_uri,
}
blob.upload_from_string(
b"",
content_type=artifact.file_data.mime_type or None,
)
else:
raise InputValidationError(
Expand Down Expand Up @@ -263,15 +276,25 @@ def _load_artifact(
blob_name = self._get_blob_name(
app_name, user_id, filename, version, session_id
)
blob = self.bucket.blob(blob_name)
blob = self.bucket.get_blob(blob_name)
if blob is None:
return None

# If the artifact was saved as a file_data URI reference, restore it.
if blob.metadata and "file_uri" in blob.metadata:
return types.Part(
file_data=types.FileData(
file_uri=blob.metadata["file_uri"],
mime_type=blob.content_type or None,
)
)

artifact_bytes = blob.download_as_bytes()
if not artifact_bytes:
return None
artifact = types.Part.from_bytes(
return types.Part.from_bytes(
data=artifact_bytes, mime_type=blob.content_type
)
return artifact

def _list_artifact_keys(
self, app_name: str, user_id: str, session_id: Optional[str]
Expand Down
113 changes: 113 additions & 0 deletions tests/unittests/artifacts/test_artifact_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,119 @@ def test_converts_text_dict(self):
assert result.text == "hello world"


# ---------------------------------------------------------------------------
# GCS file_data (URI reference) tests
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_gcs_save_artifact_with_external_gcs_uri():
"""GcsArtifactService saves and loads a gs:// file_data URI reference."""
service = mock_gcs_artifact_service()
artifact = types.Part(
file_data=types.FileData(
file_uri="gs://my-bucket/report.pdf",
mime_type="application/pdf",
)
)

version = await service.save_artifact(
app_name="app",
user_id="user1",
session_id="sess1",
filename="report.pdf",
artifact=artifact,
)
assert version == 0

loaded = await service.load_artifact(
app_name="app",
user_id="user1",
session_id="sess1",
filename="report.pdf",
)
assert loaded is not None
assert loaded.file_data is not None
assert loaded.file_data.file_uri == "gs://my-bucket/report.pdf"
assert loaded.file_data.mime_type == "application/pdf"


@pytest.mark.asyncio
async def test_gcs_save_artifact_with_artifact_ref_uri():
"""GcsArtifactService saves and loads an internal artifact:// URI reference."""
service = mock_gcs_artifact_service()
artifact_ref_uri = "artifact://apps/app/users/user1/sessions/sess1/artifacts/source.txt/versions/0"
artifact = types.Part(
file_data=types.FileData(
file_uri=artifact_ref_uri,
mime_type="text/plain",
)
)

version = await service.save_artifact(
app_name="app",
user_id="user1",
session_id="sess1",
filename="ref.txt",
artifact=artifact,
)
assert version == 0

loaded = await service.load_artifact(
app_name="app",
user_id="user1",
session_id="sess1",
filename="ref.txt",
)
assert loaded is not None
assert loaded.file_data is not None
assert loaded.file_data.file_uri == artifact_ref_uri


@pytest.mark.asyncio
async def test_gcs_save_artifact_file_data_without_mime_type():
"""GcsArtifactService handles file_data with no mime_type."""
service = mock_gcs_artifact_service()
artifact = types.Part(
file_data=types.FileData(file_uri="gs://my-bucket/data.bin")
)

version = await service.save_artifact(
app_name="app",
user_id="user1",
session_id="sess1",
filename="data.bin",
artifact=artifact,
)
assert version == 0

loaded = await service.load_artifact(
app_name="app",
user_id="user1",
session_id="sess1",
filename="data.bin",
)
assert loaded is not None
assert loaded.file_data is not None
assert loaded.file_data.file_uri == "gs://my-bucket/data.bin"


@pytest.mark.asyncio
async def test_gcs_save_artifact_file_data_missing_uri_raises():
"""GcsArtifactService raises InputValidationError when file_uri is empty."""
service = mock_gcs_artifact_service()
artifact = types.Part(file_data=types.FileData(file_uri=""))

with pytest.raises(InputValidationError):
await service.save_artifact(
app_name="app",
user_id="user1",
session_id="sess1",
filename="empty.bin",
artifact=artifact,
)


@pytest.mark.asyncio
@pytest.mark.parametrize(
"service_type",
Expand Down