Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,12 @@
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import (
VertexRagDataServiceClient,
)
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import pagers
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import transports
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import (
pagers,
)
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import (
transports,
)
from google.cloud.aiplatform_v1beta1.types import api_auth
from google.cloud.aiplatform_v1beta1.types import encryption_spec
from google.cloud.aiplatform_v1beta1.types import io
Expand Down Expand Up @@ -5294,6 +5298,79 @@ async def test_delete_rag_file_flattened_error_async():
)


@pytest.mark.parametrize(
"request_type",
[
vertex_rag_data_service.BatchCreateRagDataSchemasRequest,
dict,
],
)
def test_batch_create_rag_data_schemas(request_type, transport: str = "grpc"):
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
transport=transport,
)

# Everything is optional in proto3 as far as the runtime is concerned,
# and we are mocking out the actual API, so just send an empty request.
request = request_type()

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client.transport.batch_create_rag_data_schemas), "__call__"
) as call:
# Designate an appropriate return value for the call.
call.return_value = operations_pb2.Operation(name="operations/spam")
response = client.batch_create_rag_data_schemas(request)

# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls) == 1
_, args, _ = call.mock_calls[0]
request = vertex_rag_data_service.BatchCreateRagDataSchemasRequest()
assert args[0] == request

# Establish that the response is the type that we expect.
assert isinstance(response, future.Future)


@pytest.mark.parametrize(
"request_type",
[
vertex_rag_data_service.ListRagDataSchemasRequest,
dict,
],
)
def test_list_rag_data_schemas(request_type, transport: str = "grpc"):
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
transport=transport,
)

# Everything is optional in proto3 as far as the runtime is concerned,
# and we are mocking out the actual API, so just send an empty request.
request = request_type()

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client.transport.list_rag_data_schemas), "__call__"
) as call:
# Designate an appropriate return value for the call.
call.return_value = vertex_rag_data_service.ListRagDataSchemasResponse(
next_page_token="next_page_token_value",
)
response = client.list_rag_data_schemas(request)

# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls) == 1
_, args, _ = call.mock_calls[0]
request = vertex_rag_data_service.ListRagDataSchemasRequest()
assert args[0] == request

# Establish that the response is the type that we expect.
assert isinstance(response, pagers.ListRagDataSchemasPager)
assert response.next_page_token == "next_page_token_value"


