diff --git a/benchmarks/bench_credential_cache.py b/benchmarks/bench_credential_cache.py new file mode 100644 index 000000000..4d7403f06 --- /dev/null +++ b/benchmarks/bench_credential_cache.py @@ -0,0 +1,87 @@ +""" +Benchmark: Credential Instance Caching for Azure AD Authentication + +Measures the performance difference between: + 1. Creating a new DefaultAzureCredential + get_token() each call (old behavior) + 2. Reusing a cached DefaultAzureCredential instance (new behavior) + +Prerequisites: + - pip install azure-identity azure-core + - az login (for AzureCliCredential to work) + +Usage: + python benchmarks/bench_credential_cache.py +""" + +from __future__ import annotations + +import time +import statistics + + +def bench_no_cache(n: int) -> list[float]: + """Simulate the OLD behavior: new credential per call.""" + from azure.identity import DefaultAzureCredential + + times = [] + for _ in range(n): + start = time.perf_counter() + cred = DefaultAzureCredential() + cred.get_token("https://database.windows.net/.default") + times.append(time.perf_counter() - start) + return times + + +def bench_with_cache(n: int) -> list[float]: + """Simulate the NEW behavior: reuse a single credential instance.""" + from azure.identity import DefaultAzureCredential + + cred = DefaultAzureCredential() + times = [] + for _ in range(n): + start = time.perf_counter() + cred.get_token("https://database.windows.net/.default") + times.append(time.perf_counter() - start) + return times + + +def report(label: str, times: list[float]) -> None: + print(f"\n{'=' * 50}") + print(f" {label}") + print(f"{'=' * 50}") + print(f" Calls: {len(times)}") + print(f" Total: {sum(times):.3f}s") + print(f" Mean: {statistics.mean(times) * 1000:.1f}ms") + print(f" Median: {statistics.median(times) * 1000:.1f}ms") + print(f" Stdev: {statistics.stdev(times) * 1000:.1f}ms" if len(times) > 1 else "") + print(f" Min: {min(times) * 1000:.1f}ms") + print(f" Max: {max(times) * 1000:.1f}ms") + + +def main() -> None: + N = 10 # number of calls to benchmark + + print("Credential Instance Cache Benchmark") + print(f"Running {N} sequential token acquisitions for each scenario...\n") + + try: + print(">>> Without cache (new credential each call)...") + no_cache_times = bench_no_cache(N) + report("WITHOUT credential cache (old behavior)", no_cache_times) + + print("\n>>> With cache (reuse credential instance)...") + cache_times = bench_with_cache(N) + report("WITH credential cache (new behavior)", cache_times) + + speedup = statistics.mean(no_cache_times) / statistics.mean(cache_times) + saved = (statistics.mean(no_cache_times) - statistics.mean(cache_times)) * 1000 + print(f"\n{'=' * 50}") + print(f" SPEEDUP: {speedup:.1f}x ({saved:.0f}ms saved per call)") + print(f"{'=' * 50}") + except Exception as e: + print(f"\nBenchmark failed: {e}") + print("Make sure you are logged in via 'az login' and have azure-identity installed.") + + +if __name__ == "__main__": + main() diff --git a/mssql_python/auth.py b/mssql_python/auth.py index e2ef6e7e1..40c3e06e2 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -6,11 +6,19 @@ import platform import struct +import threading from typing import Tuple, Dict, Optional, List from mssql_python.logging import logger from mssql_python.constants import AuthType, ConstantsDDBC +# Module-level credential instance cache. +# Reusing credential objects allows the Azure Identity SDK's built-in +# in-memory token cache to work, avoiding redundant token acquisitions. +# See: https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/identity/azure-identity/TOKEN_CACHING.md +_credential_cache: Dict[str, object] = {} +_credential_cache_lock = threading.Lock() + class AADAuth: """Handles Azure Active Directory authentication""" @@ -36,12 +44,11 @@ def get_token(auth_type: str) -> bytes: @staticmethod def get_raw_token(auth_type: str) -> str: - """Acquire a fresh raw JWT for the mssql-py-core connection (bulk copy). + """Acquire a raw JWT for the mssql-py-core connection (bulk copy). - This deliberately does NOT cache the credential or token — each call - creates a new Azure Identity credential instance and requests a token. - A fresh acquisition avoids expired-token errors when bulkcopy() is - called long after the original DDBC connect(). + Uses the cached credential instance so the Azure Identity SDK's + built-in token cache can serve a valid token without a round-trip + when the previous token has not yet expired. """ _, raw_token = AADAuth._acquire_token(auth_type) return raw_token @@ -83,7 +90,19 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: ) try: - credential = credential_class() + with _credential_cache_lock: + if auth_type not in _credential_cache: + logger.debug( + "get_token: Creating new credential instance for auth_type=%s", + auth_type, + ) + _credential_cache[auth_type] = credential_class() + else: + logger.debug( + "get_token: Reusing cached credential instance for auth_type=%s", + auth_type, + ) + credential = _credential_cache[auth_type] raw_token = credential.get_token("https://database.windows.net/.default").token logger.info( "get_token: Azure AD token acquired successfully - token_length=%d chars", diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index f44bf86e2..f680518bc 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -7,6 +7,7 @@ import pytest import platform import sys +import threading from unittest.mock import patch, MagicMock from mssql_python.auth import ( AADAuth, @@ -15,6 +16,8 @@ get_auth_token, process_connection_string, extract_auth_type, + _credential_cache, + _credential_cache_lock, ) from mssql_python.constants import AuthType, ConstantsDDBC import secrets @@ -71,6 +74,14 @@ class exceptions: del sys.modules[module] +@pytest.fixture(autouse=True) +def clear_credential_cache(): + """Clear the module-level credential cache between tests.""" + _credential_cache.clear() + yield + _credential_cache.clear() + + class TestAuthType: def test_auth_type_constants(self): assert AuthType.INTERACTIVE.value == "activedirectoryinteractive" @@ -403,6 +414,201 @@ def test_unsupported_auth(self): assert extract_auth_type("Server=test;Authentication=SqlPassword;") is None +class TestCredentialInstanceCache: + """Tests for the credential instance caching behavior.""" + + def test_credential_reused_across_calls(self): + """The same credential instance should be returned for repeated calls.""" + AADAuth.get_token("default") + assert "default" in _credential_cache + first_instance = _credential_cache["default"] + + AADAuth.get_token("default") + assert _credential_cache["default"] is first_instance + + def test_different_auth_types_get_separate_instances(self): + """Each auth type should have its own cached credential.""" + AADAuth.get_token("default") + AADAuth.get_token("devicecode") + + assert "default" in _credential_cache + assert "devicecode" in _credential_cache + assert _credential_cache["default"] is not _credential_cache["devicecode"] + + def test_get_raw_token_uses_cached_credential(self): + """get_raw_token should also use the cached credential instance.""" + AADAuth.get_token("default") + cached = _credential_cache["default"] + + AADAuth.get_raw_token("default") + assert _credential_cache["default"] is cached + + def test_cache_starts_empty(self): + """Cache should be empty at the start due to the clear_credential_cache fixture.""" + assert len(_credential_cache) == 0 + + def test_cached_credential_refreshes_token_after_expiry(self): + """Verify that the cached credential instance returns fresh tokens on each call. + + This simulates what happens when Azure Identity SDK refreshes an expired + token internally: because we cache the credential (not the token), each + _acquire_token() call invokes get_token() on the same instance, giving + the SDK the opportunity to return a refreshed token when the old one has + expired. + """ + import sys + + azure_identity = sys.modules["azure.identity"] + original = azure_identity.DefaultAzureCredential + + call_count = 0 + tokens = ["initial_token_abc123", "refreshed_token_xyz789"] + + class MockCredentialWithRefresh: + def get_token(self, scope): + nonlocal call_count + idx = min(call_count, len(tokens) - 1) + call_count += 1 + + class Token: + token = tokens[idx] + + return Token() + + try: + azure_identity.DefaultAzureCredential = MockCredentialWithRefresh + + # First call — gets initial token + _, raw_token_1 = AADAuth._acquire_token("default") + assert raw_token_1 == "initial_token_abc123" + assert call_count == 1 + + # Same credential instance is cached + cached = _credential_cache["default"] + assert isinstance(cached, MockCredentialWithRefresh) + + # Second call — same credential instance, but SDK returns refreshed token + # (simulating post-expiry refresh) + _, raw_token_2 = AADAuth._acquire_token("default") + assert raw_token_2 == "refreshed_token_xyz789" + assert call_count == 2 + + # Credential instance is still the same (not recreated) + assert _credential_cache["default"] is cached + finally: + azure_identity.DefaultAzureCredential = original + + +class TestAcquireTokenImportError: + """Test the ImportError path when azure-identity is not installed.""" + + def test_import_error_raises_runtime_error(self): + """_acquire_token raises RuntimeError when azure.identity is missing.""" + import sys + + # Temporarily remove the mocked azure modules + saved = {} + for mod_name in list(sys.modules): + if mod_name == "azure" or mod_name.startswith("azure."): + saved[mod_name] = sys.modules.pop(mod_name) + + # Make the import fail + import builtins + + real_import = builtins.__import__ + + def blocked_import(name, *args, **kwargs): + if name.startswith("azure"): + raise ImportError("No module named 'azure'") + return real_import(name, *args, **kwargs) + + builtins.__import__ = blocked_import + try: + with pytest.raises( + RuntimeError, match="Azure authentication libraries are not installed" + ): + AADAuth._acquire_token("default") + finally: + builtins.__import__ = real_import + sys.modules.update(saved) + + +class TestAcquireTokenClientAuthError: + """Test the ClientAuthenticationError path inside _acquire_token.""" + + def test_client_auth_error_in_acquire_token(self): + """ClientAuthenticationError during get_token is wrapped in RuntimeError.""" + import sys + + azure_identity = sys.modules["azure.identity"] + original = azure_identity.DefaultAzureCredential + + from azure.core.exceptions import ClientAuthenticationError + + class FailingCredential: + def get_token(self, scope): + raise ClientAuthenticationError("token request denied") + + try: + azure_identity.DefaultAzureCredential = FailingCredential + with pytest.raises(RuntimeError, match="Azure AD authentication failed"): + AADAuth._acquire_token("default") + finally: + azure_identity.DefaultAzureCredential = original + + +class TestProcessAuthParametersEdgeCases: + """Cover empty-param and no-equals-sign branches.""" + + def test_empty_and_whitespace_params_skipped(self): + params = ["Server=test", "", " ", "Database=db"] + modified, auth_type = process_auth_parameters(params) + assert "Server=test" in modified + assert "Database=db" in modified + assert auth_type is None + + def test_param_without_equals_kept(self): + params = ["Server=test", "SomeFlag", "Database=db"] + modified, auth_type = process_auth_parameters(params) + assert "SomeFlag" in modified + assert "Server=test" in modified + + +class TestGetAuthTokenEdgeCases: + """Cover the Windows-interactive and token-failure branches.""" + + def test_no_auth_type_returns_none(self): + result = get_auth_token(None) + assert result is None + + def test_empty_auth_type_returns_none(self): + result = get_auth_token("") + assert result is None + + def test_windows_interactive_returns_none(self, monkeypatch): + monkeypatch.setattr(platform, "system", lambda: "Windows") + result = get_auth_token("interactive") + assert result is None + + def test_token_acquisition_failure_returns_none(self): + """When AADAuth.get_token raises, get_auth_token returns None.""" + import sys + + azure_identity = sys.modules["azure.identity"] + original = azure_identity.DefaultAzureCredential + + class FailingCredential: + def __init__(self): + raise RuntimeError("credential creation exploded") + + try: + azure_identity.DefaultAzureCredential = FailingCredential + result = get_auth_token("default") + assert result is None + finally: + azure_identity.DefaultAzureCredential = original + + def test_acquire_token_unsupported_auth_type(): with pytest.raises(ValueError, match="Unsupported auth_type 'bogus'"): AADAuth._acquire_token("bogus") @@ -417,3 +623,164 @@ def test_auth_type_stored_on_connection(self, mock_ddbc_conn): conn = connect("Server=test;Database=testdb;Authentication=ActiveDirectoryDefault") assert conn._auth_type == "default" conn.close() + + +class TestCredentialCacheThreadSafety: + """Verify thread-safe behavior of credential instance cache.""" + + def test_concurrent_access_creates_only_one_instance(self): + """Multiple threads calling get_token concurrently should result in + exactly one credential instance per auth type in the cache.""" + import sys + + azure_identity = sys.modules["azure.identity"] + original = azure_identity.DefaultAzureCredential + + instances_created = [] + + class TrackingCredential: + def __init__(self): + instances_created.append(self) + + def get_token(self, scope): + class Token: + token = SAMPLE_TOKEN + + return Token() + + try: + azure_identity.DefaultAzureCredential = TrackingCredential + + errors = [] + barrier = threading.Barrier(10) + + def worker(): + try: + barrier.wait(timeout=5) + AADAuth.get_token("default") + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=worker) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + + assert not errors, f"Threads raised errors: {errors}" + # Only one credential instance should exist in the cache + assert "default" in _credential_cache + # All threads should use the same cached instance + cached = _credential_cache["default"] + assert isinstance(cached, TrackingCredential) + # Due to the lock, only one instance should have been created + assert len(instances_created) == 1 + finally: + azure_identity.DefaultAzureCredential = original + + +class TestCacheStateAfterErrors: + """Verify credential cache state after various error scenarios.""" + + def test_client_auth_error_leaves_credential_in_cache(self): + """When get_token raises ClientAuthenticationError, the credential + instance should still remain in the cache since it was created + successfully — only the token acquisition failed.""" + import sys + + azure_identity = sys.modules["azure.identity"] + original = azure_identity.DefaultAzureCredential + from azure.core.exceptions import ClientAuthenticationError + + class CredentialThatFailsGetToken: + def get_token(self, scope): + raise ClientAuthenticationError("token denied") + + try: + azure_identity.DefaultAzureCredential = CredentialThatFailsGetToken + + with pytest.raises(RuntimeError, match="Azure AD authentication failed"): + AADAuth._acquire_token("default") + + # Credential was created and cached before get_token failed + assert "default" in _credential_cache + assert isinstance(_credential_cache["default"], CredentialThatFailsGetToken) + finally: + azure_identity.DefaultAzureCredential = original + + def test_init_error_does_not_leave_stale_entry_in_cache(self): + """When credential_class() raises during __init__, no entry should + be left in _credential_cache since the dict assignment never completes.""" + import sys + + azure_identity = sys.modules["azure.identity"] + original = azure_identity.DefaultAzureCredential + + class CredentialThatFailsInit: + def __init__(self): + raise ValueError("init exploded") + + try: + azure_identity.DefaultAzureCredential = CredentialThatFailsInit + + with pytest.raises(RuntimeError, match="Failed to create"): + AADAuth.get_token("default") + + # The cache should NOT contain a stale entry + assert "default" not in _credential_cache + finally: + azure_identity.DefaultAzureCredential = original + + +class TestCacheOutputCorrectness: + """Verify the returned token bytes are correct on both cache-miss and cache-hit.""" + + def test_token_output_correct_on_cache_miss_and_hit(self): + """get_token should return correct token bytes on both + the initial (cache-miss) and subsequent (cache-hit) calls.""" + # First call — cache miss + token_1 = AADAuth.get_token("default") + assert isinstance(token_1, bytes) + assert len(token_1) > 4 + expected = AADAuth.get_token_struct(SAMPLE_TOKEN) + assert token_1 == expected + + # Second call — cache hit + token_2 = AADAuth.get_token("default") + assert isinstance(token_2, bytes) + assert token_2 == expected + + # Same credential instance for both + assert "default" in _credential_cache + + +class TestProcessConnectionStringTokenFailureFallthrough: + """Cover the path where get_auth_token returns None and + process_connection_string falls through without attrs.""" + + def test_returns_none_attrs_when_token_acquisition_fails(self): + """When auth type is detected but token acquisition fails, + process_connection_string should return (conn_str, None, auth_type).""" + import sys + + azure_identity = sys.modules["azure.identity"] + original = azure_identity.DefaultAzureCredential + + class CredentialThatAlwaysFails: + def __init__(self): + raise RuntimeError("cannot create credential") + + try: + azure_identity.DefaultAzureCredential = CredentialThatAlwaysFails + conn_str = "Server=test;Authentication=ActiveDirectoryDefault;Database=testdb" + result_str, attrs, auth_type = process_connection_string(conn_str) + + # Auth type was detected + assert auth_type == "default" + # But token acquisition failed, so attrs is None + assert attrs is None + # Connection string is still returned (sensitive params removed) + assert "Server=test" in result_str + assert "Database=testdb" in result_str + finally: + azure_identity.DefaultAzureCredential = original