diff --git a/src/google/adk/tools/discovery_engine_search_tool.py b/src/google/adk/tools/discovery_engine_search_tool.py index 0e771ece4f..16e925b853 100644 --- a/src/google/adk/tools/discovery_engine_search_tool.py +++ b/src/google/adk/tools/discovery_engine_search_tool.py @@ -17,6 +17,7 @@ from typing import Any from typing import Optional +from google.api_core.client_options import ClientOptions from google.api_core.exceptions import GoogleAPICallError import google.auth from google.cloud import discoveryengine_v1beta as discoveryengine @@ -71,9 +72,29 @@ def __init__( self._filter = filter self._max_results = max_results + # Extract location from data_store_id or search_engine_id + # Format: projects/{project}/locations/{location}/... + resource_id = data_store_id or search_engine_id + location = "global" + if resource_id: + parts = resource_id.split("/") + if "locations" in parts: + try: + loc_index = parts.index("locations") + 1 + if loc_index < len(parts): + location = parts[loc_index] + except ValueError: + pass + + client_options = None + if location != "global": + api_endpoint = f"{location}-discoveryengine.googleapis.com" + client_options = ClientOptions(api_endpoint=api_endpoint) + credentials, _ = google.auth.default() self._discovery_engine_client = discoveryengine.SearchServiceClient( - credentials=credentials + credentials=credentials, + client_options=client_options, ) def discovery_engine_search( diff --git a/tests/unittests/tools/test_discovery_engine_search_tool.py b/tests/unittests/tools/test_discovery_engine_search_tool.py index d10da252c7..0c2061ea24 100644 --- a/tests/unittests/tools/test_discovery_engine_search_tool.py +++ b/tests/unittests/tools/test_discovery_engine_search_tool.py @@ -76,6 +76,18 @@ def test_init_with_data_store_specs_without_search_engine_id_raises_error( data_store_id="test_data_store", data_store_specs=[{"id": "123"}] ) + @mock.patch( + "google.cloud.discoveryengine_v1beta.SearchServiceClient", + ) + def test_init_with_location(self, mock_search_client): + """Test initialization with location extracted from data_store_id.""" + data_store_id = "projects/test-project/locations/us/collections/default_collection/dataStores/test-datastore" + DiscoveryEngineSearchTool(data_store_id=data_store_id) + + # Check if SearchServiceClient was called with correct client_options + args, kwargs = mock_search_client.call_args + assert kwargs["client_options"].api_endpoint == "us-discoveryengine.googleapis.com" + @mock.patch( "google.cloud.discoveryengine_v1beta.SearchServiceClient", )