diff --git a/src/auth0/authentication/base.py b/src/auth0/authentication/base.py index 9d69d969..5bb3b9bc 100644 --- a/src/auth0/authentication/base.py +++ b/src/auth0/authentication/base.py @@ -2,11 +2,10 @@ from typing import Any +from .client_authentication import add_client_authentication from .rest import RestClient, RestClientOptions from .types import RequestData, TimeoutType -from .client_authentication import add_client_authentication - UNKNOWN_ERROR = "a0.sdk.internal.unknown" @@ -22,6 +21,9 @@ class AuthenticationBase: telemetry (bool, optional): Enable or disable telemetry (defaults to True) timeout (float or tuple, optional): Change the requests connect and read timeout. Pass a tuple to specify both values separately or a float to set both to it. (defaults to 5.0 for both) protocol (str, optional): Useful for testing. (defaults to 'https') + client_info (dict, optional): Custom telemetry data for the Auth0-Client header. + When provided, overrides the default SDK telemetry. Useful for wrapper + SDKs that need to identify themselves. Ignored when telemetry is False. """ def __init__( @@ -34,6 +36,7 @@ def __init__( telemetry: bool = True, timeout: TimeoutType = 5.0, protocol: str = "https", + client_info: dict[str, Any] | None = None, ) -> None: self.domain = domain self.client_id = client_id @@ -43,7 +46,9 @@ def __init__( self.protocol = protocol self.client = RestClient( None, - options=RestClientOptions(telemetry=telemetry, timeout=timeout, retries=0), + options=RestClientOptions( + telemetry=telemetry, timeout=timeout, retries=0, client_info=client_info + ), ) def _add_client_authentication(self, payload: dict[str, Any]) -> dict[str, Any]: diff --git a/src/auth0/authentication/rest.py b/src/auth0/authentication/rest.py index 78f0422b..09572096 100644 --- a/src/auth0/authentication/rest.py +++ b/src/auth0/authentication/rest.py @@ -10,7 +10,6 @@ from urllib.parse import urlencode import requests - from .exceptions import Auth0Error, RateLimitError from .types import RequestData, TimeoutType @@ -38,6 +37,13 @@ class RestClientOptions: times using an exponential backoff strategy, before raising a RateLimitError exception. 10 retries max. (defaults to 3) + client_info (dict, optional): Custom telemetry data to send + in the Auth0-Client header instead of the default SDK + info. Useful for wrapper SDKs that need to identify + themselves. When provided, this dict is JSON-encoded + and base64-encoded as the header value. Ignored when + telemetry is False. + (defaults to None) """ def __init__( @@ -45,10 +51,12 @@ def __init__( telemetry: bool = True, timeout: TimeoutType = 5.0, retries: int = 3, + client_info: dict[str, Any] | None = None, ) -> None: self.telemetry = telemetry self.timeout = timeout self.retries = retries + self.client_info = client_info class RestClient: @@ -94,17 +102,20 @@ def __init__( if options.telemetry: py_version = platform.python_version() - version = sys.modules["auth0"].__version__ - auth0_client = dumps( - { + if options.client_info is not None: + auth0_client_dict = options.client_info + else: + version = sys.modules["auth0"].__version__ + auth0_client_dict = { "name": "auth0-python", "version": version, "env": { "python": py_version, }, } - ).encode("utf-8") + + auth0_client = dumps(auth0_client_dict).encode("utf-8") self.base_headers.update( { diff --git a/src/auth0/management/management_client.py b/src/auth0/management/management_client.py index 43d4d1d7..8aa83295 100644 --- a/src/auth0/management/management_client.py +++ b/src/auth0/management/management_client.py @@ -1,6 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Dict, Optional, Union +import base64 +from json import dumps +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union import httpx from .client import AsyncAuth0, Auth0 @@ -86,6 +88,10 @@ class ManagementClient: The API audience. Defaults to https://{domain}/api/v2/ headers : Optional[Dict[str, str]] Additional headers to send with requests. + client_info : Optional[Dict[str, Any]] + Custom telemetry data for the Auth0-Client header. When provided, + overrides the default SDK telemetry. Useful for wrapper SDKs that + need to identify themselves (e.g., ``{"name": "my-sdk", "version": "1.0.0"}``). timeout : Optional[float] Request timeout in seconds. Defaults to 60. httpx_client : Optional[httpx.Client] @@ -106,6 +112,7 @@ def __init__( client_secret: Optional[str] = None, audience: Optional[str] = None, headers: Optional[Dict[str, str]] = None, + client_info: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, httpx_client: Optional[httpx.Client] = None, ): @@ -128,6 +135,13 @@ def __init__( else: resolved_token = token # type: ignore[assignment] + # Encode client_info into Auth0-Client header to override default telemetry + if client_info is not None: + encoded = base64.b64encode( + dumps(client_info).encode("utf-8") + ).decode() + headers = {**(headers or {}), "Auth0-Client": encoded} + # Create underlying client self._api = Auth0( base_url=f"https://{domain}/api/v2", @@ -333,6 +347,10 @@ class AsyncManagementClient: The API audience. Defaults to https://{domain}/api/v2/ headers : Optional[Dict[str, str]] Additional headers to send with requests. + client_info : Optional[Dict[str, Any]] + Custom telemetry data for the Auth0-Client header. When provided, + overrides the default SDK telemetry. Useful for wrapper SDKs that + need to identify themselves (e.g., ``{"name": "my-sdk", "version": "1.0.0"}``). timeout : Optional[float] Request timeout in seconds. Defaults to 60. httpx_client : Optional[httpx.AsyncClient] @@ -353,6 +371,7 @@ def __init__( client_secret: Optional[str] = None, audience: Optional[str] = None, headers: Optional[Dict[str, str]] = None, + client_info: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, httpx_client: Optional[httpx.AsyncClient] = None, ): @@ -378,6 +397,13 @@ def __init__( else: resolved_token = token # type: ignore[assignment] + # Encode client_info into Auth0-Client header to override default telemetry + if client_info is not None: + encoded = base64.b64encode( + dumps(client_info).encode("utf-8") + ).decode() + headers = {**(headers or {}), "Auth0-Client": encoded} + # Create underlying client self._api = AsyncAuth0( base_url=f"https://{domain}/api/v2", diff --git a/tests/authentication/test_base.py b/tests/authentication/test_base.py index c98dc00b..0772625a 100644 --- a/tests/authentication/test_base.py +++ b/tests/authentication/test_base.py @@ -4,8 +4,6 @@ import unittest from unittest import mock -import requests - from auth0.authentication.base import AuthenticationBase from auth0.authentication.exceptions import Auth0Error, RateLimitError @@ -42,6 +40,39 @@ def test_telemetry_disabled(self): self.assertEqual(ab.client.base_headers, {"Content-Type": "application/json"}) + def test_telemetry_with_custom_client_info(self): + custom_info = { + "name": "auth0-ai-langchain", + "version": "1.0.0", + "env": {"python": "3.11.0"}, + } + ab = AuthenticationBase("auth0.com", "cid", client_info=custom_info) + base_headers = ab.client.base_headers + + auth0_client_bytes = base64.b64decode(base_headers["Auth0-Client"]) + auth0_client = json.loads(auth0_client_bytes.decode("utf-8")) + + self.assertEqual(auth0_client, custom_info) + + def test_telemetry_disabled_ignores_client_info(self): + custom_info = {"name": "my-sdk", "version": "2.0.0"} + ab = AuthenticationBase( + "auth0.com", "cid", telemetry=False, client_info=custom_info + ) + + self.assertNotIn("Auth0-Client", ab.client.base_headers) + self.assertNotIn("User-Agent", ab.client.base_headers) + + def test_custom_client_info_preserves_user_agent(self): + custom_info = {"name": "my-sdk", "version": "1.0.0"} + ab = AuthenticationBase("auth0.com", "cid", client_info=custom_info) + base_headers = ab.client.base_headers + + python_version = "{}.{}.{}".format( + sys.version_info.major, sys.version_info.minor, sys.version_info.micro + ) + self.assertEqual(base_headers["User-Agent"], f"Python/{python_version}") + @mock.patch("requests.request") def test_post(self, mock_request): ab = AuthenticationBase("auth0.com", "cid", telemetry=False, timeout=(10, 2)) diff --git a/tests/management/test_management_client.py b/tests/management/test_management_client.py index c09c9c58..764cc77e 100644 --- a/tests/management/test_management_client.py +++ b/tests/management/test_management_client.py @@ -1,3 +1,5 @@ +import base64 +import json import time from unittest.mock import MagicMock, patch @@ -78,6 +80,53 @@ def test_init_with_custom_headers(self): ) assert client._api is not None + def test_init_with_custom_client_info(self): + """Should encode client_info as Auth0-Client header.""" + custom_info = { + "name": "auth0-ai-langchain", + "version": "1.0.0", + "env": {"python": "3.11.0"}, + } + client = ManagementClient( + domain="test.auth0.com", + token="my-token", + client_info=custom_info, + ) + # Verify the header was set on the underlying client wrapper + custom_headers = client._api._client_wrapper.get_custom_headers() + assert custom_headers is not None + encoded_header = custom_headers.get("Auth0-Client") + assert encoded_header is not None + decoded = json.loads(base64.b64decode(encoded_header).decode("utf-8")) + assert decoded == custom_info + + def test_init_with_client_info_and_custom_headers(self): + """Should merge client_info with custom headers.""" + custom_info = {"name": "my-sdk", "version": "2.0.0"} + client = ManagementClient( + domain="test.auth0.com", + token="my-token", + headers={"X-Custom": "value"}, + client_info=custom_info, + ) + custom_headers = client._api._client_wrapper.get_custom_headers() + assert custom_headers is not None + assert custom_headers.get("X-Custom") == "value" + assert "Auth0-Client" in custom_headers + + def test_init_without_client_info_uses_default_telemetry(self): + """Should use default auth0-python telemetry when client_info is not provided.""" + client = ManagementClient( + domain="test.auth0.com", + token="my-token", + ) + # get_headers() includes the default Auth0-Client telemetry + headers = client._api._client_wrapper.get_headers() + encoded = headers.get("Auth0-Client") + assert encoded is not None + decoded = json.loads(base64.b64decode(encoded).decode("utf-8")) + assert decoded["name"] == "auth0-python" + class TestManagementClientProperties: """Tests for ManagementClient sub-client properties.""" @@ -173,6 +222,25 @@ def test_init_requires_auth(self): with pytest.raises(ValueError): AsyncManagementClient(domain="test.auth0.com") + def test_init_with_custom_client_info(self): + """Should encode client_info as Auth0-Client header.""" + custom_info = { + "name": "auth0-ai-langchain", + "version": "1.0.0", + "env": {"python": "3.11.0"}, + } + client = AsyncManagementClient( + domain="test.auth0.com", + token="my-token", + client_info=custom_info, + ) + custom_headers = client._api._client_wrapper.get_custom_headers() + assert custom_headers is not None + encoded_header = custom_headers.get("Auth0-Client") + assert encoded_header is not None + decoded = json.loads(base64.b64decode(encoded_header).decode("utf-8")) + assert decoded == custom_info + class TestTokenProvider: """Tests for TokenProvider."""