diff --git a/mock_tests/conftest.py b/mock_tests/conftest.py index 9c0bf19ec..0fd7df100 100644 --- a/mock_tests/conftest.py +++ b/mock_tests/conftest.py @@ -281,6 +281,85 @@ def metadata_capture_collection( return weaviate_client.collections.use("MetadataCaptureCollection"), service +BATCH_INSERT_TIMEOUT = 5 + + +class MockBatchDeadlineCaptureWeaviateService(weaviate_pb2_grpc.WeaviateServicer): + captured_time_remaining: float = -1.0 + + def BatchObjects( + self, request: batch_pb2.BatchObjectsRequest, context: grpc.ServicerContext + ) -> batch_pb2.BatchObjectsReply: + self.captured_time_remaining = context.time_remaining() + return batch_pb2.BatchObjectsReply() + + +@pytest.fixture(scope="function") +def weaviate_batch_insert_timeout_client( + weaviate_mock: HTTPServer, start_grpc_server: grpc.Server +) -> Generator[weaviate.WeaviateClient, None, None]: + weaviate_mock.expect_request(f"/v1/schema/{mock_class['class']}").respond_with_json(mock_class) + client = weaviate.connect_to_local( + host=MOCK_IP, + port=MOCK_PORT, + grpc_port=MOCK_PORT_GRPC, + additional_config=weaviate.classes.init.AdditionalConfig( + timeout=weaviate.classes.init.Timeout(insert=BATCH_INSERT_TIMEOUT) + ), + ) + yield client + client.close() + + +@pytest.fixture(scope="function") +def batch_deadline_capture_collection( + weaviate_batch_insert_timeout_client: weaviate.WeaviateClient, + start_grpc_server: grpc.Server, +) -> tuple[weaviate.collections.Collection, MockBatchDeadlineCaptureWeaviateService]: + service = MockBatchDeadlineCaptureWeaviateService() + weaviate_pb2_grpc.add_WeaviateServicer_to_server(service, start_grpc_server) + return ( + weaviate_batch_insert_timeout_client.collections.use(mock_class["class"]), + service, + ) + + +BATCH_INSERT_TIMEOUT_SHORT = 0.5 + + +@pytest.fixture(scope="function") +def weaviate_batch_insert_timeout_short_client( + weaviate_mock: HTTPServer, start_grpc_server: grpc.Server +) -> Generator[weaviate.WeaviateClient, None, None]: + weaviate_mock.expect_request(f"/v1/schema/{mock_class['class']}").respond_with_json(mock_class) + client = weaviate.connect_to_local( + host=MOCK_IP, + port=MOCK_PORT, + grpc_port=MOCK_PORT_GRPC, + additional_config=weaviate.classes.init.AdditionalConfig( + timeout=weaviate.classes.init.Timeout(insert=BATCH_INSERT_TIMEOUT_SHORT) + ), + ) + yield client + client.close() + + +@pytest.fixture(scope="function") +def batch_slow_response_collection( + weaviate_batch_insert_timeout_short_client: weaviate.WeaviateClient, + start_grpc_server: grpc.Server, +) -> weaviate.collections.Collection: + class MockWeaviateService(weaviate_pb2_grpc.WeaviateServicer): + def BatchObjects( + self, request: batch_pb2.BatchObjectsRequest, context: grpc.ServicerContext + ) -> batch_pb2.BatchObjectsReply: + time.sleep(BATCH_INSERT_TIMEOUT_SHORT + 1) + return batch_pb2.BatchObjectsReply() + + weaviate_pb2_grpc.add_WeaviateServicer_to_server(MockWeaviateService(), start_grpc_server) + return weaviate_batch_insert_timeout_short_client.collections.use(mock_class["class"]) + + class MockRetriesWeaviateService(weaviate_pb2_grpc.WeaviateServicer): search_count = 0 tenants_count = 0 diff --git a/mock_tests/test_timeouts.py b/mock_tests/test_timeouts.py index 5f5a51b57..77966e898 100644 --- a/mock_tests/test_timeouts.py +++ b/mock_tests/test_timeouts.py @@ -3,6 +3,8 @@ import weaviate from weaviate.exceptions import WeaviateQueryError, WeaviateTimeoutError +from .conftest import BATCH_INSERT_TIMEOUT, MockBatchDeadlineCaptureWeaviateService + def test_timeout_rest_query(timeouts_collection: weaviate.collections.Collection): with pytest.raises(WeaviateTimeoutError): @@ -24,3 +26,24 @@ def test_timeout_grpc_insert(timeouts_collection: weaviate.collections.Collectio with pytest.raises(WeaviateQueryError) as recwarn: timeouts_collection.data.insert_many([{"what": "ever"}]) assert "DEADLINE_EXCEEDED" in str(recwarn) + + +def test_batch_fixed_size_deadline_uses_insert_timeout( + batch_deadline_capture_collection: tuple[ + weaviate.collections.Collection, MockBatchDeadlineCaptureWeaviateService + ], +): + collection, service = batch_deadline_capture_collection + with collection.batch.fixed_size(batch_size=1) as batch: + batch.add_object({"what": "ever"}) + assert abs(service.captured_time_remaining - BATCH_INSERT_TIMEOUT) < 1 + + +def test_batch_fixed_size_times_out_when_insert_exceeded( + batch_slow_response_collection: weaviate.collections.Collection, +): + with batch_slow_response_collection.batch.fixed_size(batch_size=1) as batch: + batch.add_object({"what": "ever"}) + failed = batch_slow_response_collection.batch.failed_objects + assert len(failed) == 1 + assert "Deadline Exceeded" in failed[0].message diff --git a/weaviate/collections/batch/base.py b/weaviate/collections/batch/base.py index a700df53f..af6a9ea49 100644 --- a/weaviate/collections/batch/base.py +++ b/weaviate/collections/batch/base.py @@ -54,7 +54,6 @@ TBatchInput = TypeVar("TBatchInput") TBatchReturn = TypeVar("TBatchReturn") MAX_CONCURRENT_REQUESTS = 10 -DEFAULT_REQUEST_TIMEOUT = 180 CONCURRENT_REQUESTS_DYNAMIC_VECTORIZER = 2 BATCH_TIME_TARGET = 10 VECTORIZER_BATCHING_STEP_SIZE = 48 # cohere max batch size is 96 @@ -612,7 +611,7 @@ def __send_batch( self.__batch_grpc.objects( connection=self.__connection, objects=[obj._to_internal() for obj in objs], - timeout=DEFAULT_REQUEST_TIMEOUT, + timeout=self.__connection.timeout_config.insert, max_retries=MAX_RETRIES, ) )