diff --git a/src/mock_vws/_mock_common.py b/src/mock_vws/_mock_common.py index 0ebbd379b..5a7069562 100644 --- a/src/mock_vws/_mock_common.py +++ b/src/mock_vws/_mock_common.py @@ -1,13 +1,30 @@ """Common utilities for creating mock routes.""" import json -from collections.abc import Iterable +from collections.abc import Iterable, Mapping from dataclasses import dataclass from typing import Any from beartype import beartype +@dataclass(frozen=True) +class RequestData: + """A library-agnostic representation of an HTTP request. + + Args: + method: The HTTP method of the request. + path: The path of the request. + headers: The headers sent with the request. + body: The body of the request. + """ + + method: str + path: str + headers: Mapping[str, str] + body: bytes + + @dataclass(frozen=True) class Route: """A representation of a VWS route. diff --git a/src/mock_vws/_requests_mock_server/decorators.py b/src/mock_vws/_requests_mock_server/decorators.py index 35535217b..04a961989 100644 --- a/src/mock_vws/_requests_mock_server/decorators.py +++ b/src/mock_vws/_requests_mock_server/decorators.py @@ -12,6 +12,7 @@ from requests import PreparedRequest from responses import RequestsMock +from mock_vws._mock_common import RequestData from mock_vws.database import CloudDatabase, VuMarkDatabase from mock_vws.image_matchers import ( ImageMatcher, @@ -27,7 +28,8 @@ from .mock_web_services_api import MockVuforiaWebServicesAPI _ResponseType = tuple[int, Mapping[str, str], str | bytes] -_Callback = Callable[[PreparedRequest], _ResponseType] +_MockCallback = Callable[[RequestData], _ResponseType] +_ResponsesCallback = Callable[[PreparedRequest], _ResponseType] _STRUCTURAL_SIMILARITY_MATCHER = StructuralSimilarityMatcher() _BRISQUE_TRACKING_RATER = BrisqueTargetTrackingRater() @@ -156,10 +158,10 @@ def add_vumark_database(self, vumark_database: VuMarkDatabase) -> None: @staticmethod def _wrap_callback( - callback: _Callback, + callback: _MockCallback, delay_seconds: float, sleep_fn: Callable[[float], None], - ) -> _Callback: + ) -> _ResponsesCallback: """Wrap a callback to add a response delay.""" def wrapped( @@ -186,7 +188,21 @@ def wrapped( sleep_fn(effective) raise requests.exceptions.Timeout - result = callback(request) + raw_body = request.body + if raw_body is None: + body_bytes = b"" + elif isinstance(raw_body, str): + body_bytes = raw_body.encode(encoding="utf-8") + else: + body_bytes = raw_body + + request_data = RequestData( + method=request.method or "", + path=request.path_url, + headers=dict(request.headers), + body=body_bytes, + ) + result = callback(request_data) sleep_fn(delay_seconds) return result diff --git a/src/mock_vws/_requests_mock_server/mock_web_query_api.py b/src/mock_vws/_requests_mock_server/mock_web_query_api.py index e2626a25a..6f8f519d3 100644 --- a/src/mock_vws/_requests_mock_server/mock_web_query_api.py +++ b/src/mock_vws/_requests_mock_server/mock_web_query_api.py @@ -10,9 +10,8 @@ from typing import ParamSpec, Protocol, runtime_checkable from beartype import beartype -from requests.models import PreparedRequest -from mock_vws._mock_common import Route +from mock_vws._mock_common import RequestData, Route from mock_vws._query_tools import ( get_query_match_response_text, ) @@ -78,21 +77,9 @@ def decorator( return decorator -@beartype -def _body_bytes(request: PreparedRequest) -> bytes: - """Return the body of a request as bytes.""" - if request.body is None or isinstance(request.body, str): - return b"" - - return request.body - - @beartype class MockVuforiaWebQueryAPI: - """A fake implementation of the Vuforia Web Query API. - - This implementation is tied to the implementation of ``responses``. - """ + """A fake implementation of the Vuforia Web Query API.""" def __init__( self, @@ -114,14 +101,14 @@ def __init__( self._query_match_checker = query_match_checker @route(path_pattern="/v1/query", http_methods={HTTPMethod.POST}) - def query(self, request: PreparedRequest) -> _ResponseType: + def query(self, request: RequestData) -> _ResponseType: """Perform an image recognition query.""" try: run_query_validators( - request_path=request.path_url, + request_path=request.path, request_headers=request.headers, - request_body=_body_bytes(request=request), - request_method=request.method or "", + request_body=request.body, + request_method=request.method, databases=self._target_manager.cloud_databases, ) except ValidatorError as exc: @@ -129,9 +116,9 @@ def query(self, request: PreparedRequest) -> _ResponseType: response_text = get_query_match_response_text( request_headers=request.headers, - request_body=_body_bytes(request=request), - request_method=request.method or "", - request_path=request.path_url, + request_body=request.body, + request_method=request.method, + request_path=request.path, databases=self._target_manager.cloud_databases, query_match_checker=self._query_match_checker, ) diff --git a/src/mock_vws/_requests_mock_server/mock_web_services_api.py b/src/mock_vws/_requests_mock_server/mock_web_services_api.py index deac884bf..1b0662cea 100644 --- a/src/mock_vws/_requests_mock_server/mock_web_services_api.py +++ b/src/mock_vws/_requests_mock_server/mock_web_services_api.py @@ -16,7 +16,6 @@ from zoneinfo import ZoneInfo from beartype import BeartypeConf, beartype -from requests.models import PreparedRequest from mock_vws._constants import ( VUMARK_PDF, @@ -26,7 +25,7 @@ TargetStatuses, ) from mock_vws._database_matchers import get_database_matching_server_keys -from mock_vws._mock_common import Route, json_dump +from mock_vws._mock_common import RequestData, Route, json_dump from mock_vws._services_validators import run_services_validators from mock_vws._services_validators.exceptions import ( FailError, @@ -105,24 +104,9 @@ def decorator( return decorator -@beartype -def _body_bytes(request: PreparedRequest) -> bytes: - """Return the body of a request as bytes.""" - if request.body is None: - return b"" - - if isinstance(request.body, str): - return request.body.encode(encoding="utf-8") - - return request.body - - @beartype(conf=BeartypeConf(is_pep484_tower=True)) class MockVuforiaWebServicesAPI: - """A fake implementation of the Vuforia Web Services API. - - This implementation is tied to the implementation of ``responses``. - """ + """A fake implementation of the Vuforia Web Services API.""" def __init__( self, @@ -157,7 +141,7 @@ def __init__( path_pattern="/targets", http_methods={HTTPMethod.POST}, ) - def add_target(self, request: PreparedRequest) -> _ResponseType: + def add_target(self, request: RequestData) -> _ResponseType: """Add a target. Fake implementation of @@ -166,9 +150,9 @@ def add_target(self, request: PreparedRequest) -> _ResponseType: try: run_services_validators( request_headers=request.headers, - request_body=_body_bytes(request=request), - request_method=request.method or "", - request_path=request.path_url, + request_body=request.body, + request_method=request.method, + request_path=request.path, databases=self._target_manager.cloud_databases, ) except ValidatorError as exc: @@ -176,13 +160,13 @@ def add_target(self, request: PreparedRequest) -> _ResponseType: database = get_database_matching_server_keys( request_headers=request.headers, - request_body=_body_bytes(request=request), - request_method=request.method or "", - request_path=request.path_url, + request_body=request.body, + request_method=request.method, + request_path=request.path, databases=self._target_manager.cloud_databases, ) - request_json: dict[str, Any] = json.loads(s=request.body or b"") + request_json: dict[str, Any] = json.loads(s=request.body) given_active_flag = request_json.get("active_flag") active_flag = { None: True, @@ -232,7 +216,7 @@ def add_target(self, request: PreparedRequest) -> _ResponseType: path_pattern=f"/targets/{_TARGET_ID_PATTERN}", http_methods={HTTPMethod.DELETE}, ) - def delete_target(self, request: PreparedRequest) -> _ResponseType: + def delete_target(self, request: RequestData) -> _ResponseType: """Delete a target. Fake implementation of @@ -241,9 +225,9 @@ def delete_target(self, request: PreparedRequest) -> _ResponseType: try: run_services_validators( request_headers=request.headers, - request_body=_body_bytes(request=request), - request_method=request.method or "", - request_path=request.path_url, + request_body=request.body, + request_method=request.method, + request_path=request.path, databases=self._target_manager.cloud_databases, ) except ValidatorError as exc: @@ -251,13 +235,13 @@ def delete_target(self, request: PreparedRequest) -> _ResponseType: database = get_database_matching_server_keys( request_headers=request.headers, - request_body=_body_bytes(request=request), - request_method=request.method or "", - request_path=request.path_url, + request_body=request.body, + request_method=request.method, + request_path=request.path, databases=self._target_manager.cloud_databases, ) - target_id = request.path_url.split(sep="/")[-1] + target_id = request.path.split(sep="/")[-1] target = database.get_target(target_id=target_id) if target.status == TargetStatuses.PROCESSING.value: @@ -304,9 +288,7 @@ def delete_target(self, request: PreparedRequest) -> _ResponseType: path_pattern=f"/targets/{_TARGET_ID_PATTERN}/instances", http_methods={HTTPMethod.POST}, ) - def generate_vumark_instance( - self, request: PreparedRequest - ) -> _ResponseType: + def generate_vumark_instance(self, request: RequestData) -> _ResponseType: """Generate a VuMark instance.""" valid_accept_types: dict[str, bytes] = { "image/png": VUMARK_PNG, @@ -320,17 +302,17 @@ def generate_vumark_instance( ] run_services_validators( request_headers=request.headers, - request_body=_body_bytes(request=request), - request_method=request.method or "", - request_path=request.path_url, + request_body=request.body, + request_method=request.method, + request_path=request.path, databases=all_databases, ) database = get_database_matching_server_keys( request_headers=request.headers, - request_body=_body_bytes(request=request), - request_method=request.method or "", - request_path=request.path_url, + request_body=request.body, + request_method=request.method, + request_path=request.path, databases=all_databases, ) if not isinstance(database, VuMarkDatabase): @@ -340,7 +322,7 @@ def generate_vumark_instance( if accept not in valid_accept_types: raise InvalidAcceptHeaderError - request_json = json.loads(s=_body_bytes(request=request)) + request_json = json.loads(s=request.body) instance_id = request_json.get("instance_id", "") if not instance_id: raise InvalidInstanceIdError @@ -367,7 +349,7 @@ def generate_vumark_instance( return HTTPStatus.OK, headers, response_body @route(path_pattern="/summary", http_methods={HTTPMethod.GET}) - def database_summary(self, request: PreparedRequest) -> _ResponseType: + def database_summary(self, request: RequestData) -> _ResponseType: """Get a database summary report. Fake implementation of @@ -376,9 +358,9 @@ def database_summary(self, request: PreparedRequest) -> _ResponseType: try: run_services_validators( request_headers=request.headers, - request_body=_body_bytes(request=request), - request_method=request.method or "", - request_path=request.path_url, + request_body=request.body, + request_method=request.method, + request_path=request.path, databases=self._target_manager.cloud_databases, ) except ValidatorError as exc: @@ -386,9 +368,9 @@ def database_summary(self, request: PreparedRequest) -> _ResponseType: database = get_database_matching_server_keys( request_headers=request.headers, - request_body=_body_bytes(request=request), - request_method=request.method or "", - request_path=request.path_url, + request_body=request.body, + request_method=request.method, + request_path=request.path, databases=self._target_manager.cloud_databases, ) @@ -428,7 +410,7 @@ def database_summary(self, request: PreparedRequest) -> _ResponseType: return HTTPStatus.OK, headers, body_json @route(path_pattern="/targets", http_methods={HTTPMethod.GET}) - def target_list(self, request: PreparedRequest) -> _ResponseType: + def target_list(self, request: RequestData) -> _ResponseType: """Get a list of all targets. Fake implementation of @@ -437,9 +419,9 @@ def target_list(self, request: PreparedRequest) -> _ResponseType: try: run_services_validators( request_headers=request.headers, - request_body=_body_bytes(request=request), - request_method=request.method or "", - request_path=request.path_url, + request_body=request.body, + request_method=request.method, + request_path=request.path, databases=self._target_manager.cloud_databases, ) except ValidatorError as exc: @@ -447,9 +429,9 @@ def target_list(self, request: PreparedRequest) -> _ResponseType: database = get_database_matching_server_keys( request_headers=request.headers, - request_body=_body_bytes(request=request), - request_method=request.method or "", - request_path=request.path_url, + request_body=request.body, + request_method=request.method, + request_path=request.path, databases=self._target_manager.cloud_databases, ) @@ -485,7 +467,7 @@ def target_list(self, request: PreparedRequest) -> _ResponseType: path_pattern=f"/targets/{_TARGET_ID_PATTERN}", http_methods={HTTPMethod.GET}, ) - def get_target(self, request: PreparedRequest) -> _ResponseType: + def get_target(self, request: RequestData) -> _ResponseType: """Get details of a target. Fake implementation of @@ -494,9 +476,9 @@ def get_target(self, request: PreparedRequest) -> _ResponseType: try: run_services_validators( request_headers=request.headers, - request_body=_body_bytes(request=request), - request_method=request.method or "", - request_path=request.path_url, + request_body=request.body, + request_method=request.method, + request_path=request.path, databases=self._target_manager.cloud_databases, ) except ValidatorError as exc: @@ -504,12 +486,12 @@ def get_target(self, request: PreparedRequest) -> _ResponseType: database = get_database_matching_server_keys( request_headers=request.headers, - request_body=_body_bytes(request=request), - request_method=request.method or "", - request_path=request.path_url, + request_body=request.body, + request_method=request.method, + request_path=request.path, databases=self._target_manager.cloud_databases, ) - target_id = request.path_url.split(sep="/")[-1] + target_id = request.path.split(sep="/")[-1] target = database.get_target(target_id=target_id) width = target.width @@ -553,7 +535,7 @@ def get_target(self, request: PreparedRequest) -> _ResponseType: path_pattern=f"/duplicates/{_TARGET_ID_PATTERN}", http_methods={HTTPMethod.GET}, ) - def get_duplicates(self, request: PreparedRequest) -> _ResponseType: + def get_duplicates(self, request: RequestData) -> _ResponseType: """Get targets which may be considered duplicates of a given target. @@ -563,9 +545,9 @@ def get_duplicates(self, request: PreparedRequest) -> _ResponseType: try: run_services_validators( request_headers=request.headers, - request_body=_body_bytes(request=request), - request_method=request.method or "", - request_path=request.path_url, + request_body=request.body, + request_method=request.method, + request_path=request.path, databases=self._target_manager.cloud_databases, ) except ValidatorError as exc: @@ -573,12 +555,12 @@ def get_duplicates(self, request: PreparedRequest) -> _ResponseType: database = get_database_matching_server_keys( request_headers=request.headers, - request_body=_body_bytes(request=request), - request_method=request.method or "", - request_path=request.path_url, + request_body=request.body, + request_method=request.method, + request_path=request.path, databases=self._target_manager.cloud_databases, ) - target_id = request.path_url.split(sep="/")[-1] + target_id = request.path.split(sep="/")[-1] target = database.get_target(target_id=target_id) other_targets = database.targets - {target} @@ -625,7 +607,7 @@ def get_duplicates(self, request: PreparedRequest) -> _ResponseType: path_pattern=f"/targets/{_TARGET_ID_PATTERN}", http_methods={HTTPMethod.PUT}, ) - def update_target(self, request: PreparedRequest) -> _ResponseType: + def update_target(self, request: RequestData) -> _ResponseType: """Update a target. Fake implementation of @@ -634,9 +616,9 @@ def update_target(self, request: PreparedRequest) -> _ResponseType: try: run_services_validators( request_headers=request.headers, - request_body=_body_bytes(request=request), - request_method=request.method or "", - request_path=request.path_url, + request_body=request.body, + request_method=request.method, + request_path=request.path, databases=self._target_manager.cloud_databases, ) except ValidatorError as exc: @@ -644,13 +626,13 @@ def update_target(self, request: PreparedRequest) -> _ResponseType: database = get_database_matching_server_keys( request_headers=request.headers, - request_body=_body_bytes(request=request), - request_method=request.method or "", - request_path=request.path_url, + request_body=request.body, + request_method=request.method, + request_path=request.path, databases=self._target_manager.cloud_databases, ) - target_id = request.path_url.split(sep="/")[-1] + target_id = request.path.split(sep="/")[-1] target = database.get_target(target_id=target_id) date = email.utils.formatdate( @@ -667,7 +649,7 @@ def update_target(self, request: PreparedRequest) -> _ResponseType: exception.response_text, ) - request_json: dict[str, Any] = json.loads(s=request.body or b"") + request_json: dict[str, Any] = json.loads(s=request.body) name = request_json.get("name", target.name) active_flag = request_json.get("active_flag", target.active_flag) @@ -739,7 +721,7 @@ def update_target(self, request: PreparedRequest) -> _ResponseType: path_pattern=f"/summary/{_TARGET_ID_PATTERN}", http_methods={HTTPMethod.GET}, ) - def target_summary(self, request: PreparedRequest) -> _ResponseType: + def target_summary(self, request: RequestData) -> _ResponseType: """Get a summary report for a target. Fake implementation of @@ -748,9 +730,9 @@ def target_summary(self, request: PreparedRequest) -> _ResponseType: try: run_services_validators( request_headers=request.headers, - request_body=_body_bytes(request=request), - request_method=request.method or "", - request_path=request.path_url, + request_body=request.body, + request_method=request.method, + request_path=request.path, databases=self._target_manager.cloud_databases, ) except ValidatorError as exc: @@ -758,12 +740,12 @@ def target_summary(self, request: PreparedRequest) -> _ResponseType: database = get_database_matching_server_keys( request_headers=request.headers, - request_body=_body_bytes(request=request), - request_method=request.method or "", - request_path=request.path_url, + request_body=request.body, + request_method=request.method, + request_path=request.path, databases=self._target_manager.cloud_databases, ) - target_id = request.path_url.split(sep="/")[-1] + target_id = request.path.split(sep="/")[-1] target = database.get_target(target_id=target_id) date = email.utils.formatdate(