diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py index 10fdaf52e3..12dd9aed7f 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py @@ -72,8 +72,12 @@ from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import ( VertexRagDataServiceClient, ) -from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import pagers -from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import transports +from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import ( + pagers, +) +from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import ( + transports, +) from google.cloud.aiplatform_v1beta1.types import api_auth from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import io @@ -5294,6 +5298,79 @@ async def test_delete_rag_file_flattened_error_async(): ) +@pytest.mark.parametrize( + "request_type", + [ + vertex_rag_data_service.BatchCreateRagDataSchemasRequest, + dict, + ], +) +def test_batch_create_rag_data_schemas(request_type, transport: str = "grpc"): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_create_rag_data_schemas), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.batch_create_rag_data_schemas(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = vertex_rag_data_service.BatchCreateRagDataSchemasRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.parametrize( + "request_type", + [ + vertex_rag_data_service.ListRagDataSchemasRequest, + dict, + ], +) +def test_list_rag_data_schemas(request_type, transport: str = "grpc"): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_rag_data_schemas), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = vertex_rag_data_service.ListRagDataSchemasResponse( + next_page_token="next_page_token_value", + ) + response = client.list_rag_data_schemas(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = vertex_rag_data_service.ListRagDataSchemasRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListRagDataSchemasPager) + assert response.next_page_token == "next_page_token_value" + + @pytest.mark.parametrize( "request_type", [ diff --git a/tests/unit/vertex_rag/conftest.py b/tests/unit/vertex_rag/conftest.py index 3d3678ab0f..f9bbe134de 100644 --- a/tests/unit/vertex_rag/conftest.py +++ b/tests/unit/vertex_rag/conftest.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from unittest.mock import patch from unittest import mock from google import auth from google.api_core import operation as ga_operation from google.auth import credentials as auth_credentials -from vertexai import rag -from vertexai.preview import rag as rag_preview from google.cloud.aiplatform_v1 import ( DeleteRagCorpusRequest, VertexRagDataServiceAsyncClient, @@ -51,17 +48,19 @@ def google_auth_mock(): @pytest.fixture def authorized_session_mock(): - with patch( - "google.auth.transport.requests.AuthorizedSession" - ) as MockAuthorizedSession: + from google.auth.transport import requests + + with mock.patch.object(requests, "AuthorizedSession") as MockAuthorizedSession: mock_auth_session = MockAuthorizedSession(_TEST_CREDENTIALS) yield mock_auth_session @pytest.fixture def rag_data_client_mock(): + from vertexai.rag.utils import _gapic_utils + with mock.patch.object( - rag.utils._gapic_utils, "create_rag_data_service_client" + _gapic_utils, "create_rag_data_service_client" ) as rag_data_client_mock: api_client_mock = mock.Mock(spec=VertexRagDataServiceClient) @@ -84,8 +83,10 @@ def rag_data_client_mock(): @pytest.fixture def rag_data_client_preview_mock(): + from vertexai.preview.rag.utils import _gapic_utils + with mock.patch.object( - rag_preview.utils._gapic_utils, "create_rag_data_service_client" + _gapic_utils, "create_rag_data_service_client" ) as rag_data_client_mock: api_client_mock = mock.Mock(spec=VertexRagDataServiceClientPreview) @@ -108,8 +109,10 @@ def rag_data_client_preview_mock(): @pytest.fixture def rag_data_client_mock_exception(): + from vertexai.rag.utils import _gapic_utils + with mock.patch.object( - rag.utils._gapic_utils, "create_rag_data_service_client" + _gapic_utils, "create_rag_data_service_client" ) as rag_data_client_mock_exception: api_client_mock = mock.Mock(spec=VertexRagDataServiceClient) # create_rag_corpus @@ -138,8 +141,10 @@ def rag_data_client_mock_exception(): @pytest.fixture def rag_data_client_preview_mock_exception(): + from vertexai.preview.rag.utils import _gapic_utils + with mock.patch.object( - rag_preview.utils._gapic_utils, "create_rag_data_service_client" + _gapic_utils, "create_rag_data_service_client" ) as rag_data_client_mock_exception: api_client_mock = mock.Mock(spec=VertexRagDataServiceClientPreview) # create_rag_corpus @@ -172,8 +177,10 @@ def rag_data_client_preview_mock_exception(): @pytest.fixture def rag_data_async_client_mock_exception(): + from vertexai.rag.utils import _gapic_utils + with mock.patch.object( - rag.utils._gapic_utils, "create_rag_data_service_async_client" + _gapic_utils, "create_rag_data_service_async_client" ) as rag_data_async_client_mock_exception: api_client_mock = mock.Mock(spec=VertexRagDataServiceAsyncClient) # import_rag_files @@ -184,8 +191,10 @@ def rag_data_async_client_mock_exception(): @pytest.fixture def rag_data_async_client_preview_mock_exception(): + from vertexai.preview.rag.utils import _gapic_utils + with mock.patch.object( - rag_preview.utils._gapic_utils, "create_rag_data_service_async_client" + _gapic_utils, "create_rag_data_service_async_client" ) as rag_data_async_client_mock_exception: api_client_mock = mock.Mock(spec=VertexRagDataServiceAsyncClientPreview) # import_rag_files diff --git a/tests/unit/vertex_rag/test_rag_constants_preview.py b/tests/unit/vertex_rag/test_rag_constants_preview.py index 0c0f3c810c..9819137268 100644 --- a/tests/unit/vertex_rag/test_rag_constants_preview.py +++ b/tests/unit/vertex_rag/test_rag_constants_preview.py @@ -23,18 +23,23 @@ ImportRagFilesRequest, ImportRagFilesResponse, JiraSource as GapicJiraSource, + MetadataValue as GapicMetadataValue, RagContexts, RagCorpus as GapicRagCorpus, + RagDataSchema as GapicRagDataSchema, RagEngineConfig as GapicRagEngineConfig, RagFileChunkingConfig, RagFileParsingConfig, RagFileTransformationConfig, RagFile as GapicRagFile, RagManagedDbConfig as GapicRagManagedDbConfig, + RagMetadataSchemaDetails as GapicRagMetadataSchemaDetails, + RagMetadata as GapicRagMetadata, RagVectorDbConfig as GapicRagVectorDbConfig, RetrieveContextsResponse, SharePointSources as GapicSharePointSources, SlackSource as GapicSlackSource, + UserSpecifiedMetadata as GapicUserSpecifiedMetadata, VertexAiSearchConfig as GapicVertexAiSearchConfig, ) from google.cloud.aiplatform_v1beta1.types import api_auth @@ -54,15 +59,19 @@ LlmParserConfig, LlmRanker, MemoryCorpus, + MetadataValue, Pinecone, RagCorpus, RagCorpusTypeConfig, + RagDataSchema, RagEmbeddingModelConfig, RagEngineConfig, RagFile, RagManagedDb, RagManagedDbConfig, RagManagedVertexVectorSearch, + RagMetadata, + RagMetadataSchemaDetails, RagResource, RagRetrievalConfig, RagVectorDbConfig, @@ -76,6 +85,7 @@ SlackChannelsSource, Spanner, Unprovisioned, + UserSpecifiedMetadata, VertexAiSearchConfig, VertexFeatureStore, VertexPredictionEndpoint, @@ -1146,3 +1156,54 @@ filter=Filter(vector_distance_threshold=0.5), ranking=Ranking(llm_ranker=LlmRanker(model_name="test-model-name")), ) + +# RagMetadata and RagDataSchema +TEST_RAG_DATA_SCHEMA_ID = "test-data-schema-id" +TEST_RAG_DATA_SCHEMA_RESOURCE_NAME = ( + f"{TEST_RAG_CORPUS_RESOURCE_NAME}/ragDataSchemas/{TEST_RAG_DATA_SCHEMA_ID}" +) +TEST_RAG_METADATA_ID = "test-metadata-id" +TEST_RAG_METADATA_RESOURCE_NAME = ( + f"{TEST_RAG_FILE_RESOURCE_NAME}/ragMetadata/{TEST_RAG_METADATA_ID}" +) + +TEST_GAPIC_RAG_DATA_SCHEMA = GapicRagDataSchema( + name=TEST_RAG_DATA_SCHEMA_RESOURCE_NAME, + key="key1", + schema_details=GapicRagMetadataSchemaDetails( + type=GapicRagMetadataSchemaDetails.DataType.STRING, + search_strategy=GapicRagMetadataSchemaDetails.SearchStrategy( + search_strategy_type=GapicRagMetadataSchemaDetails.SearchStrategy.SearchStrategyType.EXACT_SEARCH + ), + granularity=GapicRagMetadataSchemaDetails.Granularity.GRANULARITY_FILE_LEVEL, + ), +) + +TEST_RAG_DATA_SCHEMA = RagDataSchema( + name=TEST_RAG_DATA_SCHEMA_RESOURCE_NAME, + key="key1", + schema_details=RagMetadataSchemaDetails( + type="STRING", + search_strategy=RagMetadataSchemaDetails.SearchStrategy( + search_strategy_type="EXACT_SEARCH" + ), + granularity="GRANULARITY_FILE_LEVEL", + ), +) + +TEST_GAPIC_RAG_METADATA = GapicRagMetadata( + name=TEST_RAG_METADATA_RESOURCE_NAME, + user_specified_metadata=GapicUserSpecifiedMetadata( + key="key1", + value=GapicMetadataValue(str_value="value1"), + ), +) + +TEST_RAG_METADATA = RagMetadata( + name=TEST_RAG_METADATA_RESOURCE_NAME, + user_specified_metadata=UserSpecifiedMetadata( + values={ + "key1": MetadataValue(string_value="value1"), + } + ), +) diff --git a/tests/unit/vertex_rag/test_rag_data_preview.py b/tests/unit/vertex_rag/test_rag_data_preview.py index b1e7d4c3b0..6b558a5836 100644 --- a/tests/unit/vertex_rag/test_rag_data_preview.py +++ b/tests/unit/vertex_rag/test_rag_data_preview.py @@ -30,6 +30,10 @@ VertexRagDataServiceClient, ListRagCorporaResponse, ListRagFilesResponse, + BatchCreateRagDataSchemasResponse, + BatchCreateRagMetadataResponse, + ListRagDataSchemasResponse, + ListRagMetadataResponse, ) from google.cloud import aiplatform import mock @@ -809,6 +813,149 @@ def list_rag_files_pager_mock(): yield list_rag_files_pager_mock +@pytest.fixture +def batch_create_rag_data_schemas_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "batch_create_rag_data_schemas", + ) as batch_create_rag_data_schemas_mock: + batch_create_rag_data_schemas_lro_mock = mock.Mock(ga_operation.Operation) + batch_create_rag_data_schemas_lro_mock.done.return_value = True + batch_create_rag_data_schemas_lro_mock.result.return_value = ( + BatchCreateRagDataSchemasResponse( + rag_data_schemas=[test_rag_constants_preview.TEST_GAPIC_RAG_DATA_SCHEMA] + ) + ) + batch_create_rag_data_schemas_mock.return_value = ( + batch_create_rag_data_schemas_lro_mock + ) + yield batch_create_rag_data_schemas_mock + + +@pytest.fixture +def batch_delete_rag_data_schemas_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "batch_delete_rag_data_schemas", + ) as batch_delete_rag_data_schemas_mock: + batch_delete_rag_data_schemas_lro_mock = mock.Mock(ga_operation.Operation) + batch_delete_rag_data_schemas_lro_mock.done.return_value = True + batch_delete_rag_data_schemas_mock.return_value = ( + batch_delete_rag_data_schemas_lro_mock + ) + yield batch_delete_rag_data_schemas_mock + + +@pytest.fixture +def list_rag_data_schemas_pager_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "list_rag_data_schemas", + ) as list_rag_data_schemas_pager_mock: + list_rag_data_schemas_pager_mock.return_value = [ + ListRagDataSchemasResponse( + rag_data_schemas=[ + test_rag_constants_preview.TEST_GAPIC_RAG_DATA_SCHEMA, + ], + next_page_token=test_rag_constants_preview.TEST_PAGE_TOKEN, + ), + ] + yield list_rag_data_schemas_pager_mock + + +@pytest.fixture +def batch_create_rag_metadata_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "batch_create_rag_metadata", + ) as batch_create_rag_metadata_mock: + batch_create_rag_metadata_lro_mock = mock.Mock(ga_operation.Operation) + batch_create_rag_metadata_lro_mock.done.return_value = True + batch_create_rag_metadata_lro_mock.result.return_value = ( + BatchCreateRagMetadataResponse( + rag_metadata=[test_rag_constants_preview.TEST_GAPIC_RAG_METADATA] + ) + ) + batch_create_rag_metadata_mock.return_value = batch_create_rag_metadata_lro_mock + yield batch_create_rag_metadata_mock + + +@pytest.fixture +def batch_delete_rag_metadata_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "batch_delete_rag_metadata", + ) as batch_delete_rag_metadata_mock: + batch_delete_rag_metadata_lro_mock = mock.Mock(ga_operation.Operation) + batch_delete_rag_metadata_lro_mock.done.return_value = True + batch_delete_rag_metadata_mock.return_value = batch_delete_rag_metadata_lro_mock + yield batch_delete_rag_metadata_mock + + +@pytest.fixture +def list_rag_metadata_pager_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "list_rag_metadata", + ) as list_rag_metadata_pager_mock: + list_rag_metadata_pager_mock.return_value = [ + ListRagMetadataResponse( + rag_metadata=[ + test_rag_constants_preview.TEST_GAPIC_RAG_METADATA, + ], + next_page_token=test_rag_constants_preview.TEST_PAGE_TOKEN, + ), + ] + yield list_rag_metadata_pager_mock + + +@pytest.fixture +def update_rag_metadata_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "update_rag_metadata", + ) as update_rag_metadata_mock: + update_rag_metadata_mock.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_METADATA + ) + yield update_rag_metadata_mock + + +@pytest.fixture +def rag_data_client_preview_mock_exception(): + from vertexai.preview.rag.utils import _gapic_utils + + with mock.patch.object( + _gapic_utils, "create_rag_data_service_client" + ) as create_client_mock: + create_client_mock.return_value.create_rag_corpus.side_effect = Exception + create_client_mock.return_value.get_rag_corpus.side_effect = Exception + create_client_mock.return_value.update_rag_corpus.side_effect = Exception + create_client_mock.return_value.list_rag_corpora.side_effect = Exception + create_client_mock.return_value.delete_rag_corpus.side_effect = Exception + create_client_mock.return_value.upload_rag_file.side_effect = Exception + create_client_mock.return_value.import_rag_files.side_effect = Exception + create_client_mock.return_value.get_rag_file.side_effect = Exception + create_client_mock.return_value.list_rag_files.side_effect = Exception + create_client_mock.return_value.delete_rag_file.side_effect = Exception + create_client_mock.return_value.batch_create_rag_data_schemas.side_effect = ( + Exception + ) + create_client_mock.return_value.batch_delete_rag_data_schemas.side_effect = ( + Exception + ) + create_client_mock.return_value.list_rag_data_schemas.side_effect = Exception + create_client_mock.return_value.batch_create_rag_metadata.side_effect = ( + Exception + ) + create_client_mock.return_value.batch_delete_rag_metadata.side_effect = ( + Exception + ) + create_client_mock.return_value.list_rag_metadata.side_effect = Exception + create_client_mock.return_value.update_rag_metadata.side_effect = Exception + yield create_client_mock + + def create_transformation_config( chunk_size: int = test_rag_constants_preview.TEST_CHUNK_SIZE, chunk_overlap: int = test_rag_constants_preview.TEST_CHUNK_OVERLAP, @@ -821,6 +968,28 @@ def create_transformation_config( ) +def rag_data_schema_eq(returned_schema, expected_schema): + assert returned_schema.name == expected_schema.name + assert returned_schema.key == expected_schema.key + assert returned_schema.schema_details.type == expected_schema.schema_details.type + assert ( + returned_schema.schema_details.search_strategy.search_strategy_type + == expected_schema.schema_details.search_strategy.search_strategy_type + ) + assert ( + returned_schema.schema_details.granularity + == expected_schema.schema_details.granularity + ) + + +def rag_metadata_eq(returned_metadata, expected_metadata): + assert returned_metadata.name == expected_metadata.name + assert ( + returned_metadata.user_specified_metadata + == expected_metadata.user_specified_metadata + ) + + def rag_corpus_eq(returned_corpus, expected_corpus): assert returned_corpus.name == expected_corpus.name assert returned_corpus.display_name == expected_corpus.display_name @@ -1850,6 +2019,129 @@ def test_set_embedding_model_config_wrong_endpoint_format_error(self): ) e.match("endpoint must be of the format ") + def test_batch_create_data_schemas_success( + self, batch_create_rag_data_schemas_mock + ): + schemas = rag.batch_create_data_schemas( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + requests=[test_rag_constants_preview.TEST_RAG_DATA_SCHEMA], + ) + batch_create_rag_data_schemas_mock.assert_called_once() + rag_data_schema_eq(schemas[0], test_rag_constants_preview.TEST_RAG_DATA_SCHEMA) + + @pytest.mark.usefixtures("rag_data_client_preview_mock_exception") + def test_batch_create_data_schemas_failure(self): + with pytest.raises(RuntimeError) as e: + rag.batch_create_data_schemas( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + requests=[test_rag_constants_preview.TEST_RAG_DATA_SCHEMA], + ) + e.match("Failed in RagDataSchema batch creation due to") + + def test_batch_delete_data_schemas_success( + self, batch_delete_rag_data_schemas_mock + ): + rag.batch_delete_data_schemas( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + names=[test_rag_constants_preview.TEST_RAG_DATA_SCHEMA_ID], + ) + batch_delete_rag_data_schemas_mock.assert_called_once() + + @pytest.mark.usefixtures("rag_data_client_preview_mock_exception") + def test_batch_delete_data_schemas_failure(self): + with pytest.raises(RuntimeError) as e: + rag.batch_delete_data_schemas( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + names=[test_rag_constants_preview.TEST_RAG_DATA_SCHEMA_ID], + ) + e.match("Failed in RagDataSchema batch deletion due to") + + def test_list_data_schemas_success(self, list_rag_data_schemas_pager_mock): + pager = rag.list_data_schemas( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + page_size=1, + ) + list_rag_data_schemas_pager_mock.assert_called_once() + assert pager[0].next_page_token == test_rag_constants_preview.TEST_PAGE_TOKEN + + @pytest.mark.usefixtures("rag_data_client_preview_mock_exception") + def test_list_data_schemas_failure(self): + with pytest.raises(RuntimeError) as e: + rag.list_data_schemas( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME + ) + e.match("Failed in listing the RagDataSchemas due to") + + def test_batch_create_metadata_success(self, batch_create_rag_metadata_mock): + metadata = rag.batch_create_metadata( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + file_name=test_rag_constants_preview.TEST_RAG_FILE_RESOURCE_NAME, + requests=[test_rag_constants_preview.TEST_RAG_METADATA], + ) + batch_create_rag_metadata_mock.assert_called_once() + rag_metadata_eq(metadata[0], test_rag_constants_preview.TEST_RAG_METADATA) + + @pytest.mark.usefixtures("rag_data_client_preview_mock_exception") + def test_batch_create_metadata_failure(self): + with pytest.raises(RuntimeError) as e: + rag.batch_create_metadata( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + file_name=test_rag_constants_preview.TEST_RAG_FILE_RESOURCE_NAME, + requests=[test_rag_constants_preview.TEST_RAG_METADATA], + ) + e.match("Failed in RagMetadata batch creation due to") + + def test_batch_delete_metadata_success(self, batch_delete_rag_metadata_mock): + rag.batch_delete_metadata( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + file_name=test_rag_constants_preview.TEST_RAG_FILE_RESOURCE_NAME, + names=[test_rag_constants_preview.TEST_RAG_METADATA_ID], + ) + batch_delete_rag_metadata_mock.assert_called_once() + + @pytest.mark.usefixtures("rag_data_client_preview_mock_exception") + def test_batch_delete_metadata_failure(self): + with pytest.raises(RuntimeError) as e: + rag.batch_delete_metadata( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + file_name=test_rag_constants_preview.TEST_RAG_FILE_RESOURCE_NAME, + names=[test_rag_constants_preview.TEST_RAG_METADATA_ID], + ) + e.match("Failed in RagMetadata batch deletion due to") + + def test_list_metadata_success(self, list_rag_metadata_pager_mock): + pager = rag.list_metadata( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + file_name=test_rag_constants_preview.TEST_RAG_FILE_RESOURCE_NAME, + page_size=1, + ) + list_rag_metadata_pager_mock.assert_called_once() + assert pager[0].next_page_token == test_rag_constants_preview.TEST_PAGE_TOKEN + + @pytest.mark.usefixtures("rag_data_client_preview_mock_exception") + def test_list_metadata_failure(self): + with pytest.raises(RuntimeError) as e: + rag.list_metadata( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + file_name=test_rag_constants_preview.TEST_RAG_FILE_RESOURCE_NAME, + ) + e.match("Failed in listing the RagMetadata due to") + + def test_update_metadata_success(self, update_rag_metadata_mock): + metadata = rag.update_metadata( + rag_metadata=test_rag_constants_preview.TEST_RAG_METADATA, + ) + update_rag_metadata_mock.assert_called_once() + rag_metadata_eq(metadata, test_rag_constants_preview.TEST_RAG_METADATA) + + @pytest.mark.usefixtures("rag_data_client_preview_mock_exception") + def test_update_metadata_failure(self): + with pytest.raises(RuntimeError) as e: + rag.update_metadata( + rag_metadata=test_rag_constants_preview.TEST_RAG_METADATA, + ) + e.match("Failed in RagMetadata update due to") + def test_update_rag_engine_config_success( self, update_rag_engine_config_basic_mock ): diff --git a/vertexai/preview/rag/__init__.py b/vertexai/preview/rag/__init__.py index e67eb7885d..92dde4801e 100644 --- a/vertexai/preview/rag/__init__.py +++ b/vertexai/preview/rag/__init__.py @@ -14,6 +14,10 @@ # limitations under the License. # from vertexai.preview.rag.rag_data import ( + batch_create_data_schemas, + batch_create_metadata, + batch_delete_data_schemas, + batch_delete_metadata, create_corpus, delete_corpus, delete_file, @@ -23,8 +27,11 @@ import_files, import_files_async, list_corpora, + list_data_schemas, list_files, + list_metadata, update_corpus, + update_metadata, update_rag_engine_config, upload_file, ) @@ -53,15 +60,19 @@ LlmParserConfig, LlmRanker, MemoryCorpus, + MetadataValue, Pinecone, RagCorpus, RagCorpusTypeConfig, + RagDataSchema, RagEmbeddingModelConfig, RagEngineConfig, RagFile, RagManagedDb, RagManagedDbConfig, RagManagedVertexVectorSearch, + RagMetadata, + RagMetadataSchemaDetails, RagResource, RagRetrievalConfig, RagVectorDbConfig, @@ -76,6 +87,7 @@ Spanner, TransformationConfig, Unprovisioned, + UserSpecifiedMetadata, VertexAiSearchConfig, VertexFeatureStore, VertexPredictionEndpoint, @@ -99,15 +111,19 @@ "LlmParserConfig", "LlmRanker", "MemoryCorpus", + "MetadataValue", "Pinecone", "RagEngineConfig", "RagCorpus", "RagCorpusTypeConfig", + "RagDataSchema", "RagEmbeddingModelConfig", "RagFile", "RagManagedDb", "RagManagedDbConfig", "RagManagedVertexVectorSearch", + "RagMetadata", + "RagMetadataSchemaDetails", "RagResource", "RagRetrievalConfig", "RagVectorDbConfig", @@ -123,6 +139,7 @@ "Spanner", "TransformationConfig", "Unprovisioned", + "UserSpecifiedMetadata", "VertexAiSearchConfig", "VertexFeatureStore", "VertexPredictionEndpoint", @@ -130,6 +147,10 @@ "VertexVectorSearch", "Weaviate", "ask_contexts", + "batch_create_data_schemas", + "batch_create_metadata", + "batch_delete_data_schemas", + "batch_delete_metadata", "create_corpus", "delete_corpus", "delete_file", @@ -138,11 +159,14 @@ "import_files", "import_files_async", "list_corpora", + "list_data_schemas", "list_files", + "list_metadata", "retrieval_query", "async_retrieve_contexts", "upload_file", "update_corpus", + "update_metadata", "update_rag_engine_config", "get_rag_engine_config", ) diff --git a/vertexai/preview/rag/rag_data.py b/vertexai/preview/rag/rag_data.py index 3b1d9553c6..ef91de78ac 100644 --- a/vertexai/preview/rag/rag_data.py +++ b/vertexai/preview/rag/rag_data.py @@ -15,7 +15,7 @@ # limitations under the License. # """RAG data management SDK.""" - +# from typing import Optional, Sequence, Union from google import auth from google.api_core import operation_async @@ -24,7 +24,13 @@ from google.cloud.aiplatform import initializer from google.cloud.aiplatform import utils from google.cloud.aiplatform_v1beta1 import ( + BatchCreateRagDataSchemasRequest, + BatchCreateRagMetadataRequest, + BatchDeleteRagDataSchemasRequest, + BatchDeleteRagMetadataRequest, CreateRagCorpusRequest, + CreateRagDataSchemaRequest, + CreateRagMetadataRequest, DeleteRagCorpusRequest, DeleteRagFileRequest, GetRagCorpusRequest, @@ -32,14 +38,19 @@ GetRagFileRequest, ImportRagFilesResponse, ListRagCorporaRequest, + ListRagDataSchemasRequest, ListRagFilesRequest, + ListRagMetadataRequest, RagCorpus as GapicRagCorpus, UpdateRagCorpusRequest, UpdateRagEngineConfigRequest, + UpdateRagMetadataRequest, ) from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service.pagers import ( ListRagCorporaPager, + ListRagDataSchemasPager, ListRagFilesPager, + ListRagMetadataPager, ) from google.cloud.aiplatform_v1beta1.types import EncryptionSpec from vertexai.preview.rag.utils import ( @@ -53,10 +64,12 @@ Pinecone, RagCorpus, RagCorpusTypeConfig, + RagDataSchema, RagEngineConfig, RagFile, RagManagedDb, RagManagedVertexVectorSearch, + RagMetadata, RagVectorDbConfig, SharePointSources, SlackChannelsSource, @@ -383,7 +396,6 @@ def delete_corpus(name: str) -> None: client = _gapic_utils.create_rag_data_service_client() try: client.delete_rag_corpus(request=request) - print("Successfully deleted the RagCorpus.") except Exception as e: raise RuntimeError("Failed in RagCorpus deletion due to: ", e) from e return None @@ -1002,12 +1014,250 @@ def delete_file(name: str, corpus_name: Optional[str] = None) -> None: client = _gapic_utils.create_rag_data_service_client() try: client.delete_rag_file(request=request) - print("Successfully deleted the RagFile.") except Exception as e: raise RuntimeError("Failed in RagFile deletion due to: ", e) from e return None +def batch_create_data_schemas( + corpus_name: str, + requests: Sequence[RagDataSchema], + timeout: int = 600, +) -> Sequence[RagDataSchema]: + """Batch creates RagDataSchema resources. + + Args: + corpus_name: The name of the RagCorpus resource to create the + RagDataSchemas in. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` or + ``{rag_corpus}``. + requests: The RagDataSchemas to create. + timeout: Default is 600 seconds. + + Returns: + Sequence of RagDataSchema. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + gapic_requests = [] + for request in requests: + gapic_requests.append( + CreateRagDataSchemaRequest( + parent=corpus_name, + rag_data_schema=_gapic_utils.convert_rag_data_schema_to_gapic(request), + ) + ) + request = BatchCreateRagDataSchemasRequest( + parent=corpus_name, + requests=gapic_requests, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.batch_create_rag_data_schemas(request=request) + except Exception as e: + raise RuntimeError("Failed in RagDataSchema batch creation due to: ", e) from e + result = response.result(timeout=timeout) + if result.rag_data_schemas: + return [ + _gapic_utils.convert_gapic_to_rag_data_schema(schema) + for schema in result.rag_data_schemas + ] + return [] + + +def batch_delete_data_schemas( + corpus_name: str, + names: Sequence[str], + timeout: int = 600, +) -> None: + """Batch deletes RagDataSchema resources. + + Args: + corpus_name: The name of the RagCorpus resource from which to delete the + RagDataSchemas. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` or + ``{rag_corpus}``. + names: The RagDataSchema resource names to delete. + timeout: Default is 600 seconds. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + full_names = [ + _gapic_utils.get_data_schema_name(name, corpus_name) for name in names + ] + request = BatchDeleteRagDataSchemasRequest( + parent=corpus_name, + names=full_names, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.batch_delete_rag_data_schemas(request=request) + response.result(timeout=timeout) + except Exception as e: + raise RuntimeError("Failed in RagDataSchema batch deletion due to: ", e) from e + return None + + +def list_data_schemas( + corpus_name: str, + page_size: Optional[int] = None, + page_token: Optional[str] = None, +) -> ListRagDataSchemasPager: + """Lists RagDataSchemas in an existing RagCorpus. + + Args: + corpus_name: An existing RagCorpus name. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` or + ``{rag_corpus}``. + page_size: The standard list page size. + page_token: The standard list page token. + + Returns: + ListRagDataSchemasPager. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + request = ListRagDataSchemasRequest( + parent=corpus_name, + page_size=page_size, + page_token=page_token, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + pager = client.list_rag_data_schemas(request=request) + except Exception as e: + raise RuntimeError("Failed in listing the RagDataSchemas due to: ", e) from e + return pager + + +def batch_create_metadata( + corpus_name: str, + file_name: str, + requests: Sequence[RagMetadata], + timeout: int = 600, +) -> Sequence[RagMetadata]: + """Batch creates RagMetadata resources. + + Args: + corpus_name: The name of the RagCorpus resource. + file_name: The name of the RagFile resource. + requests: The RagMetadata to create. + timeout: Default is 600 seconds. + + Returns: + Sequence of RagMetadata. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + file_name = _gapic_utils.get_file_name(file_name, corpus_name) + gapic_requests = [] + for request in requests: + gapic_requests.append( + CreateRagMetadataRequest( + parent=file_name, + rag_metadata=_gapic_utils.convert_rag_metadata_to_gapic(request), + ) + ) + request = BatchCreateRagMetadataRequest( + parent=file_name, + requests=gapic_requests, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.batch_create_rag_metadata(request=request) + except Exception as e: + raise RuntimeError("Failed in RagMetadata batch creation due to: ", e) from e + result = response.result(timeout=timeout) + if result.rag_metadata: + return [ + _gapic_utils.convert_gapic_to_rag_metadata(metadata) + for metadata in result.rag_metadata + ] + return [] + + +def batch_delete_metadata( + corpus_name: str, + file_name: str, + names: Sequence[str], + timeout: int = 600, +) -> None: + """Batch deletes RagMetadata resources. + + Args: + corpus_name: The name of the RagCorpus resource. + file_name: The name of the RagFile resource. + names: The RagMetadata resource names to delete. + timeout: Default is 600 seconds. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + file_name = _gapic_utils.get_file_name(file_name, corpus_name) + full_names = [ + _gapic_utils.get_metadata_name(name, corpus_name, file_name) for name in names + ] + request = BatchDeleteRagMetadataRequest( + parent=file_name, + names=full_names, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.batch_delete_rag_metadata(request=request) + response.result(timeout=timeout) + except Exception as e: + raise RuntimeError("Failed in RagMetadata batch deletion due to: ", e) from e + return None + + +def list_metadata( + corpus_name: str, + file_name: str, + page_size: Optional[int] = None, + page_token: Optional[str] = None, +) -> ListRagMetadataPager: + """Lists RagMetadata in an existing RagFile. + + Args: + corpus_name: An existing RagCorpus name. + file_name: An existing RagFile name. + page_size: The standard list page size. + page_token: The standard list page token. + + Returns: + ListRagMetadataPager. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + file_name = _gapic_utils.get_file_name(file_name, corpus_name) + request = ListRagMetadataRequest( + parent=file_name, + page_size=page_size, + page_token=page_token, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + pager = client.list_rag_metadata(request=request) + except Exception as e: + raise RuntimeError("Failed in listing the RagMetadata due to: ", e) from e + return pager + + +def update_metadata( + rag_metadata: RagMetadata, +) -> RagMetadata: + """Updates a RagMetadata resource. + + Args: + rag_metadata: The RagMetadata which replaces the resource on the server. + + Returns: + RagMetadata. + """ + request = UpdateRagMetadataRequest( + rag_metadata=_gapic_utils.convert_rag_metadata_to_gapic(rag_metadata), + ) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.update_rag_metadata(request=request) + except Exception as e: + raise RuntimeError("Failed in RagMetadata update due to: ", e) from e + return _gapic_utils.convert_gapic_to_rag_metadata(response) + + def update_rag_engine_config( rag_engine_config: RagEngineConfig, timeout: int = 600, diff --git a/vertexai/preview/rag/utils/_gapic_utils.py b/vertexai/preview/rag/utils/_gapic_utils.py index 3579300c7d..a372afe97f 100644 --- a/vertexai/preview/rag/utils/_gapic_utils.py +++ b/vertexai/preview/rag/utils/_gapic_utils.py @@ -23,6 +23,8 @@ VertexRagDataAsyncClientWithOverride, VertexRagDataClientWithOverride, ) + +# from google.cloud.aiplatform_v1beta1 import ( GoogleDriveSource, ImportRagFilesConfig, @@ -43,6 +45,9 @@ ) from google.cloud.aiplatform_v1beta1.types import api_auth from google.cloud.aiplatform_v1beta1.types import EncryptionSpec +from google.cloud.aiplatform_v1beta1.types import ( + vertex_rag_data as GapicRagDataTypes, +) from vertexai.preview.rag.utils.resources import ( ANN, Basic, @@ -54,15 +59,19 @@ LayoutParserConfig, LlmParserConfig, MemoryCorpus, + MetadataValue, Pinecone, RagCorpus, RagCorpusTypeConfig, + RagDataSchema, RagEmbeddingModelConfig, RagEngineConfig, RagFile, RagManagedDb, RagManagedDbConfig, RagManagedVertexVectorSearch, + RagMetadata, + RagMetadataSchemaDetails, RagVectorDbConfig, Scaled, Serverless, @@ -71,6 +80,7 @@ Spanner, TransformationConfig, Unprovisioned, + UserSpecifiedMetadata, VertexAiSearchConfig, VertexFeatureStore, VertexPredictionEndpoint, @@ -420,6 +430,191 @@ def convert_gapic_to_rag_file(gapic_rag_file: GapicRagFile) -> RagFile: return rag_file +def convert_gapic_to_rag_metadata( + gapic_rag_metadata: GapicRagDataTypes.RagMetadata, +) -> RagMetadata: + """Convert Gapic RagMetadata to RagMetadata.""" + return RagMetadata( + name=gapic_rag_metadata.name, + user_specified_metadata=convert_gapic_to_user_specified_metadata( + gapic_rag_metadata.user_specified_metadata + ), + ) + + +def convert_gapic_to_user_specified_metadata( + gapic_user_specified_metadata: GapicRagDataTypes.UserSpecifiedMetadata, +) -> UserSpecifiedMetadata: + """Convert Gapic UserSpecifiedMetadata to UserSpecifiedMetadata.""" + if not gapic_user_specified_metadata: + return None + return UserSpecifiedMetadata( + values={ + gapic_user_specified_metadata.key: convert_gapic_to_metadata_value( + gapic_user_specified_metadata.value + ) + } + ) + + +def convert_gapic_to_metadata_value( + gapic_metadata_value: GapicRagDataTypes.MetadataValue, +) -> MetadataValue: + """Convert Gapic MetadataValue to MetadataValue.""" + if not gapic_metadata_value: + return None + oneof_field = gapic_metadata_value._pb.WhichOneof("value") + if oneof_field == "str_value": + return MetadataValue(string_value=gapic_metadata_value.str_value) + elif oneof_field == "int_value": + return MetadataValue(int_value=gapic_metadata_value.int_value) + elif oneof_field == "float_value": + return MetadataValue(float_value=gapic_metadata_value.float_value) + elif oneof_field == "bool_value": + return MetadataValue(bool_value=gapic_metadata_value.bool_value) + return MetadataValue() + + +def convert_rag_metadata_to_gapic( + rag_metadata: RagMetadata, +) -> GapicRagDataTypes.RagMetadata: + """Convert RagMetadata to Gapic RagMetadata.""" + return GapicRagDataTypes.RagMetadata( + name=rag_metadata.name, + user_specified_metadata=convert_user_specified_metadata_to_gapic( + rag_metadata.user_specified_metadata + ), + ) + + +def convert_user_specified_metadata_to_gapic( + user_specified_metadata: UserSpecifiedMetadata, +) -> GapicRagDataTypes.UserSpecifiedMetadata: + """Convert UserSpecifiedMetadata to Gapic UserSpecifiedMetadata.""" + if not user_specified_metadata: + return None + if user_specified_metadata.values: + if len(user_specified_metadata.values) > 1: + raise ValueError( + "Only one key-value pair is supported in UserSpecifiedMetadata." + ) + key = list(user_specified_metadata.values.keys())[0] + return GapicRagDataTypes.UserSpecifiedMetadata( + key=key, + value=convert_metadata_value_to_gapic(user_specified_metadata.values[key]), + ) + return GapicRagDataTypes.UserSpecifiedMetadata() + + +def convert_metadata_value_to_gapic( + metadata_value: MetadataValue, +) -> GapicRagDataTypes.MetadataValue: + """Convert MetadataValue to Gapic MetadataValue.""" + if not metadata_value: + return None + if metadata_value.string_value is not None: + return GapicRagDataTypes.MetadataValue(str_value=metadata_value.string_value) + if metadata_value.int_value is not None: + return GapicRagDataTypes.MetadataValue(int_value=metadata_value.int_value) + if metadata_value.float_value is not None: + return GapicRagDataTypes.MetadataValue(float_value=metadata_value.float_value) + if metadata_value.bool_value is not None: + return GapicRagDataTypes.MetadataValue(bool_value=metadata_value.bool_value) + return GapicRagDataTypes.MetadataValue() + + +def convert_gapic_to_rag_data_schema( + gapic_rag_data_schema: GapicRagDataTypes.RagDataSchema, +) -> RagDataSchema: + """Convert Gapic RagDataSchema to RagDataSchema.""" + return RagDataSchema( + name=gapic_rag_data_schema.name, + key=gapic_rag_data_schema.key, + schema_details=convert_gapic_to_rag_metadata_schema_details( + gapic_rag_data_schema.schema_details + ), + ) + + +def convert_gapic_to_rag_metadata_schema_details( + gapic_details: GapicRagDataTypes.RagMetadataSchemaDetails, +) -> RagMetadataSchemaDetails: + """Convert Gapic RagMetadataSchemaDetails to RagMetadataSchemaDetails.""" + if not gapic_details: + return None + list_config = None + if gapic_details.list_config: + list_config = RagMetadataSchemaDetails.ListConfig( + value_schema=convert_gapic_to_rag_metadata_schema_details( + gapic_details.list_config.value_schema + ) + ) + search_strategy = None + if gapic_details.search_strategy: + search_strategy = RagMetadataSchemaDetails.SearchStrategy( + search_strategy_type=GapicRagDataTypes.RagMetadataSchemaDetails.SearchStrategy.SearchStrategyType( + gapic_details.search_strategy.search_strategy_type + ).name + ) + return RagMetadataSchemaDetails( + type=GapicRagDataTypes.RagMetadataSchemaDetails.DataType( + gapic_details.type_ + ).name, + granularity=GapicRagDataTypes.RagMetadataSchemaDetails.Granularity( + gapic_details.granularity + ).name, + list_config=list_config, + search_strategy=search_strategy, + ) + + +def convert_rag_data_schema_to_gapic( + rag_data_schema: RagDataSchema, +) -> GapicRagDataTypes.RagDataSchema: + """Convert RagDataSchema to Gapic RagDataSchema.""" + return GapicRagDataTypes.RagDataSchema( + name=rag_data_schema.name, + key=rag_data_schema.key, + schema_details=convert_rag_metadata_schema_details_to_gapic( + rag_data_schema.schema_details + ), + ) + + +def convert_rag_metadata_schema_details_to_gapic( + details: RagMetadataSchemaDetails, +) -> GapicRagDataTypes.RagMetadataSchemaDetails: + """Convert RagMetadataSchemaDetails to Gapic RagMetadataSchemaDetails.""" + if not details: + return None + list_config = None + if details.list_config: + list_config = GapicRagDataTypes.RagMetadataSchemaDetails.ListConfig( + value_schema=convert_rag_metadata_schema_details_to_gapic( + details.list_config.value_schema + ) + ) + search_strategy = None + if details.search_strategy: + search_strategy = GapicRagDataTypes.RagMetadataSchemaDetails.SearchStrategy( + search_strategy_type=details.search_strategy.search_strategy_type + ) + return GapicRagDataTypes.RagMetadataSchemaDetails( + type_=( + details.type + if details.type + else GapicRagDataTypes.RagMetadataSchemaDetails.DataType.DATA_TYPE_UNSPECIFIED + ), + granularity=( + details.granularity + if details.granularity + else GapicRagDataTypes.RagMetadataSchemaDetails.Granularity.GRANULARITY_UNSPECIFIED + ), + list_config=list_config, + search_strategy=search_strategy, + ) + + def convert_json_to_rag_file(upload_rag_file_response: Dict[str, Any]) -> RagFile: """Converts a JSON response to a RagFile.""" rag_file = RagFile( @@ -728,7 +923,8 @@ def get_file_name( if not corpus_name: raise ValueError( "corpus_name must be provided if name is a `{rag_file}`, not a " - "full resource name (`projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}`). " + "full resource name" + " (`projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}`). " ) return client.rag_file_path( project=initializer.global_config.project, @@ -738,10 +934,76 @@ def get_file_name( ) else: raise ValueError( - "name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}` or `{rag_file}`" + "name must be of the format" + " `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}`" + " or `{rag_file}`" ) +def get_data_schema_name( + name: str, + corpus_name: str, +) -> str: + """Get the full resource name for a RagDataSchema.""" + if name: + if len(name.split("/")) == 8: + return name + elif re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), name): + if not corpus_name: + raise ValueError( + "corpus_name must be provided if name is a `{rag_data_schema}`," + " not a " + "full resource name" + " (`projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragDataSchemas/{rag_data_schema}`). " + ) + return "projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragDataSchemas/{rag_data_schema}".format( + project=initializer.global_config.project, + location=initializer.global_config.location, + rag_corpus=get_corpus_name(corpus_name).split("/")[-1], + rag_data_schema=name, + ) + else: + raise ValueError( + "name must be of the format" + " `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragDataSchemas/{rag_data_schema}`" + " or `{rag_data_schema}`" + ) + return name + + +def get_metadata_name( + name: str, + corpus_name: str, + file_name: str, +) -> str: + """Get the full resource name for a RagMetadata.""" + if name: + if len(name.split("/")) == 10: + return name + elif re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), name): + if not corpus_name or not file_name: + raise ValueError( + "corpus_name and file_name must be provided if name is a" + " `{rag_metadata}`, not a " + "full resource name" + " (`projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}/ragMetadata/{rag_metadata}`). " + ) + return "projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}/ragMetadata/{rag_metadata}".format( + project=initializer.global_config.project, + location=initializer.global_config.location, + rag_corpus=get_corpus_name(corpus_name).split("/")[-1], + rag_file=get_file_name(file_name, corpus_name).split("/")[-1], + rag_metadata=name, + ) + else: + raise ValueError( + "name must be of the format" + " `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}/ragMetadata/{rag_metadata}`" + " or `{rag_metadata}`" + ) + return name + + def set_corpus_type_config( corpus_type_config: RagCorpusTypeConfig, rag_corpus: GapicRagCorpus, diff --git a/vertexai/preview/rag/utils/resources.py b/vertexai/preview/rag/utils/resources.py index 140906d796..a73c3ace27 100644 --- a/vertexai/preview/rag/utils/resources.py +++ b/vertexai/preview/rag/utils/resources.py @@ -25,6 +25,7 @@ DEPRECATION_DATE = "June 2025" +# @dataclasses.dataclass class RagFile: """RAG file (output only). @@ -736,3 +737,103 @@ class RagCorpus: vertex_ai_search_config: Optional[VertexAiSearchConfig] = None backend_config: Optional[RagVectorDbConfig] = None encryption_spec: Optional[EncryptionSpec] = None + + +@dataclasses.dataclass +class RagMetadataSchemaDetails: + """Data schema details indicates the data type and the data + + struct corresponding to the key of user specified metadata. + + Attributes: + type (str): Type of the metadata. + list_config (RagMetadataSchemaDetails.ListConfig): Config for List data + type. + granularity (str): The granularity associated with this RagMetadataSchema. + search_strategy (RagMetadataSchemaDetails.SearchStrategy): The search + strategy for the metadata value of the key. + """ + + @dataclasses.dataclass + class ListConfig: + """Config for List data type. + + Attributes: + value_schema (RagMetadataSchemaDetails): The value's data type in the + list. + """ + + value_schema: Optional["RagMetadataSchemaDetails"] = None + + @dataclasses.dataclass + class SearchStrategy: + """The search strategy for the metadata value of the key. + + Attributes: + search_strategy_type (str): The search strategy type to be applied on + the metadata key. + """ + + search_strategy_type: Optional[str] = None + + type: Optional[str] = None + list_config: Optional[ListConfig] = None + granularity: Optional[str] = None + search_strategy: Optional[SearchStrategy] = None + + +@dataclasses.dataclass +class RagDataSchema: + """The schema of the user specified metadata. + + Attributes: + name (str): Identifier. Resource name of the data schema. + key (str): Required. The key of this data schema. + schema_details (RagMetadataSchemaDetails): The schema details mapping to + the key. + """ + + name: Optional[str] = None + key: Optional[str] = None + schema_details: Optional[RagMetadataSchemaDetails] = None + + +@dataclasses.dataclass +class MetadataValue: + """The value of metadata. + + Attributes: + string_value (str): The string value. + int_value (int): The int value. + float_value (float): The float value. + bool_value (bool): The bool value. + """ + + string_value: Optional[str] = None + int_value: Optional[int] = None + float_value: Optional[float] = None + bool_value: Optional[bool] = None + + +@dataclasses.dataclass +class RagMetadata: + """Metadata for RagFile provided by users. + + Attributes: + name (str): Identifier. Resource name of the RagMetadata. + user_specified_metadata (UserSpecifiedMetadata): User provided metadata. + """ + + name: Optional[str] = None + user_specified_metadata: Optional["UserSpecifiedMetadata"] = None + + +@dataclasses.dataclass +class UserSpecifiedMetadata: + """Metadata provided by users. + + Attributes: + values (Dict[str, MetadataValue]): Required. The values of the metadata. + """ + + values: dict[str, MetadataValue]