Skip to content

Commit fb49ae6

Browse files
committed
feat: support file_data URI references in GcsArtifactService #5230
1 parent 8bc5728 commit fb49ae6

File tree

2 files changed

+142
-6
lines changed

2 files changed

+142
-6
lines changed

src/google/adk/artifacts/gcs_artifact_service.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from google.genai import types
3333
from typing_extensions import override
3434

35+
from . import artifact_util
3536
from ..errors.input_validation_error import InputValidationError
3637
from .base_artifact_service import ArtifactVersion
3738
from .base_artifact_service import BaseArtifactService
@@ -230,9 +231,21 @@ def _save_artifact(
230231
content_type="text/plain",
231232
)
232233
elif artifact.file_data:
233-
raise NotImplementedError(
234-
"Saving artifact with file_data is not supported yet in"
235-
" GcsArtifactService."
234+
if not artifact.file_data.file_uri:
235+
raise InputValidationError("Artifact file_data must have a file_uri.")
236+
if artifact_util.is_artifact_ref(artifact):
237+
if not artifact_util.parse_artifact_uri(artifact.file_data.file_uri):
238+
raise InputValidationError(
239+
f"Invalid artifact reference URI: {artifact.file_data.file_uri}"
240+
)
241+
# Store the URI as blob metadata; no content to upload.
242+
blob.metadata = {
243+
**(blob.metadata or {}),
244+
"file_uri": artifact.file_data.file_uri,
245+
}
246+
blob.upload_from_string(
247+
b"",
248+
content_type=artifact.file_data.mime_type or None,
236249
)
237250
else:
238251
raise InputValidationError(
@@ -263,15 +276,25 @@ def _load_artifact(
263276
blob_name = self._get_blob_name(
264277
app_name, user_id, filename, version, session_id
265278
)
266-
blob = self.bucket.blob(blob_name)
279+
blob = self.bucket.get_blob(blob_name)
280+
if blob is None:
281+
return None
282+
283+
# If the artifact was saved as a file_data URI reference, restore it.
284+
if blob.metadata and "file_uri" in blob.metadata:
285+
return types.Part(
286+
file_data=types.FileData(
287+
file_uri=blob.metadata["file_uri"],
288+
mime_type=blob.content_type or None,
289+
)
290+
)
267291

268292
artifact_bytes = blob.download_as_bytes()
269293
if not artifact_bytes:
270294
return None
271-
artifact = types.Part.from_bytes(
295+
return types.Part.from_bytes(
272296
data=artifact_bytes, mime_type=blob.content_type
273297
)
274-
return artifact
275298

276299
def _list_artifact_keys(
277300
self, app_name: str, user_id: str, session_id: Optional[str]

tests/unittests/artifacts/test_artifact_service.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,119 @@ def test_converts_text_dict(self):
865865
assert result.text == "hello world"
866866

867867

868+
# ---------------------------------------------------------------------------
869+
# GCS file_data (URI reference) tests
870+
# ---------------------------------------------------------------------------
871+
872+
873+
@pytest.mark.asyncio
874+
async def test_gcs_save_artifact_with_external_gcs_uri():
875+
"""GcsArtifactService saves and loads a gs:// file_data URI reference."""
876+
service = mock_gcs_artifact_service()
877+
artifact = types.Part(
878+
file_data=types.FileData(
879+
file_uri="gs://my-bucket/report.pdf",
880+
mime_type="application/pdf",
881+
)
882+
)
883+
884+
version = await service.save_artifact(
885+
app_name="app",
886+
user_id="user1",
887+
session_id="sess1",
888+
filename="report.pdf",
889+
artifact=artifact,
890+
)
891+
assert version == 0
892+
893+
loaded = await service.load_artifact(
894+
app_name="app",
895+
user_id="user1",
896+
session_id="sess1",
897+
filename="report.pdf",
898+
)
899+
assert loaded is not None
900+
assert loaded.file_data is not None
901+
assert loaded.file_data.file_uri == "gs://my-bucket/report.pdf"
902+
assert loaded.file_data.mime_type == "application/pdf"
903+
904+
905+
@pytest.mark.asyncio
906+
async def test_gcs_save_artifact_with_artifact_ref_uri():
907+
"""GcsArtifactService saves and loads an internal artifact:// URI reference."""
908+
service = mock_gcs_artifact_service()
909+
artifact_ref_uri = "artifact://apps/app/users/user1/sessions/sess1/artifacts/source.txt/versions/0"
910+
artifact = types.Part(
911+
file_data=types.FileData(
912+
file_uri=artifact_ref_uri,
913+
mime_type="text/plain",
914+
)
915+
)
916+
917+
version = await service.save_artifact(
918+
app_name="app",
919+
user_id="user1",
920+
session_id="sess1",
921+
filename="ref.txt",
922+
artifact=artifact,
923+
)
924+
assert version == 0
925+
926+
loaded = await service.load_artifact(
927+
app_name="app",
928+
user_id="user1",
929+
session_id="sess1",
930+
filename="ref.txt",
931+
)
932+
assert loaded is not None
933+
assert loaded.file_data is not None
934+
assert loaded.file_data.file_uri == artifact_ref_uri
935+
936+
937+
@pytest.mark.asyncio
938+
async def test_gcs_save_artifact_file_data_without_mime_type():
939+
"""GcsArtifactService handles file_data with no mime_type."""
940+
service = mock_gcs_artifact_service()
941+
artifact = types.Part(
942+
file_data=types.FileData(file_uri="gs://my-bucket/data.bin")
943+
)
944+
945+
version = await service.save_artifact(
946+
app_name="app",
947+
user_id="user1",
948+
session_id="sess1",
949+
filename="data.bin",
950+
artifact=artifact,
951+
)
952+
assert version == 0
953+
954+
loaded = await service.load_artifact(
955+
app_name="app",
956+
user_id="user1",
957+
session_id="sess1",
958+
filename="data.bin",
959+
)
960+
assert loaded is not None
961+
assert loaded.file_data is not None
962+
assert loaded.file_data.file_uri == "gs://my-bucket/data.bin"
963+
964+
965+
@pytest.mark.asyncio
966+
async def test_gcs_save_artifact_file_data_missing_uri_raises():
967+
"""GcsArtifactService raises InputValidationError when file_uri is empty."""
968+
service = mock_gcs_artifact_service()
969+
artifact = types.Part(file_data=types.FileData(file_uri=""))
970+
971+
with pytest.raises(InputValidationError):
972+
await service.save_artifact(
973+
app_name="app",
974+
user_id="user1",
975+
session_id="sess1",
976+
filename="empty.bin",
977+
artifact=artifact,
978+
)
979+
980+
868981
@pytest.mark.asyncio
869982
@pytest.mark.parametrize(
870983
"service_type",

0 commit comments

Comments
 (0)