Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions auth0/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import base64
import platform
import sys
from contextvars import ContextVar
from json import dumps, loads
from random import randint
from time import sleep
Expand All @@ -19,6 +20,12 @@

UNKNOWN_ERROR = "a0.sdk.internal.unknown"

# Context variable to store response headers in a thread-safe and async-safe manner
# Each execution context (thread or async task) gets its own isolated copy
_response_headers: ContextVar[dict[str, str]] = ContextVar(
"response_headers", default={}
)


class RestClientOptions:
"""Configuration object for RestClient. Used for configuring
Expand Down Expand Up @@ -85,6 +92,9 @@ def __init__(
self._metrics = {"retries": 0, "backoff": []}
self._skip_sleep = False

# Initialize context variable for this client instance
_response_headers.set({})

self.base_headers = {
"Content-Type": "application/json",
}
Expand Down Expand Up @@ -121,6 +131,26 @@ def __init__(
self.telemetry = options.telemetry
self.timeout = options.timeout

@property
def last_response_headers(self) -> dict[str, str]:
"""Get the headers from the most recent API response.

This property is thread-safe and async-safe, using context variables
to isolate response headers per execution context (thread or async task).

Returns:
dict[str, str]: Response headers including rate-limit information
(X-RateLimit-Limit, X-RateLimit-Remaining, X-RateLimit-Reset).
Returns an empty dict if no request has been made yet.

Example:
>>> users = Users(domain="tenant.auth0.com", token="token")
>>> users.create({"email": "user@example.com"})
>>> headers = users.client.last_response_headers
>>> remaining = int(headers.get("X-RateLimit-Remaining", 0))
"""
return _response_headers.get()

# Returns a hard cap for the maximum number of retries allowed (10)
def MAX_REQUEST_RETRIES(self) -> int:
return 10
Expand Down Expand Up @@ -262,11 +292,15 @@ def _calculate_wait(self, attempt: int) -> int:
return wait

def _process_response(self, response: requests.Response) -> Any:
return self._parse(response).content()
parsed_response = self._parse(response)
content = parsed_response.content()
# Store headers in context variable for thread-safe/async-safe access
_response_headers.set(dict(parsed_response._headers))
return content

def _parse(self, response: requests.Response) -> Response:
if not response.text:
return EmptyResponse(response.status_code)
return EmptyResponse(response.status_code, response.headers)
try:
return JsonResponse(response)
except ValueError:
Expand Down Expand Up @@ -356,8 +390,8 @@ def _error_message(self) -> str:


class EmptyResponse(Response):
def __init__(self, status_code: int) -> None:
super().__init__(status_code, "", {})
def __init__(self, status_code: int, headers: Mapping[str, str] | None = None) -> None:
super().__init__(status_code, "", headers or {})

def _error_code(self) -> str:
return UNKNOWN_ERROR
Expand Down
234 changes: 234 additions & 0 deletions auth0/test/test_rest_headers_contextvar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
"""Tests for context-var based response headers in RestClient.

Tests verify that headers are properly isolated across threads and async contexts.
"""

import threading
import time
import unittest
from unittest import mock

import responses

from auth0.rest import RestClient, RestClientOptions


class TestRestClientHeadersContextVar(unittest.TestCase):
"""Test that response headers are properly stored and accessed via contextvars."""

@responses.activate
def test_headers_accessible_after_request(self):
"""Test that headers are stored and accessible after a successful request."""
responses.add(
responses.GET,
"https://example.com/api/test",
json={"result": "ok"},
status=200,
headers={
"X-RateLimit-Limit": "60",
"X-RateLimit-Remaining": "59",
"X-RateLimit-Reset": "1640000000",
},
)

client = RestClient(jwt="test-token")
result = client.get("https://example.com/api/test")

self.assertEqual(result, {"result": "ok"})
headers = client.last_response_headers
self.assertEqual(headers.get("X-RateLimit-Limit"), "60")
self.assertEqual(headers.get("X-RateLimit-Remaining"), "59")
self.assertEqual(headers.get("X-RateLimit-Reset"), "1640000000")

@responses.activate
def test_headers_on_204_response(self):
"""Test that headers are preserved on 204 No Content responses."""
responses.add(
responses.DELETE,
"https://example.com/api/resource/123",
status=204,
headers={
"X-RateLimit-Limit": "30",
"X-RateLimit-Remaining": "25",
"X-RateLimit-Reset": "1640000100",
},
)

client = RestClient(jwt="test-token")
result = client.delete("https://example.com/api/resource/123")

# 204 returns empty content
self.assertEqual(result, "")
# But headers should still be accessible
headers = client.last_response_headers
self.assertEqual(headers.get("X-RateLimit-Limit"), "30")
self.assertEqual(headers.get("X-RateLimit-Remaining"), "25")

@responses.activate
def test_headers_updated_on_successive_requests(self):
"""Test that headers are updated with each new request."""
# First request
responses.add(
responses.GET,
"https://example.com/api/test1",
json={"id": 1},
status=200,
headers={"X-RateLimit-Remaining": "59"},
)

# Second request
responses.add(
responses.GET,
"https://example.com/api/test2",
json={"id": 2},
status=200,
headers={"X-RateLimit-Remaining": "58"},
)

client = RestClient(jwt="test-token")

# First request
client.get("https://example.com/api/test1")
self.assertEqual(client.last_response_headers.get("X-RateLimit-Remaining"), "59")

# Second request should update headers
client.get("https://example.com/api/test2")
self.assertEqual(client.last_response_headers.get("X-RateLimit-Remaining"), "58")

@responses.activate
def test_headers_empty_initially(self):
"""Test that headers are empty before any request is made."""
client = RestClient(jwt="test-token")
headers = client.last_response_headers
self.assertEqual(headers, {})

@responses.activate
def test_thread_isolation(self):
"""Test that response headers are isolated between threads.

