Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions amazon_creatorsapi/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from amazon_creatorsapi.errors import ItemsNotFoundError
from creatorsapi_python_sdk.api.default_api import DefaultApi
from creatorsapi_python_sdk.api_client import ApiClient
from creatorsapi_python_sdk.configuration import Configuration
from creatorsapi_python_sdk.exceptions import ApiException
from creatorsapi_python_sdk.models.get_browse_nodes_request_content import (
GetBrowseNodesRequestContent,
Expand Down Expand Up @@ -58,6 +59,8 @@ class AmazonCreatorsApi:
country: Country code (e.g., "ES", "US"). Used to determine marketplace.
marketplace: Marketplace URL (e.g., "www.amazon.es"). Overrides country.
throttling: Wait time in seconds between API calls. Defaults to 1 second.
proxy: Optional HTTP proxy URL, e.g. ``"http://user:pass@proxy:3128"``.
Applied to both regular API calls and OAuth2 token refresh.

Raises:
InvalidArgumentError: If neither country nor marketplace is provided.
Expand All @@ -83,6 +86,7 @@ def __init__(
country: CountryCode | None = None,
marketplace: str | None = None,
throttling: float = DEFAULT_THROTTLING,
proxy: str | None = None,
) -> None:
"""Initialize the Amazon Creators API client."""
self._credential_id = credential_id
Expand All @@ -95,7 +99,11 @@ def __init__(
# Determine marketplace from country or direct value
self.marketplace = validate_and_get_marketplace(country, marketplace)

configuration = Configuration()
configuration.proxy = proxy

self._api_client = ApiClient(
configuration=configuration,
credential_id=credential_id,
credential_secret=credential_secret,
version=version,
Expand Down
4 changes: 3 additions & 1 deletion creatorsapi_python_sdk/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,9 @@ def call_api(
self.credential_id, self.credential_secret,
self.version, self.auth_endpoint
)
self._token_manager = OAuth2TokenManager(config)
proxy = self.configuration.proxy
proxies = {"http": proxy, "https": proxy} if proxy else None
self._token_manager = OAuth2TokenManager(config, proxies=proxies)
# Get token (will use cached token if valid)
token = self._token_manager.get_token()
# Add Authorization headers - Version only for v2.x
Expand Down
14 changes: 10 additions & 4 deletions creatorsapi_python_sdk/auth/oauth2_token_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@
class OAuth2TokenManager:
"""Manages OAuth2 token lifecycle including acquisition, caching, and automatic refresh"""

def __init__(self, config):
def __init__(self, config, proxies=None):
"""
Creates an OAuth2TokenManager instance

:param config: The OAuth2Config instance
:param proxies: Optional dict of proxy URLs, e.g. {"http": "http://proxy:3128", "https": "http://proxy:3128"}
"""
self.config = config
self.proxies = proxies
self.access_token = None
self.expires_at = None

Expand Down Expand Up @@ -67,6 +69,10 @@ def refresh_token(self):
:raises Exception: If token refresh fails
"""
try:
session = requests.Session()
if self.proxies:
session.proxies.update(self.proxies)

if self.config.is_lwa():
# LWA (v3.x) uses JSON body
request_data = {
Expand All @@ -76,7 +82,7 @@ def refresh_token(self):
'scope': self.config.get_scope()
}
headers = {'Content-Type': 'application/json'}
response = requests.post(
response = session.post(
self.config.get_cognito_endpoint(),
json=request_data,
headers=headers
Expand All @@ -90,7 +96,7 @@ def refresh_token(self):
'scope': self.config.get_scope()
}
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
response = requests.post(
response = session.post(
self.config.get_cognito_endpoint(),
data=request_data,
headers=headers
Expand Down
28 changes: 28 additions & 0 deletions tests/amazon_creatorsapi/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,34 @@ def test_init_no_country_or_marketplace(self) -> None:
tag=self.tag,
)

@mock.patch("amazon_creatorsapi.api.ApiClient")
def test_init_with_proxy(self, mock_client: MagicMock) -> None:
"""Test that proxy URL is passed through to ApiClient configuration."""
proxy_url = "http://user:pass@proxy.example.com:3128"
AmazonCreatorsApi(
credential_id=self.credential_id,
credential_secret=self.credential_secret,
version=self.version,
tag=self.tag,
country=self.country,
proxy=proxy_url,
)
call_kwargs = mock_client.call_args.kwargs
self.assertEqual(call_kwargs["configuration"].proxy, proxy_url)

@mock.patch("amazon_creatorsapi.api.ApiClient")
def test_init_without_proxy(self, mock_client: MagicMock) -> None:
"""Test that configuration.proxy is None when no proxy is provided."""
AmazonCreatorsApi(
credential_id=self.credential_id,
credential_secret=self.credential_secret,
version=self.version,
tag=self.tag,
country=self.country,
)
call_kwargs = mock_client.call_args.kwargs
self.assertIsNone(call_kwargs["configuration"].proxy)

@mock.patch("amazon_creatorsapi.api.ApiClient")
def test_throttling_disabled(self, _mock_client: MagicMock) -> None:
"""Test that API call is not delayed when throttling is 0."""
Expand Down
72 changes: 72 additions & 0 deletions tests/amazon_creatorsapi/oauth2_token_manager_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Unit tests for OAuth2TokenManager proxy support."""

from __future__ import annotations

import unittest
from unittest import mock
from unittest.mock import MagicMock, patch

from creatorsapi_python_sdk.auth.oauth2_config import OAuth2Config
from creatorsapi_python_sdk.auth.oauth2_token_manager import OAuth2TokenManager


def _make_config(version: str = "2.2") -> OAuth2Config:
return OAuth2Config(
credential_id="test_id",
credential_secret="test_secret",
version=version,
auth_endpoint=None,
)


def _mock_token_response() -> MagicMock:
resp = MagicMock()
resp.status_code = 200
resp.json.return_value = {"access_token": "tok123", "expires_in": 3600}
return resp


class TestOAuth2TokenManagerProxy(unittest.TestCase):
"""Tests that OAuth2TokenManager routes token refresh through the proxy."""

@patch("creatorsapi_python_sdk.auth.oauth2_token_manager.requests.Session")
def test_refresh_token_sets_proxies_on_session(self, mock_session_cls: MagicMock) -> None:
"""When proxies are provided, Session.proxies.update is called with them."""
proxy_url = "http://user:pass@proxy.example.com:3128"
proxies = {"http": proxy_url, "https": proxy_url}

mock_session = MagicMock()
mock_session.post.return_value = _mock_token_response()
mock_session_cls.return_value = mock_session

manager = OAuth2TokenManager(_make_config(), proxies=proxies)
manager.refresh_token()

mock_session.proxies.update.assert_called_once_with(proxies)

@patch("creatorsapi_python_sdk.auth.oauth2_token_manager.requests.Session")
def test_refresh_token_no_proxy_skips_proxies_update(self, mock_session_cls: MagicMock) -> None:
"""When no proxy is configured, Session.proxies.update is not called."""
mock_session = MagicMock()
mock_session.post.return_value = _mock_token_response()
mock_session_cls.return_value = mock_session

manager = OAuth2TokenManager(_make_config())
manager.refresh_token()

mock_session.proxies.update.assert_not_called()

@patch("creatorsapi_python_sdk.auth.oauth2_token_manager.requests.Session")
def test_refresh_token_lwa_sets_proxies_on_session(self, mock_session_cls: MagicMock) -> None:
"""Proxy is also applied for LWA (v3.x) token refresh."""
proxy_url = "http://proxy.example.com:3128"
proxies = {"http": proxy_url, "https": proxy_url}

mock_session = MagicMock()
mock_session.post.return_value = _mock_token_response()
mock_session_cls.return_value = mock_session

manager = OAuth2TokenManager(_make_config(version="3.1"), proxies=proxies)
manager.refresh_token()

mock_session.proxies.update.assert_called_once_with(proxies)