Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions mpt_api_client/http/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mpt_api_client.constants import APPLICATION_JSON
from mpt_api_client.exceptions import MPTError, transform_http_status_exception
from mpt_api_client.http.client import json_to_file_payload
from mpt_api_client.http.client_utils import validate_base_url
from mpt_api_client.http.types import (
HeaderTypes,
QueryParam,
Expand Down Expand Up @@ -38,13 +39,7 @@ def __init__(
"argument to MPTClient."
)

base_url = base_url or os.getenv("MPT_URL")
if not base_url:
raise ValueError(
"Base URL is required. "
"Set it up as env variable MPT_URL or pass it as `base_url` "
"argument to MPTClient."
)
base_url = validate_base_url(base_url or os.getenv("MPT_URL"))
base_headers = {
"User-Agent": "swo-marketplace-client/1.0",
"Authorization": f"Bearer {api_token}",
Expand Down
9 changes: 2 additions & 7 deletions mpt_api_client/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
MPTError,
transform_http_status_exception,
)
from mpt_api_client.http.client_utils import validate_base_url
from mpt_api_client.http.types import (
HeaderTypes,
QueryParam,
Expand Down Expand Up @@ -51,13 +52,7 @@ def __init__(
"argument to MPTClient."
)

base_url = base_url or os.getenv("MPT_URL")
if not base_url:
raise ValueError(
"Base URL is required. "
"Set it up as env variable MPT_URL or pass it as `base_url` "
"argument to MPTClient."
)
base_url = validate_base_url(base_url or os.getenv("MPT_URL"))
base_headers = {
"User-Agent": "swo-marketplace-client/1.0",
"Authorization": f"Bearer {api_token}",
Expand Down
47 changes: 47 additions & 0 deletions mpt_api_client/http/client_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import re
from urllib.parse import SplitResult, urlsplit, urlunparse

INVALID_ENV_URL_MESSAGE = (
"Base URL is required. "
"Set it up as env variable MPT_URL or pass it as `base_url` "
"argument to MPTClient. Expected format scheme://host[:port]"
)
PATHS_TO_REMOVE_RE = re.compile(r"^/$|^/public/?$|^/public/v1/?$")


def _format_host(hostname: str | None) -> str:
if not hostname or not isinstance(hostname, str):
raise ValueError(INVALID_ENV_URL_MESSAGE)

return f"[{hostname}]" if ":" in hostname else hostname


def _format_port(split_result: SplitResult) -> str:
try:
parsed_port = split_result.port
except ValueError as exc:
raise ValueError(INVALID_ENV_URL_MESSAGE) from exc
return f":{parsed_port}" if parsed_port else ""


def _sanitize_path(path: str) -> str:
return PATHS_TO_REMOVE_RE.sub("", path)


def _build_sanitized_base_url(split_result: SplitResult) -> str:
host = _format_host(split_result.hostname)
port = _format_port(split_result)
path = _sanitize_path(split_result.path)
return str(urlunparse((split_result.scheme, f"{host}{port}", path, "", "", "")))


def validate_base_url(base_url: str | None) -> str:
"""Validate base url."""
if not base_url or not isinstance(base_url, str):
raise ValueError(INVALID_ENV_URL_MESSAGE)

split_result = urlsplit(base_url, scheme="https")
if not split_result.scheme or not split_result.hostname:
raise ValueError(INVALID_ENV_URL_MESSAGE)

return _build_sanitized_base_url(split_result)
42 changes: 42 additions & 0 deletions tests/unit/http/test_client_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest

from mpt_api_client.http.client_utils import validate_base_url


@pytest.mark.parametrize(
("input_url", "expected"),
[
("//[2001:db8:85a3::8a2e:370:7334]:80/a", "https://[2001:db8:85a3::8a2e:370:7334]:80/a"),
("//example.com", "https://example.com"),
("http://example.com", "http://example.com"),
("http://example.com:88/something/else", "http://example.com:88/something/else"),
("http://user@example.com:88/", "http://example.com:88"),
("http://user:pass@example.com:88/", "http://example.com:88"),
("http://example.com/public", "http://example.com"),
("http://example.com/public/", "http://example.com"),
("http://example.com/public/else", "http://example.com/public/else"),
("http://example.com/public/v1", "http://example.com"),
("http://example.com/public/v1/", "http://example.com"),
("http://example.com/else/public", "http://example.com/else/public"),
("http://example.com/elsepublic", "http://example.com/elsepublic"),
],
)
def test_protocol_and_host(input_url, expected):
result = validate_base_url(input_url)

assert result == expected


@pytest.mark.parametrize(
"input_url",
[
"",
"http//example.com",
"://example.com",
"http:example.com",
"http:/example.com",
],
)
def test_protocol_and_host_error(input_url):
with pytest.raises(ValueError):
validate_base_url(input_url)