Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/google/adk/tools/discovery_engine_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any
from typing import Optional

from google.api_core.client_options import ClientOptions
from google.api_core.exceptions import GoogleAPICallError
import google.auth
from google.cloud import discoveryengine_v1beta as discoveryengine
Expand Down Expand Up @@ -71,9 +72,29 @@ def __init__(
self._filter = filter
self._max_results = max_results

# Extract location from data_store_id or search_engine_id
# Format: projects/{project}/locations/{location}/...
resource_id = data_store_id or search_engine_id
location = "global"
if resource_id:
parts = resource_id.split("/")
if "locations" in parts:
try:
loc_index = parts.index("locations") + 1
if loc_index < len(parts):
location = parts[loc_index]
except ValueError:
pass

client_options = None
if location != "global":
api_endpoint = f"{location}-discoveryengine.googleapis.com"
client_options = ClientOptions(api_endpoint=api_endpoint)

credentials, _ = google.auth.default()
self._discovery_engine_client = discoveryengine.SearchServiceClient(
credentials=credentials
credentials=credentials,
client_options=client_options,
)

def discovery_engine_search(
Expand Down
12 changes: 12 additions & 0 deletions tests/unittests/tools/test_discovery_engine_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ def test_init_with_data_store_specs_without_search_engine_id_raises_error(
data_store_id="test_data_store", data_store_specs=[{"id": "123"}]
)

@mock.patch(
"google.cloud.discoveryengine_v1beta.SearchServiceClient",
)
def test_init_with_location(self, mock_search_client):
"""Test initialization with location extracted from data_store_id."""
data_store_id = "projects/test-project/locations/us/collections/default_collection/dataStores/test-datastore"
DiscoveryEngineSearchTool(data_store_id=data_store_id)

# Check if SearchServiceClient was called with correct client_options
args, kwargs = mock_search_client.call_args
assert kwargs["client_options"].api_endpoint == "us-discoveryengine.googleapis.com"

@mock.patch(
"google.cloud.discoveryengine_v1beta.SearchServiceClient",
)
Expand Down