@pytest.mark.parametrize(
"request_type",
[
Expand Down
33 changes: 21 additions & 12 deletions tests/unit/vertex_rag/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from unittest.mock import patch
from unittest import mock
from google import auth
from google.api_core import operation as ga_operation
from google.auth import credentials as auth_credentials
from vertexai import rag
from vertexai.preview import rag as rag_preview
from google.cloud.aiplatform_v1 import (
DeleteRagCorpusRequest,
VertexRagDataServiceAsyncClient,
Expand Down Expand Up @@ -51,17 +48,19 @@ def google_auth_mock():

@pytest.fixture
def authorized_session_mock():
with patch(
"google.auth.transport.requests.AuthorizedSession"
) as MockAuthorizedSession:
from google.auth.transport import requests

with mock.patch.object(requests, "AuthorizedSession") as MockAuthorizedSession:
mock_auth_session = MockAuthorizedSession(_TEST_CREDENTIALS)
yield mock_auth_session


@pytest.fixture
def rag_data_client_mock():
from vertexai.rag.utils import _gapic_utils

with mock.patch.object(
rag.utils._gapic_utils, "create_rag_data_service_client"
_gapic_utils, "create_rag_data_service_client"
) as rag_data_client_mock:
api_client_mock = mock.Mock(spec=VertexRagDataServiceClient)

Expand All @@ -84,8 +83,10 @@ def rag_data_client_mock():

@pytest.fixture
def rag_data_client_preview_mock():
from vertexai.preview.rag.utils import _gapic_utils

with mock.patch.object(
rag_preview.utils._gapic_utils, "create_rag_data_service_client"
_gapic_utils, "create_rag_data_service_client"
) as rag_data_client_mock:
api_client_mock = mock.Mock(spec=VertexRagDataServiceClientPreview)

Expand All @@ -108,8 +109,10 @@ def rag_data_client_preview_mock():

@pytest.fixture
def rag_data_client_mock_exception():
from vertexai.rag.utils import _gapic_utils

with mock.patch.object(
rag.utils._gapic_utils, "create_rag_data_service_client"
_gapic_utils, "create_rag_data_service_client"
) as rag_data_client_mock_exception:
api_client_mock = mock.Mock(spec=VertexRagDataServiceClient)
# create_rag_corpus
Expand Down Expand Up @@ -138,8 +141,10 @@ def rag_data_client_mock_exception():

@pytest.fixture
def rag_data_client_preview_mock_exception():
from vertexai.preview.rag.utils import _gapic_utils

with mock.patch.object(
rag_preview.utils._gapic_utils, "create_rag_data_service_client"
_gapic_utils, "create_rag_data_service_client"
) as rag_data_client_mock_exception:
api_client_mock = mock.Mock(spec=VertexRagDataServiceClientPreview)
# create_rag_corpus
Expand Down Expand Up @@ -172,8 +177,10 @@ def rag_data_client_preview_mock_exception():

@pytest.fixture
def rag_data_async_client_mock_exception():
from vertexai.rag.utils import _gapic_utils

with mock.patch.object(
rag.utils._gapic_utils, "create_rag_data_service_async_client"
_gapic_utils, "create_rag_data_service_async_client"
) as rag_data_async_client_mock_exception:
api_client_mock = mock.Mock(spec=VertexRagDataServiceAsyncClient)
# import_rag_files
Expand All @@ -184,8 +191,10 @@ def rag_data_async_client_mock_exception():

@pytest.fixture
def rag_data_async_client_preview_mock_exception():
from vertexai.preview.rag.utils import _gapic_utils

with mock.patch.object(
rag_preview.utils._gapic_utils, "create_rag_data_service_async_client"
_gapic_utils, "create_rag_data_service_async_client"
) as rag_data_async_client_mock_exception:
api_client_mock = mock.Mock(spec=VertexRagDataServiceAsyncClientPreview)
# import_rag_files
Expand Down
61 changes: 61 additions & 0 deletions tests/unit/vertex_rag/test_rag_constants_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,23 @@
ImportRagFilesRequest,
ImportRagFilesResponse,
JiraSource as GapicJiraSource,
MetadataValue as GapicMetadataValue,
RagContexts,
RagCorpus as GapicRagCorpus,
RagDataSchema as GapicRagDataSchema,
RagEngineConfig as GapicRagEngineConfig,
RagFileChunkingConfig,
RagFileParsingConfig,
RagFileTransformationConfig,
RagFile as GapicRagFile,
RagManagedDbConfig as GapicRagManagedDbConfig,
RagMetadataSchemaDetails as GapicRagMetadataSchemaDetails,
RagMetadata as GapicRagMetadata,
RagVectorDbConfig as GapicRagVectorDbConfig,
RetrieveContextsResponse,
SharePointSources as GapicSharePointSources,
SlackSource as GapicSlackSource,
UserSpecifiedMetadata as GapicUserSpecifiedMetadata,
VertexAiSearchConfig as GapicVertexAiSearchConfig,
)
from google.cloud.aiplatform_v1beta1.types import api_auth
Expand All @@ -54,15 +59,19 @@
LlmParserConfig,
LlmRanker,
MemoryCorpus,
MetadataValue,
Pinecone,
RagCorpus,
RagCorpusTypeConfig,
RagDataSchema,
RagEmbeddingModelConfig,
RagEngineConfig,
RagFile,
RagManagedDb,
RagManagedDbConfig,
RagManagedVertexVectorSearch,
RagMetadata,
RagMetadataSchemaDetails,
RagResource,
RagRetrievalConfig,
RagVectorDbConfig,
Expand All @@ -76,6 +85,7 @@
SlackChannelsSource,
Spanner,
Unprovisioned,
UserSpecifiedMetadata,
VertexAiSearchConfig,
VertexFeatureStore,
VertexPredictionEndpoint,
Expand Down Expand Up @@ -1146,3 +1156,54 @@
filter=Filter(vector_distance_threshold=0.5),
ranking=Ranking(llm_ranker=LlmRanker(model_name="test-model-name")),
)

# RagMetadata and RagDataSchema
TEST_RAG_DATA_SCHEMA_ID = "test-data-schema-id"
TEST_RAG_DATA_SCHEMA_RESOURCE_NAME = (
f"{TEST_RAG_CORPUS_RESOURCE_NAME}/ragDataSchemas/{TEST_RAG_DATA_SCHEMA_ID}"
)
TEST_RAG_METADATA_ID = "test-metadata-id"
TEST_RAG_METADATA_RESOURCE_NAME = (
f"{TEST_RAG_FILE_RESOURCE_NAME}/ragMetadata/{TEST_RAG_METADATA_ID}"
)

TEST_GAPIC_RAG_DATA_SCHEMA = GapicRagDataSchema(
name=TEST_RAG_DATA_SCHEMA_RESOURCE_NAME,
key="key1",
schema_details=GapicRagMetadataSchemaDetails(
type=GapicRagMetadataSchemaDetails.DataType.STRING,
search_strategy=GapicRagMetadataSchemaDetails.SearchStrategy(
search_strategy_type=GapicRagMetadataSchemaDetails.SearchStrategy.SearchStrategyType.EXACT_SEARCH
),
granularity=GapicRagMetadataSchemaDetails.Granularity.GRANULARITY_FILE_LEVEL,
),
)

TEST_RAG_DATA_SCHEMA = RagDataSchema(
name=TEST_RAG_DATA_SCHEMA_RESOURCE_NAME,
key="key1",
schema_details=RagMetadataSchemaDetails(
type="STRING",
search_strategy=RagMetadataSchemaDetails.SearchStrategy(
search_strategy_type="EXACT_SEARCH"
),
granularity="GRANULARITY_FILE_LEVEL",
),
)

TEST_GAPIC_RAG_METADATA = GapicRagMetadata(
name=TEST_RAG_METADATA_RESOURCE_NAME,
user_specified_metadata=GapicUserSpecifiedMetadata(
key="key1",
value=GapicMetadataValue(str_value="value1"),
),
)

TEST_RAG_METADATA = RagMetadata(
name=TEST_RAG_METADATA_RESOURCE_NAME,
user_specified_metadata=UserSpecifiedMetadata(
values={
"key1": MetadataValue(string_value="value1"),
}
),
)
Loading
Loading