From 2de5b6778cae3d0ffa85c79d404ac696611b869c Mon Sep 17 00:00:00 2001 From: Lukasz Lancucki Date: Thu, 12 Feb 2026 16:28:35 +0000 Subject: [PATCH] fix: validate, clean base url --- mpt_api_client/http/async_client.py | 9 ++---- mpt_api_client/http/client.py | 9 ++---- mpt_api_client/http/client_utils.py | 47 ++++++++++++++++++++++++++++ tests/unit/http/test_client_utils.py | 42 +++++++++++++++++++++++++ 4 files changed, 93 insertions(+), 14 deletions(-) create mode 100644 mpt_api_client/http/client_utils.py create mode 100644 tests/unit/http/test_client_utils.py diff --git a/mpt_api_client/http/async_client.py b/mpt_api_client/http/async_client.py index 128439ca..860793fa 100644 --- a/mpt_api_client/http/async_client.py +++ b/mpt_api_client/http/async_client.py @@ -11,6 +11,7 @@ from mpt_api_client.constants import APPLICATION_JSON from mpt_api_client.exceptions import MPTError, transform_http_status_exception from mpt_api_client.http.client import json_to_file_payload +from mpt_api_client.http.client_utils import validate_base_url from mpt_api_client.http.types import ( HeaderTypes, QueryParam, @@ -38,13 +39,7 @@ def __init__( "argument to MPTClient." ) - base_url = base_url or os.getenv("MPT_URL") - if not base_url: - raise ValueError( - "Base URL is required. " - "Set it up as env variable MPT_URL or pass it as `base_url` " - "argument to MPTClient." - ) + base_url = validate_base_url(base_url or os.getenv("MPT_URL")) base_headers = { "User-Agent": "swo-marketplace-client/1.0", "Authorization": f"Bearer {api_token}", diff --git a/mpt_api_client/http/client.py b/mpt_api_client/http/client.py index 29bd2566..c86ba4bf 100644 --- a/mpt_api_client/http/client.py +++ b/mpt_api_client/http/client.py @@ -14,6 +14,7 @@ MPTError, transform_http_status_exception, ) +from mpt_api_client.http.client_utils import validate_base_url from mpt_api_client.http.types import ( HeaderTypes, QueryParam, @@ -51,13 +52,7 @@ def __init__( "argument to MPTClient." ) - base_url = base_url or os.getenv("MPT_URL") - if not base_url: - raise ValueError( - "Base URL is required. " - "Set it up as env variable MPT_URL or pass it as `base_url` " - "argument to MPTClient." - ) + base_url = validate_base_url(base_url or os.getenv("MPT_URL")) base_headers = { "User-Agent": "swo-marketplace-client/1.0", "Authorization": f"Bearer {api_token}", diff --git a/mpt_api_client/http/client_utils.py b/mpt_api_client/http/client_utils.py new file mode 100644 index 00000000..ed64c2c4 --- /dev/null +++ b/mpt_api_client/http/client_utils.py @@ -0,0 +1,47 @@ +import re +from urllib.parse import SplitResult, urlsplit, urlunparse + +INVALID_ENV_URL_MESSAGE = ( + "Base URL is required. " + "Set it up as env variable MPT_URL or pass it as `base_url` " + "argument to MPTClient. Expected format scheme://host[:port]" +) +PATHS_TO_REMOVE_RE = re.compile(r"^/$|^/public/?$|^/public/v1/?$") + + +def _format_host(hostname: str | None) -> str: + if not hostname or not isinstance(hostname, str): + raise ValueError(INVALID_ENV_URL_MESSAGE) + + return f"[{hostname}]" if ":" in hostname else hostname + + +def _format_port(split_result: SplitResult) -> str: + try: + parsed_port = split_result.port + except ValueError as exc: + raise ValueError(INVALID_ENV_URL_MESSAGE) from exc + return f":{parsed_port}" if parsed_port else "" + + +def _sanitize_path(path: str) -> str: + return PATHS_TO_REMOVE_RE.sub("", path) + + +def _build_sanitized_base_url(split_result: SplitResult) -> str: + host = _format_host(split_result.hostname) + port = _format_port(split_result) + path = _sanitize_path(split_result.path) + return str(urlunparse((split_result.scheme, f"{host}{port}", path, "", "", ""))) + + +def validate_base_url(base_url: str | None) -> str: + """Validate base url.""" + if not base_url or not isinstance(base_url, str): + raise ValueError(INVALID_ENV_URL_MESSAGE) + + split_result = urlsplit(base_url, scheme="https") + if not split_result.scheme or not split_result.hostname: + raise ValueError(INVALID_ENV_URL_MESSAGE) + + return _build_sanitized_base_url(split_result) diff --git a/tests/unit/http/test_client_utils.py b/tests/unit/http/test_client_utils.py new file mode 100644 index 00000000..92ac018b --- /dev/null +++ b/tests/unit/http/test_client_utils.py @@ -0,0 +1,42 @@ +import pytest + +from mpt_api_client.http.client_utils import validate_base_url + + +@pytest.mark.parametrize( + ("input_url", "expected"), + [ + ("//[2001:db8:85a3::8a2e:370:7334]:80/a", "https://[2001:db8:85a3::8a2e:370:7334]:80/a"), + ("//example.com", "https://example.com"), + ("http://example.com", "http://example.com"), + ("http://example.com:88/something/else", "http://example.com:88/something/else"), + ("http://user@example.com:88/", "http://example.com:88"), + ("http://user:pass@example.com:88/", "http://example.com:88"), + ("http://example.com/public", "http://example.com"), + ("http://example.com/public/", "http://example.com"), + ("http://example.com/public/else", "http://example.com/public/else"), + ("http://example.com/public/v1", "http://example.com"), + ("http://example.com/public/v1/", "http://example.com"), + ("http://example.com/else/public", "http://example.com/else/public"), + ("http://example.com/elsepublic", "http://example.com/elsepublic"), + ], +) +def test_protocol_and_host(input_url, expected): + result = validate_base_url(input_url) + + assert result == expected + + +@pytest.mark.parametrize( + "input_url", + [ + "", + "http//example.com", + "://example.com", + "http:example.com", + "http:/example.com", + ], +) +def test_protocol_and_host_error(input_url): + with pytest.raises(ValueError): + validate_base_url(input_url)