This is the key thread-safety test: each thread should have its own
response headers when using contextvars.
"""
results = {}
errors = []

# Setup responses with different headers for each endpoint
responses.add(
responses.GET,
"https://example.com/api/thread1",
json={"thread": 1},
status=200,
headers={"X-RateLimit-Remaining": "100"},
)

responses.add(
responses.GET,
"https://example.com/api/thread2",
json={"thread": 2},
status=200,
headers={"X-RateLimit-Remaining": "200"},
)

def thread_worker(thread_id: int, endpoint: str, remaining: str):
"""Worker function for thread test."""
try:
client = RestClient(jwt="test-token")
client.get(f"https://example.com/api/{endpoint}")
# Each thread should see its own headers, not contaminated by other threads
results[thread_id] = client.last_response_headers.get(
"X-RateLimit-Remaining"
)
except Exception as e:
errors.append(str(e))

# Start two threads that make requests simultaneously
thread1 = threading.Thread(target=thread_worker, args=(1, "thread1", "100"))
thread2 = threading.Thread(target=thread_worker, args=(2, "thread2", "200"))

thread1.start()
thread2.start()

thread1.join()
thread2.join()

# Verify no errors occurred
self.assertEqual(errors, [], f"Errors in threads: {errors}")

# Verify each thread got the correct headers for its request
self.assertEqual(results[1], "100", "Thread 1 should see its own headers")
self.assertEqual(results[2], "200", "Thread 2 should see its own headers")

@responses.activate
def test_headers_in_same_context_reflect_latest_request(self):
"""Test that in the same execution context, headers reflect the latest request.

Contextvars are context-specific (thread or async task), not client-specific.
When multiple clients make requests in the same context, the contextvar reflects
the most recent response. For isolation per client, use different threads.
"""
responses.add(
responses.GET,
"https://example.com/api/request1",
json={"request": 1},
status=200,
headers={"X-Request-ID": "request1"},
)

responses.add(
responses.GET,
"https://example.com/api/request2",
json={"request": 2},
status=200,
headers={"X-Request-ID": "request2"},
)

client1 = RestClient(jwt="token1")
client2 = RestClient(jwt="token2")

# First request
client1.get("https://example.com/api/request1")
self.assertEqual(client1.last_response_headers.get("X-Request-ID"), "request1")

# Second request in same context overwrites the contextvar
client2.get("https://example.com/api/request2")
# Both clients see the latest headers because they're in the same context
self.assertEqual(client1.last_response_headers.get("X-Request-ID"), "request2")
self.assertEqual(client2.last_response_headers.get("X-Request-ID"), "request2")

@responses.activate
def test_post_request_headers(self):
"""Test that headers are captured on POST requests."""
responses.add(
responses.POST,
"https://example.com/api/create",
json={"id": "new-id"},
status=201,
headers={"X-RateLimit-Remaining": "55"},
)

client = RestClient(jwt="test-token")
result = client.post("https://example.com/api/create", data={"name": "test"})

self.assertEqual(result["id"], "new-id")
self.assertEqual(client.last_response_headers.get("X-RateLimit-Remaining"), "55")

@responses.activate
def test_patch_request_headers(self):
"""Test that headers are captured on PATCH requests."""
responses.add(
responses.PATCH,
"https://example.com/api/update/123",
json={"id": "123", "updated": True},
status=200,
headers={"X-RateLimit-Remaining": "54"},
)

client = RestClient(jwt="test-token")
result = client.patch("https://example.com/api/update/123", data={"name": "updated"})

self.assertEqual(result["updated"], True)
self.assertEqual(client.last_response_headers.get("X-RateLimit-Remaining"), "54")


if __name__ == "__main__":
unittest.main()