diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index d8bbab1a95..3d5b382f5b 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -48,6 +48,11 @@ _USAGE_METADATA_CUSTOM_METADATA_KEY = '_usage_metadata' _SESSION_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]+$') +_SESSION_RESOURCE_NAME_PATTERN = re.compile( + r'^projects/[^/]+/locations/[^/]+/' + r'(?:collections/[^/]+/engines/[^/]+|reasoningEngines/[^/]+)/' + r'sessions/([^/]+)$' +) def _validate_session_id(session_id: str) -> None: @@ -61,6 +66,19 @@ def _validate_session_id(session_id: str) -> None: ) +def _normalize_session_id(session_id: str) -> str: + """Returns the plain session ID from a session ID or Vertex resource name.""" + match = ( + _SESSION_RESOURCE_NAME_PATTERN.fullmatch(session_id) + if isinstance(session_id, str) + else None + ) + if match: + session_id = match.group(1) + _validate_session_id(session_id) + return session_id + + def _quote_filter_literal(value: str) -> str: """Quotes filter values so embedded metacharacters stay inside the literal.""" escaped_value = value.replace('\\', '\\\\').replace('"', '\\"') @@ -177,7 +195,7 @@ async def get_session( session_id: str, config: Optional[GetSessionConfig] = None, ) -> Optional[Session]: - _validate_session_id(session_id) + session_id = _normalize_session_id(session_id) reasoning_engine_id = self._get_reasoning_engine_id(app_name) session_resource_name = ( f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}' @@ -277,7 +295,7 @@ async def list_sessions( async def delete_session( self, *, app_name: str, user_id: str, session_id: str ) -> None: - _validate_session_id(session_id) + session_id = _normalize_session_id(session_id) reasoning_engine_id = self._get_reasoning_engine_id(app_name) session_resource_name = ( f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}' diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index b8c71701dc..5ff10cdfbc 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -725,6 +725,53 @@ async def test_get_and_delete_session(): assert str(excinfo.value) == '404 Session not found: 1' +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +@pytest.mark.parametrize( + 'resource_name', + [ + ( + 'projects/my-project/locations/global/collections/' + 'default_collection/engines/my-app/sessions/1' + ), + 'projects/my-project/locations/us-central1/reasoningEngines/123/sessions/1', + ], +) +async def test_get_session_accepts_fully_qualified_resource_name( + resource_name: str, +): + session_service = mock_vertex_ai_session_service() + + session = await session_service.get_session( + app_name='123', + user_id='user', + session_id=resource_name, + ) + + assert session == MOCK_SESSION + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_delete_session_accepts_fully_qualified_resource_name(): + session_service = mock_vertex_ai_session_service() + + await session_service.delete_session( + app_name='123', + user_id='user', + session_id=( + 'projects/my-project/locations/global/collections/default_collection/' + 'engines/my-app/sessions/1' + ), + ) + + with pytest.raises(api_core_exceptions.NotFound) as excinfo: + await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + assert str(excinfo.value) == '404 Session not found: 1' + + @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') async def test_delete_session_rejects_other_users_session(): @@ -753,7 +800,15 @@ async def test_session_id_path_traversal_rejected(): """Session IDs containing path-traversal characters must be rejected.""" session_service = mock_vertex_ai_session_service() - for bad_id in ['..', '../foo', '..?force=true', 'a/b', '']: + for bad_id in [ + '..', + '../foo', + '..?force=true', + 'a/b', + '', + 'projects/my-project/locations/global/sessions/../foo', + 'reasoningEngines/123/sessions/1', + ]: with pytest.raises(ValueError): await session_service.delete_session( app_name='123', user_id='user', session_id=bad_id