diff --git a/sdk/postgresql/azure-postgresql-auth/CHANGELOG.md b/sdk/postgresql/azure-postgresql-auth/CHANGELOG.md new file mode 100644 index 000000000000..740d75040e66 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/CHANGELOG.md @@ -0,0 +1,26 @@ +# Release History + +## 1.0.2 (Unreleased) + +### Features Added + +### Breaking Changes + +### Bugs Fixed + +- Removed dependency on `DefaultAzureCredential` in source library +- Fixed `get_entra_conninfo_async` and `get_entra_token_async` closing the credential by using it as a context manager + +### Other Changes + +## 1.0.1 (2025-11-26) + +### Other Changes + +- Update author to Microsoft + +## 1.0.0 (2025-11-14) + +### Features Added + +- Initial public release diff --git a/sdk/postgresql/azure-postgresql-auth/LICENSE b/sdk/postgresql/azure-postgresql-auth/LICENSE new file mode 100644 index 000000000000..b2f52a2bad4e --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/LICENSE @@ -0,0 +1,21 @@ +Copyright (c) Microsoft Corporation. + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/sdk/postgresql/azure-postgresql-auth/MANIFEST.in b/sdk/postgresql/azure-postgresql-auth/MANIFEST.in new file mode 100644 index 000000000000..24fe9446010a --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/MANIFEST.in @@ -0,0 +1,6 @@ +recursive-include tests *.py *.yaml +include *.md +include LICENSE +recursive-include samples *.py *.md +include azure_postgresql_auth/py.typed +recursive-include doc *.rst diff --git a/sdk/postgresql/azure-postgresql-auth/README.md b/sdk/postgresql/azure-postgresql-auth/README.md new file mode 100644 index 000000000000..db31b19a28a0 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/README.md @@ -0,0 +1,269 @@ +# Azure PostgreSQL Auth client library for Python + +The Azure PostgreSQL Auth client library provides Microsoft Entra ID authentication for Python database drivers connecting to Azure Database for PostgreSQL. It supports psycopg2, psycopg3, and SQLAlchemy with automatic token management and connection pooling. + +[Source code](https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/postgresql/azure-postgresql-auth) +| [Package (PyPI)](https://pypi.org/project/azure-postgresql-auth/) +| [Samples](https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/postgresql/azure-postgresql-auth/samples) + +## Getting started + +### Prerequisites + +- Python 3.9 or later +- An Azure subscription +- An Azure Database for PostgreSQL Server instance with Entra ID authentication enabled +- A credential object that implements the [TokenCredential](https://learn.microsoft.com/python/api/azure-core/azure.core.credentials.tokencredential) interface + +### Install the package + +Install the core package: + +```bash +pip install azure-postgresql-auth +``` + +Install with driver-specific extras: + +```bash +# For psycopg3 (recommended for new projects) +pip install "azure-postgresql-auth[psycopg3]" + +# For psycopg2 (legacy support) +pip install "azure-postgresql-auth[psycopg2]" + +# For SQLAlchemy +pip install "azure-postgresql-auth[sqlalchemy]" +``` + +Install Azure Identity for credential support: + +```bash +pip install azure-identity +``` + +## Key concepts + +### Authentication flow + +1. **Token Acquisition**: Uses Azure Identity credentials to acquire access tokens from Microsoft Entra ID. +2. **Automatic Refresh**: Tokens are acquired for each new database connection. +3. **Secure Transport**: Tokens are passed as passwords in PostgreSQL connection strings over SSL. +4. **Server Validation**: Azure Database for PostgreSQL validates the token and establishes the authenticated connection. +5. **User Mapping**: The token's user principal name (UPN) is mapped to a PostgreSQL user for authorization. + +### Token scopes + +The library requests the following OAuth2 scopes: + +- **Database scope**: `https://ossrdbms-aad.database.windows.net/.default` (primary) +- **Management scope**: `https://management.azure.com/.default` (fallback for managed identities) + +### Driver support + +- **psycopg3**: Modern PostgreSQL driver (recommended for new projects) — sync and async support +- **psycopg2**: Legacy PostgreSQL driver — synchronous only +- **SQLAlchemy**: ORM/Core interface using event listeners — sync and async engine support + +### Security features + +- **Token-based authentication**: No passwords stored or transmitted +- **Automatic expiration**: Tokens expire and are refreshed automatically +- **SSL enforcement**: All connections require SSL encryption +- **Principle of least privilege**: Only database-specific scopes are requested + +## Examples + +### Configuration + +The samples use environment variables to configure database connections. Copy `.env.example` into a `.env` file +in the same directory as the sample and update the variables: + +``` +POSTGRES_SERVER= +POSTGRES_DATABASE= +``` + +### psycopg2 — Connection pooling + +```python +from azure_postgresql_auth.psycopg2 import EntraConnection +from azure.identity import DefaultAzureCredential +from psycopg2 import pool +from functools import partial + +credential = DefaultAzureCredential() +connection_factory = partial(EntraConnection, credential=credential) + +connection_pool = pool.ThreadedConnectionPool( + minconn=1, + maxconn=5, + host="your-server.postgres.database.azure.com", + database="your_database", + connection_factory=connection_factory, +) +conn = connection_pool.getconn() +with conn.cursor() as cur: + cur.execute("SELECT 1") +``` + +### psycopg2 — Direct connection + +```python +from azure_postgresql_auth.psycopg2 import EntraConnection +from azure.identity import DefaultAzureCredential + +with EntraConnection( + "postgresql://your-server.postgres.database.azure.com:5432/your_database", + credential=DefaultAzureCredential(), +) as conn: + with conn.cursor() as cur: + cur.execute("SELECT 1") +``` + +### psycopg3 — Synchronous connection + +```python +from azure_postgresql_auth.psycopg3 import EntraConnection +from azure.identity import DefaultAzureCredential +from psycopg_pool import ConnectionPool + +with ConnectionPool( + conninfo="postgresql://your-server.postgres.database.azure.com:5432/your_database", + connection_class=EntraConnection, + kwargs={"credential": DefaultAzureCredential()}, + min_size=1, + max_size=5, +) as pg_pool: + with pg_pool.connection() as conn: + with conn.cursor() as cur: + cur.execute("SELECT 1") +``` + +### psycopg3 — Asynchronous connection + +```python +from azure_postgresql_auth.psycopg3 import AsyncEntraConnection +from azure.identity.aio import DefaultAzureCredential +from psycopg_pool import AsyncConnectionPool + +async with AsyncConnectionPool( + conninfo="postgresql://your-server.postgres.database.azure.com:5432/your_database", + connection_class=AsyncEntraConnection, + kwargs={"credential": DefaultAzureCredential()}, + min_size=1, + max_size=5, +) as pg_pool: + async with pg_pool.connection() as conn: + async with conn.cursor() as cur: + await cur.execute("SELECT 1") +``` + +### SQLAlchemy — Synchronous engine + +> For more information, see SQLAlchemy's documentation on +> [controlling how parameters are passed to the DBAPI connect function](https://docs.sqlalchemy.org/en/20/core/engines.html#controlling-how-parameters-are-passed-to-the-dbapi-connect-function). + +```python +from sqlalchemy import create_engine +from azure_postgresql_auth.sqlalchemy import enable_entra_authentication +from azure.identity import DefaultAzureCredential + +engine = create_engine( + "postgresql+psycopg://your-server.postgres.database.azure.com/your_database", + connect_args={"credential": DefaultAzureCredential()}, +) +enable_entra_authentication(engine) + +with engine.connect() as conn: + result = conn.execute(text("SELECT 1")) +``` + +### SQLAlchemy — Asynchronous engine + +```python +from sqlalchemy.ext.asyncio import create_async_engine +from azure_postgresql_auth.sqlalchemy import enable_entra_authentication_async +from azure.identity import DefaultAzureCredential + +engine = create_async_engine( + "postgresql+psycopg://your-server.postgres.database.azure.com/your_database", + connect_args={"credential": DefaultAzureCredential()}, +) +enable_entra_authentication_async(engine) + +async with engine.connect() as conn: + result = await conn.execute(text("SELECT 1")) +``` + +## Troubleshooting + +### Authentication errors + +If you get "password authentication failed", ensure your Azure identity has been granted access to the database: + +```sql +-- Run as a database administrator +CREATE ROLE "your-user@your-domain.com" WITH LOGIN; +GRANT ALL PRIVILEGES ON DATABASE your_database TO "your-user@your-domain.com"; +``` + +### Connection timeouts + +Increase the connection timeout for slow networks: + +```python +conn = EntraConnection.connect( + "postgresql://server:5432/db", + credential=DefaultAzureCredential(), + connect_timeout=30, +) +``` + +### Windows async compatibility + +On Windows, you may need to set the event loop policy for async usage: + +```python +import asyncio +import sys + +if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) +``` + +### Debug logging + +Enable debug logging to troubleshoot authentication issues: + +```python +import logging +logging.basicConfig(level=logging.DEBUG) +``` + +## Next steps + +### Additional documentation + +For more information about Azure Database for PostgreSQL Entra ID authentication, see the +[Azure documentation](https://learn.microsoft.com/azure/postgresql/security/security-entra-configure). + +### Samples + +Explore [sample code](https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/postgresql/azure-postgresql-auth/samples) for psycopg2, psycopg3, and SQLAlchemy. + +## Contributing + +This project welcomes contributions and suggestions. Most contributions require you to agree to a +Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us +the rights to use your contribution. For details, visit [https://cla.microsoft.com](https://cla.microsoft.com). + +When you submit a pull request, a CLA-bot will automatically determine whether you need to provide +a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions +provided by the bot. You will only need to do this once across all repos using our CLA. + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). +For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or +contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. + +![Impressions](https://azure-sdk-impressions.azurewebsites.net/api/impressions/azure-sdk-for-python%2Fsdk%2Fpostgresql%2Fazure-postgresql-auth%2FREADME.png) diff --git a/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/__init__.py b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/__init__.py new file mode 100644 index 000000000000..e5e343e2ba20 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/__init__.py @@ -0,0 +1,22 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +""" +Azure PostgreSQL Auth client library for Python. + +This library provides Microsoft Entra ID authentication for Python database drivers +connecting to Azure Database for PostgreSQL. It supports psycopg2, psycopg3, +and SQLAlchemy with automatic token management. + +Available submodules (with optional dependencies): + - psycopg2: Support for psycopg2 driver (pip install azure-postgresql-auth[psycopg2]) + - psycopg3: Support for psycopg (v3) driver (pip install azure-postgresql-auth[psycopg3]) + - sqlalchemy: Support for SQLAlchemy ORM (pip install azure-postgresql-auth[sqlalchemy]) +""" + +from azure_postgresql_auth._version import VERSION + +__version__ = VERSION diff --git a/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/_version.py b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/_version.py new file mode 100644 index 000000000000..3b770bdbf149 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/_version.py @@ -0,0 +1,7 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +VERSION = "1.0.2" diff --git a/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/core.py b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/core.py new file mode 100644 index 000000000000..710d54a85747 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/core.py @@ -0,0 +1,192 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from __future__ import annotations + +import base64 +import json +from typing import Any, cast, Optional + +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential +from azure.core.exceptions import ClientAuthenticationError + +from azure_postgresql_auth.errors import ( + ScopePermissionError, + TokenDecodeError, + UsernameExtractionError, +) + +AZURE_DB_FOR_POSTGRES_SCOPE = "https://ossrdbms-aad.database.windows.net/.default" +AZURE_MANAGEMENT_SCOPE = "https://management.azure.com/.default" + + +def get_entra_token(credential: TokenCredential, scope: str) -> str: + """Acquires an Entra authentication token for Azure PostgreSQL synchronously. + + :param credential: Credential object used to obtain the token. + :type credential: ~azure.core.credentials.TokenCredential + :param scope: The scope for the token request. + :type scope: str + :return: The acquired authentication token to be used as the database password. + :rtype: str + """ + cred = credential.get_token(scope) + return cred.token + + +async def get_entra_token_async(credential: AsyncTokenCredential, scope: str) -> str: + """Asynchronously acquires an Entra authentication token for Azure PostgreSQL. + + :param credential: Asynchronous credential used to obtain the token. + :type credential: ~azure.core.credentials_async.AsyncTokenCredential + :param scope: The scope for the token request. + :type scope: str + :return: The acquired authentication token to be used as the database password. + :rtype: str + """ + cred = await credential.get_token(scope) + return cred.token + + +def decode_jwt(token: str) -> dict[str, Any]: + """Decodes a JWT token to extract its payload claims. + + :param token: The JWT token string in the standard three-part format. + :type token: str + :return: A dictionary containing the claims extracted from the token payload. + :rtype: dict[str, Any] + :raises ~azure_postgresql_auth.TokenDecodeError: If the token format is invalid or cannot be decoded. + """ + try: + payload = token.split(".")[1] + padding = "=" * (-len(payload) % 4) + decoded_payload = base64.urlsafe_b64decode(payload + padding) + return cast(dict[str, Any], json.loads(decoded_payload)) + except Exception as e: + raise TokenDecodeError("Invalid JWT token format") from e + + +def parse_principal_name(xms_mirid: str) -> Optional[str]: + """Parses the principal name from an Azure resource path. + + :param xms_mirid: The xms_mirid claim value containing the Azure resource path. + :type xms_mirid: str + :return: The extracted principal name, or None if parsing fails. + :rtype: str or None + """ + if not xms_mirid: + return None + + last_slash_index = xms_mirid.rfind("/") + if last_slash_index == -1: + return None + + beginning = xms_mirid[:last_slash_index] + principal_name = xms_mirid[last_slash_index + 1 :] + + if not principal_name or not beginning.lower().endswith( + "providers/microsoft.managedidentity/userassignedidentities" + ): + return None + + return principal_name + + +def get_entra_conninfo(credential: TokenCredential) -> dict[str, str]: + """Synchronously obtains connection information from Entra authentication for Azure PostgreSQL. + + This function acquires an access token from Microsoft Entra ID and extracts the username + from the token claims. It tries multiple claim sources to determine the username. + + :param credential: The credential used for token acquisition. + :type credential: ~azure.core.credentials.TokenCredential + :return: A dictionary with 'user' and 'password' keys for database authentication. + :rtype: dict[str, str] + :raises ~azure_postgresql_auth.TokenDecodeError: If the JWT token cannot be decoded. + :raises ~azure_postgresql_auth.UsernameExtractionError: If the username cannot be extracted. + :raises ~azure_postgresql_auth.ScopePermissionError: If the management scope token cannot be acquired. + """ + db_token = get_entra_token(credential, AZURE_DB_FOR_POSTGRES_SCOPE) + db_claims = decode_jwt(db_token) + xms_mirid = db_claims.get("xms_mirid") + username = ( + parse_principal_name(xms_mirid) + if isinstance(xms_mirid, str) + else None or db_claims.get("upn") or db_claims.get("preferred_username") or db_claims.get("unique_name") + ) + + if not username: + try: + mgmt_token = get_entra_token(credential, AZURE_MANAGEMENT_SCOPE) + except ClientAuthenticationError as e: + raise ScopePermissionError("Failed to acquire token from management scope") from e + mgmt_claims = decode_jwt(mgmt_token) + xms_mirid = mgmt_claims.get("xms_mirid") + username = ( + parse_principal_name(xms_mirid) + if isinstance(xms_mirid, str) + else None + or mgmt_claims.get("upn") + or mgmt_claims.get("preferred_username") + or mgmt_claims.get("unique_name") + ) + + if not username: + raise UsernameExtractionError( + "Could not determine username from token claims. Ensure the identity has the proper Entra ID attributes." + ) + + return {"user": username, "password": db_token} + + +async def get_entra_conninfo_async( + credential: AsyncTokenCredential, +) -> dict[str, str]: + """Asynchronously obtains connection information from Entra authentication for Azure PostgreSQL. + + This function acquires an access token from Microsoft Entra ID and extracts the username + from the token claims. It tries multiple claim sources to determine the username. + + :param credential: The async credential used for token acquisition. + :type credential: ~azure.core.credentials_async.AsyncTokenCredential + :return: A dictionary with 'user' and 'password' keys for database authentication. + :rtype: dict[str, str] + :raises ~azure_postgresql_auth.TokenDecodeError: If the JWT token cannot be decoded. + :raises ~azure_postgresql_auth.UsernameExtractionError: If the username cannot be extracted. + :raises ~azure_postgresql_auth.ScopePermissionError: If the management scope token cannot be acquired. + """ + db_token = await get_entra_token_async(credential, AZURE_DB_FOR_POSTGRES_SCOPE) + db_claims = decode_jwt(db_token) + xms_mirid = db_claims.get("xms_mirid") + username = ( + parse_principal_name(xms_mirid) + if isinstance(xms_mirid, str) + else None or db_claims.get("upn") or db_claims.get("preferred_username") or db_claims.get("unique_name") + ) + + if not username: + try: + mgmt_token = await get_entra_token_async(credential, AZURE_MANAGEMENT_SCOPE) + except ClientAuthenticationError as e: + raise ScopePermissionError("Failed to acquire token from management scope") from e + mgmt_claims = decode_jwt(mgmt_token) + xms_mirid = mgmt_claims.get("xms_mirid") + username = ( + parse_principal_name(xms_mirid) + if isinstance(xms_mirid, str) + else None + or mgmt_claims.get("upn") + or mgmt_claims.get("preferred_username") + or mgmt_claims.get("unique_name") + ) + + if not username: + raise UsernameExtractionError( + "Could not determine username from token claims. Ensure the identity has the proper Entra ID attributes." + ) + + return {"user": username, "password": db_token} diff --git a/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/errors.py b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/errors.py new file mode 100644 index 000000000000..3ed71a712b42 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/errors.py @@ -0,0 +1,31 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from __future__ import annotations + + +class AzurePgEntraError(Exception): + """Base class for all custom exceptions in the project.""" + + +class TokenDecodeError(AzurePgEntraError): + """Raised when a token value is invalid.""" + + +class UsernameExtractionError(AzurePgEntraError): + """Raised when username cannot be extracted from token.""" + + +class CredentialValueError(AzurePgEntraError): + """Raised when token credential is invalid.""" + + +class EntraConnectionValueError(AzurePgEntraError): + """Raised when Entra connection credentials are invalid.""" + + +class ScopePermissionError(AzurePgEntraError): + """Raised when the provided scope does not have sufficient permissions.""" diff --git a/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/psycopg2/__init__.py b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/psycopg2/__init__.py new file mode 100644 index 000000000000..3b4959aa3f8c --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/psycopg2/__init__.py @@ -0,0 +1,29 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +""" +Psycopg2 support for Microsoft Entra ID authentication with Azure Database for PostgreSQL. + +This module provides a connection class that handles Microsoft Entra ID token acquisition +and authentication for synchronous PostgreSQL connections. + +Requirements: + Install with: pip install azure-postgresql-auth[psycopg2] + + This will install: + - psycopg2-binary>=2.9.0 + +Classes: + EntraConnection: Synchronous connection class with Entra ID authentication +""" + +from .entra_connection import ( + EntraConnection, +) + +__all__ = [ + "EntraConnection", +] diff --git a/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/psycopg2/entra_connection.py b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/psycopg2/entra_connection.py new file mode 100644 index 000000000000..3357c5887381 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/psycopg2/entra_connection.py @@ -0,0 +1,80 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +from azure.core.credentials import TokenCredential + +from azure_postgresql_auth.core import get_entra_conninfo +from azure_postgresql_auth.errors import ( + CredentialValueError, + EntraConnectionValueError, +) + +try: + from psycopg2.extensions import connection, make_dsn, parse_dsn +except ImportError as e: + # Provide a helpful error message if psycopg2 dependencies are missing + raise ImportError( + "psycopg2 dependencies are not installed. Install them with: pip install azure-postgresql-auth[psycopg2]" + ) from e + + +class EntraConnection(connection): + """Synchronous connection class for using Entra authentication with Azure PostgreSQL. + + This connection class automatically acquires Microsoft Entra ID credentials when user + or password are not provided in the DSN or connection parameters. + + :param dsn: PostgreSQL connection string (Data Source Name). + :type dsn: str + :keyword credential: Azure credential for token acquisition. + :paramtype credential: ~azure.core.credentials.TokenCredential + :keyword user: Database username. If not provided, extracted from Entra token. + :paramtype user: str or None + :keyword password: Database password. If not provided, uses Entra access token. + :paramtype password: str or None + :raises ~azure_postgresql_auth.CredentialValueError: + If the provided credential is not a valid TokenCredential. + :raises ~azure_postgresql_auth.EntraConnectionValueError: + If Entra connection credentials cannot be retrieved. + """ + + def __init__(self, dsn: str, **kwargs: Any) -> None: + # Extract current DSN params + dsn_params = parse_dsn(dsn) if dsn else {} + + credential = kwargs.pop("credential", None) + if credential is None or not isinstance(credential, (TokenCredential)): + raise CredentialValueError("credential is required and must be a TokenCredential for sync connections") + + # Check if user and password are already provided + has_user = "user" in dsn_params or "user" in kwargs + has_password = "password" in dsn_params or "password" in kwargs + + # Only get Entra credentials if user or password is missing + if not has_user or not has_password: + try: + entra_creds = get_entra_conninfo(credential) + except Exception as e: + raise EntraConnectionValueError("Could not retrieve Entra credentials") from e + + # Only update missing credentials + if not has_user and "user" in entra_creds: + dsn_params["user"] = entra_creds["user"] + if not has_password and "password" in entra_creds: + dsn_params["password"] = entra_creds["password"] + + # Update DSN params with any kwargs (kwargs take precedence) + dsn_params.update(kwargs) + + # Create new DSN with updated credentials + new_dsn = make_dsn(**dsn_params) + + # Call parent constructor with updated DSN only + super().__init__(new_dsn) diff --git a/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/psycopg3/__init__.py b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/psycopg3/__init__.py new file mode 100644 index 000000000000..a118c9476c0d --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/psycopg3/__init__.py @@ -0,0 +1,27 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +""" +Psycopg3 support for Microsoft Entra ID authentication with Azure Database for PostgreSQL. + +This module provides connection classes that extend psycopg's Connection and AsyncConnection +to automatically handle Microsoft Entra ID token acquisition and authentication. + +Requirements: + Install with: pip install azure-postgresql-auth[psycopg3] + + This will install: + - psycopg[binary]>=3.1.0 + +Classes: + EntraConnection: Synchronous connection class with Entra ID authentication + AsyncEntraConnection: Asynchronous connection class with Entra ID authentication +""" + +from .async_entra_connection import AsyncEntraConnection +from .entra_connection import EntraConnection + +__all__ = ["EntraConnection", "AsyncEntraConnection"] diff --git a/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/psycopg3/async_entra_connection.py b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/psycopg3/async_entra_connection.py new file mode 100644 index 000000000000..b8225669c223 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/psycopg3/async_entra_connection.py @@ -0,0 +1,62 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +from azure.core.credentials_async import AsyncTokenCredential +from azure_postgresql_auth.core import get_entra_conninfo_async +from azure_postgresql_auth.errors import ( + CredentialValueError, + EntraConnectionValueError, +) + +try: + from psycopg import AsyncConnection +except ImportError as e: + raise ImportError( + "psycopg3 dependencies are not installed. Install them with: pip install azure-postgresql-auth[psycopg3]" + ) from e + + +class AsyncEntraConnection(AsyncConnection): + """Asynchronous connection class for using Entra authentication with Azure PostgreSQL.""" + + @classmethod + async def connect(cls, *args: Any, **kwargs: Any) -> "AsyncEntraConnection": + """Establishes an asynchronous PostgreSQL connection using Entra authentication. + + This method automatically acquires Microsoft Entra ID credentials when user or password + are not provided in the connection parameters. + + :param args: Positional arguments forwarded to the parent connection method. + :type args: Any + :return: An open asynchronous connection to the PostgreSQL database. + :rtype: AsyncEntraConnection + :raises ~azure_postgresql_auth.CredentialValueError: + If the provided credential is not a valid AsyncTokenCredential. + :raises ~azure_postgresql_auth.EntraConnectionValueError: + If Entra connection credentials cannot be retrieved. + """ + credential = kwargs.pop("credential", None) + if credential is None or not isinstance(credential, (AsyncTokenCredential)): + raise CredentialValueError( + "credential is required and must be an AsyncTokenCredential for async connections" + ) + + # Check if we need to acquire Entra authentication info + if not kwargs.get("user") or not kwargs.get("password"): + try: + entra_conninfo = await get_entra_conninfo_async(credential) + except Exception as e: + raise EntraConnectionValueError("Could not retrieve Entra credentials") from e + # Always use the token password when Entra authentication is needed + kwargs["password"] = entra_conninfo["password"] + if not kwargs.get("user"): + # If user isn't already set, use the username from the token + kwargs["user"] = entra_conninfo["user"] + return await super().connect(*args, **kwargs) diff --git a/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/psycopg3/entra_connection.py b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/psycopg3/entra_connection.py new file mode 100644 index 000000000000..e0d9c62222d8 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/psycopg3/entra_connection.py @@ -0,0 +1,60 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +from azure.core.credentials import TokenCredential +from azure_postgresql_auth.core import get_entra_conninfo +from azure_postgresql_auth.errors import ( + CredentialValueError, + EntraConnectionValueError, +) + +try: + from psycopg import Connection +except ImportError as e: + raise ImportError( + "psycopg3 dependencies are not installed. Install them with: pip install azure-postgresql-auth[psycopg3]" + ) from e + + +class EntraConnection(Connection): + """Synchronous connection class for using Entra authentication with Azure PostgreSQL.""" + + @classmethod + def connect(cls, *args: Any, **kwargs: Any) -> "EntraConnection": + """Establishes a synchronous PostgreSQL connection using Entra authentication. + + This method automatically acquires Microsoft Entra ID credentials when user or password + are not provided in the connection parameters. + + :param args: Positional arguments forwarded to the parent connection method. + :type args: Any + :return: An open synchronous connection to the PostgreSQL database. + :rtype: EntraConnection + :raises ~azure_postgresql_auth.CredentialValueError: + If the provided credential is not a valid TokenCredential. + :raises ~azure_postgresql_auth.EntraConnectionValueError: + If Entra connection credentials cannot be retrieved. + """ + credential = kwargs.pop("credential", None) + if credential is None or not isinstance(credential, (TokenCredential)): + raise CredentialValueError("credential is required and must be a TokenCredential for sync connections") + + # Check if we need to acquire Entra authentication info + if not kwargs.get("user") or not kwargs.get("password"): + try: + entra_conninfo = get_entra_conninfo(credential) + except Exception as e: + raise EntraConnectionValueError("Could not retrieve Entra credentials") from e + # Always use the token password when Entra authentication is needed + kwargs["password"] = entra_conninfo["password"] + if not kwargs.get("user"): + # If user isn't already set, use the username from the token + kwargs["user"] = entra_conninfo["user"] + return super().connect(*args, **kwargs) diff --git a/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/py.typed b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/py.typed new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/sqlalchemy/__init__.py b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/sqlalchemy/__init__.py new file mode 100644 index 000000000000..029ad6487851 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/sqlalchemy/__init__.py @@ -0,0 +1,31 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +""" +SQLAlchemy integration for Azure PostgreSQL with Entra ID authentication. + +This module provides integration between SQLAlchemy and Microsoft Entra ID +authentication for PostgreSQL connections. It automatically handles token acquisition +and credential injection through SQLAlchemy's event system. + +Requirements: + Install with: pip install azure-postgresql-auth[sqlalchemy] + + This will install: + - sqlalchemy>=2.0.0 + +Functions: + enable_entra_authentication: Enable Entra ID authentication for synchronous SQLAlchemy engines + enable_entra_authentication_async: Enable Entra ID authentication for asynchronous SQLAlchemy engines +""" + +from .async_entra_connection import enable_entra_authentication_async +from .entra_connection import enable_entra_authentication + +__all__ = [ + "enable_entra_authentication", + "enable_entra_authentication_async", +] diff --git a/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/sqlalchemy/async_entra_connection.py b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/sqlalchemy/async_entra_connection.py new file mode 100644 index 000000000000..252631b860d4 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/sqlalchemy/async_entra_connection.py @@ -0,0 +1,80 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +from azure.core.credentials import TokenCredential +from azure_postgresql_auth.core import get_entra_conninfo +from azure_postgresql_auth.errors import ( + CredentialValueError, + EntraConnectionValueError, +) + +try: + from sqlalchemy import event + from sqlalchemy.engine import Dialect + from sqlalchemy.ext.asyncio import AsyncEngine +except ImportError as e: + # Provide a helpful error message if SQLAlchemy dependencies are missing + raise ImportError( + "SQLAlchemy dependencies are not installed. Install them with: pip install azure-postgresql-auth[sqlalchemy]" + ) from e + + +def enable_entra_authentication_async(engine: AsyncEngine) -> None: + """Enable Microsoft Entra ID authentication for an async SQLAlchemy engine. + + This function registers an event listener that automatically provides + Entra ID credentials for each database connection. A credential must be + provided via connect_args when creating the engine. Event handlers do not + support async behavior so the token fetching will still be synchronous. + + :param engine: The async SQLAlchemy Engine to enable Entra authentication for. + :type engine: ~sqlalchemy.ext.asyncio.AsyncEngine + """ + + @event.listens_for(engine.sync_engine, "do_connect") + def provide_token( + dialect: Dialect, conn_rec: Any, cargs: Any, cparams: dict[str, Any] # pylint: disable=unused-argument + ) -> None: + """Event handler that provides Entra credentials for each sync connection. + + :param dialect: The SQLAlchemy dialect being used. + :type dialect: ~sqlalchemy.engine.Dialect + :param conn_rec: The connection record. + :type conn_rec: Any + :param cargs: The positional connection arguments. + :type cargs: Any + :param cparams: The keyword connection parameters. + :type cparams: dict[str, Any] + """ + credential = cparams.get("credential", None) + if credential is None or not isinstance(credential, (TokenCredential)): + raise CredentialValueError( + "credential is required and must be a TokenCredential. " + "Pass it via connect_args={'credential': DefaultAzureCredential()}" + ) + # Check if credentials are already present + has_user = "user" in cparams + has_password = "password" in cparams + + # Only get Entra credentials if user or password is missing + if not has_user or not has_password: + try: + entra_creds = get_entra_conninfo(credential) + except Exception as e: + raise EntraConnectionValueError("Could not retrieve Entra credentials") from e + # Only update missing credentials + if not has_user and "user" in entra_creds: + cparams["user"] = entra_creds["user"] + if not has_password and "password" in entra_creds: + cparams["password"] = entra_creds["password"] + + # Strip helper-only param before DBAPI connect to avoid 'invalid connection option' + if "credential" in cparams: + del cparams["credential"] diff --git a/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/sqlalchemy/entra_connection.py b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/sqlalchemy/entra_connection.py new file mode 100644 index 000000000000..ad3d4c4d9048 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/azure_postgresql_auth/sqlalchemy/entra_connection.py @@ -0,0 +1,78 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +from azure.core.credentials import TokenCredential +from azure_postgresql_auth.core import get_entra_conninfo +from azure_postgresql_auth.errors import ( + CredentialValueError, + EntraConnectionValueError, +) + +try: + from sqlalchemy import Engine, event + from sqlalchemy.engine import Dialect +except ImportError as e: + raise ImportError( + "SQLAlchemy dependencies are not installed. Install them with: pip install azure-postgresql-auth[sqlalchemy]" + ) from e + + +def enable_entra_authentication(engine: Engine) -> None: + """Enable Microsoft Entra ID authentication for a SQLAlchemy engine. + + This function registers an event listener that automatically provides + Entra ID credentials for each database connection. A credential must be + provided via connect_args when creating the engine. + + :param engine: The SQLAlchemy Engine to enable Entra authentication for. + :type engine: ~sqlalchemy.Engine + """ + + @event.listens_for(engine, "do_connect") + def provide_token( + dialect: Dialect, conn_rec: Any, cargs: Any, cparams: dict[str, Any] # pylint: disable=unused-argument + ) -> None: + """Event handler that provides Entra credentials for each connection. + + :param dialect: The SQLAlchemy dialect being used. + :type dialect: ~sqlalchemy.engine.Dialect + :param conn_rec: The connection record. + :type conn_rec: Any + :param cargs: The positional connection arguments. + :type cargs: Any + :param cparams: The keyword connection parameters. + :type cparams: dict[str, Any] + """ + credential = cparams.get("credential", None) + if credential is None or not isinstance(credential, (TokenCredential)): + raise CredentialValueError( + "credential is required and must be a TokenCredential. " + "Pass it via connect_args={'credential': DefaultAzureCredential()}" + ) + # Check if credentials are already present + has_user = "user" in cparams + has_password = "password" in cparams + + # Only get Entra credentials if user or password is missing + if not has_user or not has_password: + try: + entra_creds = get_entra_conninfo(credential) + except Exception as e: + raise EntraConnectionValueError("Could not retrieve Entra credentials") from e + # Only update missing credentials + if not has_user and "user" in entra_creds: + cparams["user"] = entra_creds["user"] + if not has_password and "password" in entra_creds: + cparams["password"] = entra_creds["password"] + + # Remove the helper-only parameter so the DBAPI (psycopg/psycopg2) doesn't see an + # unknown connection option and raise 'invalid connection option "credential"'. + if "credential" in cparams: + del cparams["credential"] diff --git a/sdk/postgresql/azure-postgresql-auth/dev_requirements.txt b/sdk/postgresql/azure-postgresql-auth/dev_requirements.txt new file mode 100644 index 000000000000..d26cf14fc463 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/dev_requirements.txt @@ -0,0 +1,10 @@ +-e ../../../eng/tools/azure-sdk-tools +../../core/azure-core +../../identity/azure-identity +aiohttp +pytest +pytest-asyncio +psycopg2-binary>=2.9.0 +psycopg[binary]>=3.1.0 +psycopg-pool>=3.1.0 +sqlalchemy>=2.0.0 diff --git a/sdk/postgresql/azure-postgresql-auth/doc/azure_postgresql_auth.psycopg2.rst b/sdk/postgresql/azure-postgresql-auth/doc/azure_postgresql_auth.psycopg2.rst new file mode 100644 index 000000000000..e840440b33bc --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/doc/azure_postgresql_auth.psycopg2.rst @@ -0,0 +1,7 @@ +azure\_postgresql\_auth.psycopg2 package +======================================== + +.. automodule:: azure_postgresql_auth.psycopg2 + :members: + :show-inheritance: + :undoc-members: diff --git a/sdk/postgresql/azure-postgresql-auth/doc/azure_postgresql_auth.psycopg3.rst b/sdk/postgresql/azure-postgresql-auth/doc/azure_postgresql_auth.psycopg3.rst new file mode 100644 index 000000000000..19a3c14adbd2 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/doc/azure_postgresql_auth.psycopg3.rst @@ -0,0 +1,7 @@ +azure\_postgresql\_auth.psycopg3 package +======================================== + +.. automodule:: azure_postgresql_auth.psycopg3 + :members: + :show-inheritance: + :undoc-members: diff --git a/sdk/postgresql/azure-postgresql-auth/doc/azure_postgresql_auth.rst b/sdk/postgresql/azure-postgresql-auth/doc/azure_postgresql_auth.rst new file mode 100644 index 000000000000..ae673ba87833 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/doc/azure_postgresql_auth.rst @@ -0,0 +1,36 @@ +azure\_postgresql\_auth package +=============================== + +.. automodule:: azure_postgresql_auth + :members: + :show-inheritance: + :undoc-members: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + azure_postgresql_auth.psycopg2 + azure_postgresql_auth.psycopg3 + azure_postgresql_auth.sqlalchemy + +Submodules +---------- + +azure\_postgresql\_auth.core module +----------------------------------- + +.. automodule:: azure_postgresql_auth.core + :members: + :show-inheritance: + :undoc-members: + +azure\_postgresql\_auth.errors module +------------------------------------- + +.. automodule:: azure_postgresql_auth.errors + :members: + :show-inheritance: + :undoc-members: diff --git a/sdk/postgresql/azure-postgresql-auth/doc/azure_postgresql_auth.sqlalchemy.rst b/sdk/postgresql/azure-postgresql-auth/doc/azure_postgresql_auth.sqlalchemy.rst new file mode 100644 index 000000000000..0a8d54031c45 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/doc/azure_postgresql_auth.sqlalchemy.rst @@ -0,0 +1,7 @@ +azure\_postgresql\_auth.sqlalchemy package +========================================== + +.. automodule:: azure_postgresql_auth.sqlalchemy + :members: + :show-inheritance: + :undoc-members: diff --git a/sdk/postgresql/azure-postgresql-auth/pyproject.toml b/sdk/postgresql/azure-postgresql-auth/pyproject.toml new file mode 100644 index 000000000000..902a92a6a6f9 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/pyproject.toml @@ -0,0 +1,62 @@ +[build-system] +requires = ["setuptools>=77.0.3", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "azure-postgresql-auth" +authors = [ + {name = "Microsoft Corporation", email = "azpysdkhelp@microsoft.com"}, +] +description = "Microsoft Azure PostgreSQL Auth Library for Python" +keywords = ["azure", "azure sdk", "postgresql", "entra"] +requires-python = ">=3.9" +license = "MIT" +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", +] +dependencies = [ + "azure-core>=1.29.4", +] +dynamic = ["version", "readme"] + +[project.optional-dependencies] +psycopg3 = [ + "psycopg[binary]>=3.1.0", + "psycopg-pool>=3.1.0", +] +psycopg2 = [ + "psycopg2-binary>=2.9.0", +] +sqlalchemy = [ + "sqlalchemy>=2.0.0", +] + +[project.urls] +repository = "https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/postgresql/azure-postgresql-auth" + +[tool.setuptools.dynamic] +version = {attr = "azure_postgresql_auth._version.VERSION"} +readme = {file = ["README.md", "CHANGELOG.md"], content-type = "text/markdown"} + +[tool.setuptools.packages.find] +include = ["azure_postgresql_auth*"] + +[tool.setuptools.package-data] +pytyped = ["py.typed"] + +[tool.azure-sdk-build] +mypy = true +pyright = false +black = true + +[tool.azure-sdk-conda] +in_bundle = false diff --git a/sdk/postgresql/azure-postgresql-auth/samples/.env.example b/sdk/postgresql/azure-postgresql-auth/samples/.env.example new file mode 100644 index 000000000000..3887833d45cd --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/samples/.env.example @@ -0,0 +1,5 @@ +# Example environment configuration for PostgreSQL connection +# Copy this file to .env and update with your actual values + +POSTGRES_SERVER=your-server.postgres.database.azure.com +POSTGRES_DATABASE=your_database_name diff --git a/sdk/postgresql/azure-postgresql-auth/samples/sample_psycopg2_connection.py b/sdk/postgresql/azure-postgresql-auth/samples/sample_psycopg2_connection.py new file mode 100644 index 000000000000..c3781be43d76 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/samples/sample_psycopg2_connection.py @@ -0,0 +1,64 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +""" +FILE: sample_psycopg2_connection.py + +DESCRIPTION: + This sample demonstrates how to connect to Azure PostgreSQL using psycopg2 + with synchronous Entra ID authentication. + +USAGE: + python sample_psycopg2_connection.py +""" + +import os +from functools import partial + +from azure.identity import DefaultAzureCredential +from dotenv import load_dotenv +from psycopg2 import pool + +from azure_postgresql_auth.psycopg2 import EntraConnection + +# Load environment variables from .env file +load_dotenv() +SERVER = os.getenv("POSTGRES_SERVER") +DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") + + +def main() -> None: + # We use the EntraConnection class to enable synchronous Entra-based authentication for database access. + # This class is applied whenever the connection pool creates a new connection, ensuring that Entra + # authentication tokens are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://www.psycopg.org/docs/advanced.html#subclassing-connection + + # Create a connection factory with the credential bound using functools.partial + credential = DefaultAzureCredential() + connection_factory = partial(EntraConnection, credential=credential) + + connection_pool = pool.ThreadedConnectionPool( + minconn=1, + maxconn=5, + host=SERVER, + database=DATABASE, + connection_factory=connection_factory, + ) + + conn = connection_pool.getconn() + try: + with conn.cursor() as cur: + cur.execute("SELECT now()") + result = cur.fetchone() + print(f"Database time: {result[0]}") + finally: + connection_pool.putconn(conn) + connection_pool.closeall() + + +if __name__ == "__main__": + main() diff --git a/sdk/postgresql/azure-postgresql-auth/samples/sample_psycopg3_connection.py b/sdk/postgresql/azure-postgresql-auth/samples/sample_psycopg3_connection.py new file mode 100644 index 000000000000..4557fc9b48f1 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/samples/sample_psycopg3_connection.py @@ -0,0 +1,123 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +""" +FILE: sample_psycopg3_connection.py + +DESCRIPTION: + This sample demonstrates how to connect to Azure PostgreSQL using psycopg3 + with both synchronous and asynchronous Entra ID authentication. + +USAGE: + python sample_psycopg3_connection.py +""" + +import argparse +import asyncio +import os +import sys + +from azure.identity import DefaultAzureCredential +from azure.identity.aio import DefaultAzureCredential as AsyncDefaultAzureCredential +from dotenv import load_dotenv +from psycopg_pool import AsyncConnectionPool, ConnectionPool + +from azure_postgresql_auth.psycopg3 import AsyncEntraConnection, EntraConnection + +# Load environment variables from .env file +load_dotenv() +SERVER = os.getenv("POSTGRES_SERVER") +DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") + + +def main_sync() -> None: + """Synchronous connection example using psycopg with Entra ID authentication.""" + + # We use the EntraConnection class to enable synchronous Entra-based authentication for database access. + # This class is applied whenever the connection pool creates a new connection, ensuring that Entra + # authentication tokens are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://www.psycopg.org/psycopg3/docs/api/connections.html#psycopg.Connection.connect + pool = ConnectionPool( + conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", + min_size=1, + max_size=5, + open=False, + connection_class=EntraConnection, + kwargs={"credential": DefaultAzureCredential()}, + ) + with pool, pool.connection() as conn, conn.cursor() as cur: + cur.execute("SELECT now()") + result = cur.fetchone() + print(f"Sync - Database time: {result}") + + +async def main_async() -> None: + """Asynchronous connection example using psycopg with Entra ID authentication.""" + + # We use the AsyncEntraConnection class to enable asynchronous Entra-based authentication for database access. + # This class is applied whenever the connection pool creates a new connection, ensuring that Entra + # authentication tokens are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://www.psycopg.org/psycopg3/docs/api/connections.html#psycopg.Connection.connect + pool = AsyncConnectionPool( + conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", + min_size=1, + max_size=5, + open=False, + connection_class=AsyncEntraConnection, + kwargs={"credential": AsyncDefaultAzureCredential()}, + ) + async with pool, pool.connection() as conn, conn.cursor() as cur: + await cur.execute("SELECT now()") + result = await cur.fetchone() + print(f"Async - Database time: {result}") + + +async def main(mode: str = "async") -> None: + """Main function that runs sync and/or async examples based on mode. + + Args: + mode: "sync", "async", or "both" to determine which examples to run + """ + if mode in ("sync", "both"): + print("=== Running Synchronous Example ===") + try: + main_sync() + print("Sync example completed successfully!") + except Exception as e: + print(f"Sync example failed: {e}") + + if mode in ("async", "both"): + if mode == "both": + print("\n=== Running Asynchronous Example ===") + else: + print("=== Running Asynchronous Example ===") + try: + await main_async() + print("Async example completed successfully!") + except Exception as e: + print(f"Async example failed: {e}") + + +if __name__ == "__main__": + # Parse command line arguments + parser = argparse.ArgumentParser( + description="Demonstrate psycopg connections with Microsoft Entra ID authentication" + ) + parser.add_argument( + "--mode", + choices=["sync", "async", "both"], + default="both", + help="Run synchronous, asynchronous, or both examples (default: both)", + ) + args = parser.parse_args() + + # Set Windows event loop policy for compatibility if needed + if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + asyncio.run(main(args.mode)) diff --git a/sdk/postgresql/azure-postgresql-auth/samples/sample_sqlalchemy_connection.py b/sdk/postgresql/azure-postgresql-auth/samples/sample_sqlalchemy_connection.py new file mode 100644 index 000000000000..03a54af820fd --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/samples/sample_sqlalchemy_connection.py @@ -0,0 +1,134 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +""" +FILE: sample_sqlalchemy_connection.py + +DESCRIPTION: + This sample demonstrates how to connect to Azure PostgreSQL using SQLAlchemy + with both synchronous and asynchronous Entra ID authentication. + +USAGE: + python sample_sqlalchemy_connection.py +""" + +import argparse +import asyncio +import os +import sys + +from azure.identity import DefaultAzureCredential +from dotenv import load_dotenv +from sqlalchemy import create_engine, text +from sqlalchemy.ext.asyncio import create_async_engine + +from azure_postgresql_auth.sqlalchemy import ( + enable_entra_authentication, + enable_entra_authentication_async, +) + +# Load environment variables from .env file +load_dotenv() +SERVER = os.getenv("POSTGRES_SERVER") +DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") + + +def main_sync() -> None: + """Synchronous connection example using SQLAlchemy with Entra ID authentication.""" + + # Create a synchronous engine + engine = create_engine( + f"postgresql+psycopg://{SERVER}/{DATABASE}", + connect_args={"credential": DefaultAzureCredential()}, + ) + + # We add an event listener to the engine to enable synchronous Entra authentication + # for database access. This event listener is triggered whenever the connection pool + # backing the engine creates a new connection, ensuring that Entra authentication tokens + # are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://docs.sqlalchemy.org/en/20/core/engines.html#controlling-how-parameters-are-passed-to-the-dbapi-connect-function + enable_entra_authentication(engine) + + with engine.connect() as conn: + result = conn.execute(text("SELECT now()")) + row = result.fetchone() + print(f"Sync - Database time: {row[0] if row else 'Unknown'}") + + # Clean up the engine + engine.dispose() + + +async def main_async() -> None: + """Asynchronous connection example using SQLAlchemy with Entra ID authentication.""" + + # Create an asynchronous engine + engine = create_async_engine( + f"postgresql+psycopg://{SERVER}/{DATABASE}", + connect_args={"credential": DefaultAzureCredential()}, + ) + + # We add an event listener to the engine to enable asynchronous Entra authentication + # for database access. This event listener is triggered whenever the connection pool + # backing the engine creates a new connection, ensuring that Entra authentication tokens + # are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://docs.sqlalchemy.org/en/20/core/engines.html#controlling-how-parameters-are-passed-to-the-dbapi-connect-function + enable_entra_authentication_async(engine) + + async with engine.connect() as conn: + result = await conn.execute(text("SELECT now()")) + row = result.fetchone() + print(f"Async Core - Database time: {row[0] if row else 'Unknown'}") + + # Clean up the engine + await engine.dispose() + + +async def main(mode: str = "async") -> None: + """Main function that runs sync and/or async examples based on mode. + + Args: + mode: "sync", "async", or "both" to determine which examples to run + """ + if mode in ("sync", "both"): + print("=== Running Synchronous SQLAlchemy Example ===") + try: + main_sync() + print("Sync example completed successfully!") + except Exception as e: + print(f"Sync example failed: {e}") + + if mode in ("async", "both"): + if mode == "both": + print("\n=== Running Asynchronous SQLAlchemy Example ===") + else: + print("=== Running Asynchronous SQLAlchemy Example ===") + try: + await main_async() + print("Async example completed successfully!") + except Exception as e: + print(f"Async example failed: {e}") + + +if __name__ == "__main__": + # Parse command line arguments + parser = argparse.ArgumentParser( + description="Demonstrate SQLAlchemy connections with Microsoft Entra ID authentication" + ) + parser.add_argument( + "--mode", + choices=["sync", "async", "both"], + default="both", + help="Run synchronous, asynchronous, or both examples (default: both)", + ) + args = parser.parse_args() + + # Set Windows event loop policy for compatibility if needed + if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + asyncio.run(main(args.mode)) diff --git a/sdk/postgresql/azure-postgresql-auth/sdk_packaging.toml b/sdk/postgresql/azure-postgresql-auth/sdk_packaging.toml new file mode 100644 index 000000000000..901bc8ccbfa6 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/sdk_packaging.toml @@ -0,0 +1,2 @@ +[packaging] +auto_update = false diff --git a/sdk/postgresql/azure-postgresql-auth/tests/conftest.py b/sdk/postgresql/azure-postgresql-auth/tests/conftest.py new file mode 100644 index 000000000000..17d1b344ec63 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/tests/conftest.py @@ -0,0 +1,86 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +"""Fixtures for azure-postgresql-auth live tests.""" + +from __future__ import annotations + +import asyncio +import os +import sys + +import pytest + +from devtools_testutils import get_credential + + +def _get_pg_config(): + """Read PostgreSQL connection config from environment variables.""" + host = os.environ.get("POSTGRESQL_HOST") + if not host: + pytest.skip("POSTGRESQL_HOST environment variable not set") + database = os.environ.get("POSTGRESQL_DATABASE", "testdb") + port = os.environ.get("POSTGRESQL_PORT", "5432") + return host, database, port + + +@pytest.fixture(scope="session") +def event_loop_policy(): + """Use SelectorEventLoop on Windows for psycopg async compatibility.""" + if sys.platform == "win32": + return asyncio.WindowsSelectorEventLoopPolicy() + return asyncio.DefaultEventLoopPolicy() + + +@pytest.fixture(scope="session") +def postgresql_database(): + """Get the PostgreSQL database from environment variables.""" + return os.environ.get("POSTGRESQL_DATABASE", "testdb") + + +@pytest.fixture(scope="session") +def credential(): + """Get a credential for live tests.""" + return get_credential() + + +@pytest.fixture(scope="session") +def async_credential(): + """Get an async credential for live tests.""" + return get_credential(is_async=True) + + +@pytest.fixture(scope="session") +def connection_dsn(): + """Get a psycopg2-style DSN connection string.""" + host, database, port = _get_pg_config() + return f"host={host} port={port} dbname={database} sslmode=require" + + +@pytest.fixture(scope="session") +def connection_params(): + """Get psycopg3-style connection parameters as a dict.""" + host, database, port = _get_pg_config() + return { + "host": host, + "port": port, + "dbname": database, + "sslmode": "require", + } + + +@pytest.fixture(scope="session") +def connection_url(): + """Get a SQLAlchemy sync connection URL (psycopg2 driver).""" + host, database, port = _get_pg_config() + return f"postgresql+psycopg2://{host}:{port}/{database}?sslmode=require" + + +@pytest.fixture(scope="session") +def async_connection_url(): + """Get a SQLAlchemy async connection URL (psycopg driver).""" + host, database, port = _get_pg_config() + return f"postgresql+psycopg://{host}:{port}/{database}?sslmode=require" diff --git a/sdk/postgresql/azure-postgresql-auth/tests/test_core.py b/sdk/postgresql/azure-postgresql-auth/tests/test_core.py new file mode 100644 index 000000000000..7a12ea0d25c5 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/tests/test_core.py @@ -0,0 +1,155 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +"""Unit tests for the core authentication module.""" + +from __future__ import annotations + +import pytest + +from azure_postgresql_auth.core import ( + AZURE_DB_FOR_POSTGRES_SCOPE, + AZURE_MANAGEMENT_SCOPE, + decode_jwt, + get_entra_conninfo, + parse_principal_name, +) +from azure_postgresql_auth.errors import TokenDecodeError + +from utils import ( + TEST_USERS, + MockTokenCredential, + create_jwt_token_with_preferred_username, + create_jwt_token_with_unique_name, + create_jwt_token_with_xms_mirid, + create_valid_jwt_token, +) + + +class TestDecodeJwt: + """Tests for decode_jwt function.""" + + def test_decode_valid_jwt(self): + """Test decoding a valid JWT token extracts payload claims.""" + token = create_valid_jwt_token("user@example.com") + claims = decode_jwt(token) + assert claims["upn"] == "user@example.com" + assert claims["iat"] == 1234567890 + assert claims["exp"] == 9999999999 + + def test_decode_jwt_with_xms_mirid(self): + """Test decoding a JWT with xms_mirid claim.""" + xms_mirid = TEST_USERS["MANAGED_IDENTITY_PATH"] + token = create_jwt_token_with_xms_mirid(xms_mirid) + claims = decode_jwt(token) + assert claims["xms_mirid"] == xms_mirid + + def test_decode_jwt_with_preferred_username(self): + """Test decoding a JWT with preferred_username claim.""" + token = create_jwt_token_with_preferred_username("user@example.com") + claims = decode_jwt(token) + assert claims["preferred_username"] == "user@example.com" + + def test_decode_jwt_with_unique_name(self): + """Test decoding a JWT with unique_name claim.""" + token = create_jwt_token_with_unique_name("user@example.com") + claims = decode_jwt(token) + assert claims["unique_name"] == "user@example.com" + + def test_decode_invalid_jwt_raises_error(self): + """Test that an invalid JWT token raises TokenDecodeError.""" + with pytest.raises(TokenDecodeError, match="Invalid JWT token format"): + decode_jwt("invalid-token") + + def test_decode_empty_string_raises_error(self): + """Test that an empty string raises TokenDecodeError.""" + with pytest.raises(TokenDecodeError): + decode_jwt("") + + +class TestParsePrincipalName: + """Tests for parse_principal_name function.""" + + def test_parse_valid_managed_identity_path(self): + """Test extracting principal name from a valid managed identity resource path.""" + xms_mirid = TEST_USERS["MANAGED_IDENTITY_PATH"] + result = parse_principal_name(xms_mirid) + assert result == TEST_USERS["MANAGED_IDENTITY_NAME"] + + def test_parse_empty_string_returns_none(self): + """Test that empty string returns None.""" + assert parse_principal_name("") is None + + def test_parse_no_slash_returns_none(self): + """Test that string with no slashes returns None.""" + assert parse_principal_name("no-slash-here") is None + + def test_parse_wrong_provider_returns_none(self): + """Test that wrong provider path returns None.""" + assert ( + parse_principal_name("/subscriptions/123/resourcegroups/rg/providers/Microsoft.Wrong/identities/name") + is None + ) + + def test_parse_trailing_slash_returns_none(self): + """Test that a path ending with a slash returns None.""" + assert ( + parse_principal_name( + "/subscriptions/123/resourcegroups/rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/" + ) + is None + ) + + +class TestGetEntraConninfo: + """Tests for get_entra_conninfo function.""" + + def test_conninfo_with_upn_claim(self): + """Test that UPN claim is extracted as username.""" + token = create_valid_jwt_token(TEST_USERS["ENTRA_USER"]) + credential = MockTokenCredential(token) + result = get_entra_conninfo(credential) + assert result["user"] == TEST_USERS["ENTRA_USER"] + assert result["password"] == token + + def test_conninfo_with_preferred_username(self): + """Test that preferred_username claim is extracted as username.""" + token = create_jwt_token_with_preferred_username(TEST_USERS["ENTRA_USER"]) + credential = MockTokenCredential(token) + result = get_entra_conninfo(credential) + assert result["user"] == TEST_USERS["ENTRA_USER"] + + def test_conninfo_with_unique_name(self): + """Test that unique_name claim is extracted as username.""" + token = create_jwt_token_with_unique_name(TEST_USERS["ENTRA_USER"]) + credential = MockTokenCredential(token) + result = get_entra_conninfo(credential) + assert result["user"] == TEST_USERS["ENTRA_USER"] + + def test_conninfo_with_managed_identity(self): + """Test that managed identity xms_mirid claim is parsed for username.""" + token = create_jwt_token_with_xms_mirid(TEST_USERS["MANAGED_IDENTITY_PATH"]) + credential = MockTokenCredential(token) + result = get_entra_conninfo(credential) + assert result["user"] == TEST_USERS["MANAGED_IDENTITY_NAME"] + + def test_conninfo_invalid_token_raises_error(self): + """Test that invalid token raises TokenDecodeError.""" + credential = MockTokenCredential("invalid-token") + with pytest.raises(TokenDecodeError): + get_entra_conninfo(credential) + + def test_credential_called_once_for_db_scope(self): + """Test that credential is called once when UPN found in DB token.""" + token = create_valid_jwt_token(TEST_USERS["ENTRA_USER"]) + credential = MockTokenCredential(token) + get_entra_conninfo(credential) + assert credential.get_call_count() == 1 + + def test_scope_constants_defined(self): + """Test that scope constants are defined correctly.""" + assert AZURE_DB_FOR_POSTGRES_SCOPE == "https://ossrdbms-aad.database.windows.net/.default" + assert AZURE_MANAGEMENT_SCOPE == "https://management.azure.com/.default" diff --git a/sdk/postgresql/azure-postgresql-auth/tests/test_core_async.py b/sdk/postgresql/azure-postgresql-auth/tests/test_core_async.py new file mode 100644 index 000000000000..1b32888bee09 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/tests/test_core_async.py @@ -0,0 +1,49 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +"""Async unit tests for the core authentication module.""" + +from __future__ import annotations + +import pytest + +from azure_postgresql_auth.core import get_entra_conninfo_async +from azure_postgresql_auth.errors import TokenDecodeError + +from utils import ( + TEST_USERS, + MockAsyncTokenCredential, + create_jwt_token_with_xms_mirid, + create_valid_jwt_token, +) + + +class TestGetEntraConninfoAsync: + """Tests for get_entra_conninfo_async function.""" + + @pytest.mark.asyncio + async def test_async_conninfo_with_upn_claim(self): + """Test async: UPN claim extracted as username.""" + token = create_valid_jwt_token(TEST_USERS["ENTRA_USER"]) + credential = MockAsyncTokenCredential(token) + result = await get_entra_conninfo_async(credential) + assert result["user"] == TEST_USERS["ENTRA_USER"] + assert result["password"] == token + + @pytest.mark.asyncio + async def test_async_conninfo_with_managed_identity(self): + """Test async: managed identity xms_mirid claim parsed for username.""" + token = create_jwt_token_with_xms_mirid(TEST_USERS["MANAGED_IDENTITY_PATH"]) + credential = MockAsyncTokenCredential(token) + result = await get_entra_conninfo_async(credential) + assert result["user"] == TEST_USERS["MANAGED_IDENTITY_NAME"] + + @pytest.mark.asyncio + async def test_async_conninfo_invalid_token_raises_error(self): + """Test async: invalid token raises TokenDecodeError.""" + credential = MockAsyncTokenCredential("invalid-token") + with pytest.raises(TokenDecodeError): + await get_entra_conninfo_async(credential) diff --git a/sdk/postgresql/azure-postgresql-auth/tests/test_psycopg2.py b/sdk/postgresql/azure-postgresql-auth/tests/test_psycopg2.py new file mode 100644 index 000000000000..9ae14cffce3c --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/tests/test_psycopg2.py @@ -0,0 +1,113 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +"""Unit tests for psycopg2 EntraConnection.""" + +from __future__ import annotations +import sys + +from unittest.mock import patch + +import pytest + +from azure_postgresql_auth.errors import CredentialValueError, EntraConnectionValueError + +from utils import TEST_USERS, MockTokenCredential, create_valid_jwt_token + +if sys.implementation.name != "pypy": + import psycopg2 + from azure_postgresql_auth.psycopg2 import EntraConnection +else: + pytest.skip("psycopg2 not supported on PyPy", allow_module_level=True) + + +class TestPsycopg2EntraConnection: + """Tests for psycopg2 EntraConnection class.""" + + def test_missing_credential_raises_error(self): + """Test that missing credential raises CredentialValueError.""" + with pytest.raises(CredentialValueError, match="credential is required"): + EntraConnection("host=localhost dbname=test") + + def test_invalid_credential_type_raises_error(self): + """Test that non-TokenCredential raises CredentialValueError.""" + with pytest.raises(CredentialValueError, match="credential is required"): + EntraConnection("host=localhost dbname=test", credential="not-a-credential") + + @patch("azure_postgresql_auth.psycopg2.entra_connection.get_entra_conninfo") + def test_entra_credentials_injected(self, mock_get_conninfo): + """Test that Entra credentials are injected when user/password missing.""" + token = create_valid_jwt_token(TEST_USERS["ENTRA_USER"]) + mock_get_conninfo.return_value = { + "user": TEST_USERS["ENTRA_USER"], + "password": token, + } + credential = MockTokenCredential(token) + # super().__init__ will raise OperationalError (no real DB), but + # credential logic runs before the connection attempt. + with pytest.raises(psycopg2.OperationalError): + EntraConnection("host=localhost dbname=test", credential=credential) + mock_get_conninfo.assert_called_once_with(credential) + + @patch("azure_postgresql_auth.psycopg2.entra_connection.get_entra_conninfo") + def test_existing_credentials_preserved(self, mock_get_conninfo): + """Test that existing user/password in DSN are preserved.""" + token = create_valid_jwt_token(TEST_USERS["ENTRA_USER"]) + credential = MockTokenCredential(token) + with pytest.raises(psycopg2.OperationalError): + EntraConnection( + "host=localhost dbname=test user=existing password=secret", + credential=credential, + ) + mock_get_conninfo.assert_not_called() + + +class TestPsycopg2EntraConnectionErrors: + """Tests for error handling in psycopg2 EntraConnection.""" + + @patch("azure_postgresql_auth.psycopg2.entra_connection.get_entra_conninfo") + def test_entra_credential_failure_raises_error(self, mock_get_conninfo): + """Test that credential retrieval failure raises EntraConnectionValueError.""" + mock_get_conninfo.side_effect = Exception("auth failed") + token = create_valid_jwt_token(TEST_USERS["ENTRA_USER"]) + credential = MockTokenCredential(token) + + with pytest.raises(EntraConnectionValueError, match="Could not retrieve Entra credentials"): + EntraConnection("host=localhost dbname=test", credential=credential) + + +@pytest.mark.live_test_only +class TestPsycopg2EntraConnectionLive: + """Live tests for psycopg2 EntraConnection against deployed Azure PostgreSQL.""" + + def test_connect_with_entra_user(self, connection_dsn, credential, postgresql_database): + """Test connecting with an Entra user using EntraConnection.""" + with EntraConnection(connection_dsn, credential=credential) as conn: + with conn.cursor() as cur: + cur.execute("SELECT current_user, current_database()") + row = cur.fetchone() + assert row is not None + current_user, current_db = row + + assert current_db == postgresql_database + assert current_user is not None + + def test_execute_query(self, connection_dsn, credential): + """Test executing a basic query through EntraConnection.""" + with EntraConnection(connection_dsn, credential=credential) as conn: + with conn.cursor() as cur: + cur.execute("SELECT 1") + row = cur.fetchone() + assert row is not None + assert row[0] == 1 + + def test_multiple_sequential_connections(self, connection_dsn, credential): + """Test that the same credential works across multiple sequential connections.""" + for _ in range(3): + with EntraConnection(connection_dsn, credential=credential) as conn: + with conn.cursor() as cur: + cur.execute("SELECT 1") + assert cur.fetchone()[0] == 1 diff --git a/sdk/postgresql/azure-postgresql-auth/tests/test_psycopg3.py b/sdk/postgresql/azure-postgresql-auth/tests/test_psycopg3.py new file mode 100644 index 000000000000..66abfb0aba17 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/tests/test_psycopg3.py @@ -0,0 +1,107 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +"""Unit tests for psycopg3 EntraConnection.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from azure_postgresql_auth.errors import CredentialValueError, EntraConnectionValueError +from azure_postgresql_auth.psycopg3 import EntraConnection + +from utils import ( + TEST_USERS, + MockTokenCredential, + create_valid_jwt_token, +) + + +class TestPsycopg3EntraConnection: + """Tests for psycopg3 sync EntraConnection.""" + + def test_missing_credential_raises_error(self): + """Test that missing credential raises CredentialValueError.""" + with pytest.raises(CredentialValueError, match="credential is required"): + EntraConnection.connect(host="localhost", dbname="test") + + def test_invalid_credential_type_raises_error(self): + """Test that non-TokenCredential raises CredentialValueError.""" + with pytest.raises(CredentialValueError, match="credential is required"): + EntraConnection.connect(host="localhost", dbname="test", credential="invalid") + + @patch("azure_postgresql_auth.psycopg3.entra_connection.get_entra_conninfo") + @patch("psycopg.Connection.connect") + def test_entra_credentials_injected(self, mock_connect, mock_get_conninfo): + """Test that Entra credentials are injected when user/password missing.""" + token = create_valid_jwt_token(TEST_USERS["ENTRA_USER"]) + mock_get_conninfo.return_value = { + "user": TEST_USERS["ENTRA_USER"], + "password": token, + } + mock_connect.return_value = MagicMock() + credential = MockTokenCredential(token) + EntraConnection.connect(host="localhost", dbname="test", credential=credential) + mock_get_conninfo.assert_called_once_with(credential) + + @patch("azure_postgresql_auth.psycopg3.entra_connection.get_entra_conninfo") + @patch("psycopg.Connection.connect") + def test_existing_credentials_preserved(self, mock_connect, mock_get_conninfo): + """Test that existing user/password are preserved.""" + mock_connect.return_value = MagicMock() + token = create_valid_jwt_token(TEST_USERS["ENTRA_USER"]) + credential = MockTokenCredential(token) + EntraConnection.connect( + host="localhost", + dbname="test", + user="existing", + password="secret", + credential=credential, + ) + mock_get_conninfo.assert_not_called() + + @patch("azure_postgresql_auth.psycopg3.entra_connection.get_entra_conninfo") + def test_entra_credential_failure_raises_error(self, mock_get_conninfo): + """Test that credential failure raises EntraConnectionValueError.""" + mock_get_conninfo.side_effect = Exception("auth failed") + token = create_valid_jwt_token(TEST_USERS["ENTRA_USER"]) + credential = MockTokenCredential(token) + + with pytest.raises(EntraConnectionValueError): + EntraConnection.connect(host="localhost", dbname="test", credential=credential) + + +@pytest.mark.live_test_only +class TestPsycopg3EntraConnectionLive: + """Live tests for psycopg3 sync EntraConnection against deployed Azure PostgreSQL.""" + + def test_connect_with_entra_user(self, connection_params, credential, postgresql_database): + """Test connecting with an Entra user using EntraConnection.""" + with EntraConnection.connect(**connection_params, credential=credential) as conn: + with conn.cursor() as cur: + cur.execute("SELECT current_user, current_database()") + current_user, current_db = cur.fetchone() + + assert current_db == postgresql_database + assert current_user is not None + + def test_execute_query(self, connection_params, credential): + """Test executing a basic query through EntraConnection.""" + with EntraConnection.connect(**connection_params, credential=credential) as conn: + with conn.cursor() as cur: + cur.execute("SELECT 1") + result = cur.fetchone() + assert result[0] == 1 + + def test_multiple_sequential_connections(self, connection_params, credential): + """Test that the same credential works across multiple sequential connections.""" + for _ in range(3): + with EntraConnection.connect(**connection_params, credential=credential) as conn: + with conn.cursor() as cur: + cur.execute("SELECT 1") + assert cur.fetchone()[0] == 1 diff --git a/sdk/postgresql/azure-postgresql-auth/tests/test_psycopg3_async.py b/sdk/postgresql/azure-postgresql-auth/tests/test_psycopg3_async.py new file mode 100644 index 000000000000..334296ebc9eb --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/tests/test_psycopg3_async.py @@ -0,0 +1,132 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +"""Async unit tests for psycopg3 AsyncEntraConnection.""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from azure_postgresql_auth.errors import CredentialValueError, EntraConnectionValueError +from azure_postgresql_auth.psycopg3 import AsyncEntraConnection + +from utils import ( + TEST_USERS, + MockAsyncTokenCredential, + create_valid_jwt_token, +) + + +class TestPsycopg3AsyncEntraConnection: + """Tests for psycopg3 async AsyncEntraConnection.""" + + @pytest.mark.asyncio + async def test_missing_credential_raises_error(self): + """Test that missing credential raises CredentialValueError.""" + with pytest.raises(CredentialValueError, match="credential is required"): + await AsyncEntraConnection.connect(host="localhost", dbname="test") + + @pytest.mark.asyncio + async def test_invalid_credential_type_raises_error(self): + """Test that non-AsyncTokenCredential raises CredentialValueError.""" + with pytest.raises(CredentialValueError, match="credential is required"): + await AsyncEntraConnection.connect(host="localhost", dbname="test", credential="invalid") + + @pytest.mark.asyncio + @patch("azure_postgresql_auth.psycopg3.async_entra_connection.get_entra_conninfo_async") + @patch("psycopg.AsyncConnection.connect", new_callable=AsyncMock) + async def test_entra_credentials_injected_async(self, mock_connect, mock_get_conninfo): + """Test that Entra credentials are injected when user/password missing (async).""" + token = create_valid_jwt_token(TEST_USERS["ENTRA_USER"]) + mock_get_conninfo.return_value = { + "user": TEST_USERS["ENTRA_USER"], + "password": token, + } + mock_connect.return_value = MagicMock() + credential = MockAsyncTokenCredential(token) + await AsyncEntraConnection.connect(host="localhost", dbname="test", credential=credential) + mock_get_conninfo.assert_called_once_with(credential) + + @pytest.mark.asyncio + @patch("azure_postgresql_auth.psycopg3.async_entra_connection.get_entra_conninfo_async") + @patch("psycopg.AsyncConnection.connect", new_callable=AsyncMock) + async def test_existing_credentials_preserved_async(self, mock_connect, mock_get_conninfo): + """Test that existing user/password are preserved (async).""" + mock_connect.return_value = MagicMock() + token = create_valid_jwt_token(TEST_USERS["ENTRA_USER"]) + credential = MockAsyncTokenCredential(token) + await AsyncEntraConnection.connect( + host="localhost", + dbname="test", + user="existing", + password="secret", + credential=credential, + ) + mock_get_conninfo.assert_not_called() + + @pytest.mark.asyncio + @patch("azure_postgresql_auth.psycopg3.async_entra_connection.get_entra_conninfo_async") + async def test_entra_credential_failure_raises_error_async(self, mock_get_conninfo): + """Test that credential failure raises EntraConnectionValueError (async).""" + mock_get_conninfo.side_effect = Exception("auth failed") + token = create_valid_jwt_token(TEST_USERS["ENTRA_USER"]) + credential = MockAsyncTokenCredential(token) + + with pytest.raises(EntraConnectionValueError): + await AsyncEntraConnection.connect(host="localhost", dbname="test", credential=credential) + + +@pytest.mark.live_test_only +class TestPsycopg3AsyncEntraConnectionLive: + """Live tests for psycopg3 async AsyncEntraConnection against deployed Azure PostgreSQL.""" + + @pytest.mark.asyncio + async def test_connect_with_entra_user_async(self, connection_params, async_credential, postgresql_database): + """Test connecting with an Entra user using AsyncEntraConnection.""" + async with await AsyncEntraConnection.connect(**connection_params, credential=async_credential) as conn: + async with conn.cursor() as cur: + await cur.execute("SELECT current_user, current_database()") + result = await cur.fetchone() + current_user, current_db = result + + assert current_db == postgresql_database + assert current_user is not None + + @pytest.mark.asyncio + async def test_execute_query_async(self, connection_params, async_credential): + """Test executing a basic query through AsyncEntraConnection.""" + async with await AsyncEntraConnection.connect(**connection_params, credential=async_credential) as conn: + async with conn.cursor() as cur: + await cur.execute("SELECT 1") + result = await cur.fetchone() + assert result[0] == 1 + + @pytest.mark.asyncio + async def test_multiple_sequential_connections_async(self, connection_params, async_credential): + """Test that the same credential works across multiple sequential connections.""" + for _ in range(3): + async with await AsyncEntraConnection.connect(**connection_params, credential=async_credential) as conn: + async with conn.cursor() as cur: + await cur.execute("SELECT 1") + result = await cur.fetchone() + assert result[0] == 1 + + @pytest.mark.asyncio + async def test_concurrent_async_connections(self, connection_params, async_credential): + """Test that multiple concurrent async connections work with the same credential.""" + + async def query(): + async with await AsyncEntraConnection.connect(**connection_params, credential=async_credential) as conn: + async with conn.cursor() as cur: + await cur.execute("SELECT 1") + result = await cur.fetchone() + return result[0] + + results = await asyncio.gather(query(), query(), query()) + assert results == [1, 1, 1] diff --git a/sdk/postgresql/azure-postgresql-auth/tests/test_sqlalchemy.py b/sdk/postgresql/azure-postgresql-auth/tests/test_sqlalchemy.py new file mode 100644 index 000000000000..80eeca8c8e11 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/tests/test_sqlalchemy.py @@ -0,0 +1,233 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +"""Unit tests for SQLAlchemy Entra authentication integration.""" + +from __future__ import annotations + +import sys +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy import create_engine, text + +from azure_postgresql_auth.errors import CredentialValueError, EntraConnectionValueError +from azure_postgresql_auth.sqlalchemy import enable_entra_authentication + +from utils import TEST_USERS, MockTokenCredential, create_valid_jwt_token + + +class TestSqlalchemyEntraAuthentication: + """Tests for enable_entra_authentication function.""" + + @patch("azure_postgresql_auth.sqlalchemy.entra_connection.get_entra_conninfo") + @patch("azure_postgresql_auth.sqlalchemy.entra_connection.event.listens_for") + def test_provides_entra_credentials(self, mock_listens_for, mock_get_conninfo): + """Test that the event handler provides Entra credentials.""" + token = create_valid_jwt_token(TEST_USERS["ENTRA_USER"]) + mock_get_conninfo.return_value = { + "user": TEST_USERS["ENTRA_USER"], + "password": token, + } + + # Capture the decorated handler + registered_handlers = [] + + def capture_decorator(engine, event_name): + def decorator(fn): + registered_handlers.append(fn) + return fn + + return decorator + + mock_listens_for.side_effect = capture_decorator + mock_engine = MagicMock() + enable_entra_authentication(mock_engine) + + assert len(registered_handlers) == 1 + handler = registered_handlers[0] + credential = MockTokenCredential(token) + cparams = {"credential": credential} + handler(MagicMock(), MagicMock(), [], cparams) + mock_get_conninfo.assert_called_once_with(credential) + assert cparams["user"] == TEST_USERS["ENTRA_USER"] + assert cparams["password"] == token + + @patch("azure_postgresql_auth.sqlalchemy.entra_connection.event.listens_for") + def test_missing_credential_raises_error(self, mock_listens_for): + """Test that the event handler raises CredentialValueError when no credential.""" + registered_handlers = [] + + def capture_decorator(engine, event_name): + def decorator(fn): + registered_handlers.append(fn) + return fn + + return decorator + + mock_listens_for.side_effect = capture_decorator + mock_engine = MagicMock() + enable_entra_authentication(mock_engine) + + assert len(registered_handlers) == 1 + handler = registered_handlers[0] + with pytest.raises(CredentialValueError): + handler(MagicMock(), MagicMock(), [], {}) + + @patch("azure_postgresql_auth.sqlalchemy.entra_connection.get_entra_conninfo") + def test_credential_removed_from_cparams(self, mock_get_conninfo): + """Test that the credential parameter is removed before DBAPI connect.""" + token = create_valid_jwt_token(TEST_USERS["ENTRA_USER"]) + mock_get_conninfo.return_value = { + "user": TEST_USERS["ENTRA_USER"], + "password": token, + } + credential = MockTokenCredential(token) + + # Simulate what the event handler does + cparams = {"credential": credential} + + mock_engine = MagicMock() + # We need to capture the registered handler + registered_handlers = [] + + def capture_handler(engine, event_name): + def decorator(fn): + registered_handlers.append(fn) + return fn + + return decorator + + with patch("azure_postgresql_auth.sqlalchemy.entra_connection.event.listens_for", side_effect=capture_handler): + enable_entra_authentication(mock_engine) + + if registered_handlers: + handler = registered_handlers[0] + handler(MagicMock(), MagicMock(), [], cparams) + assert "credential" not in cparams + assert cparams["user"] == TEST_USERS["ENTRA_USER"] + assert cparams["password"] == token + + @patch("azure_postgresql_auth.sqlalchemy.entra_connection.get_entra_conninfo") + @patch("azure_postgresql_auth.sqlalchemy.entra_connection.event.listens_for") + def test_existing_credentials_preserved(self, mock_listens_for, mock_get_conninfo): + """Test that existing user/password in cparams are preserved.""" + registered_handlers = [] + + def capture_decorator(engine, event_name): + def decorator(fn): + registered_handlers.append(fn) + return fn + + return decorator + + mock_listens_for.side_effect = capture_decorator + mock_engine = MagicMock() + token = create_valid_jwt_token(TEST_USERS["ENTRA_USER"]) + credential = MockTokenCredential(token) + enable_entra_authentication(mock_engine) + + assert len(registered_handlers) == 1 + handler = registered_handlers[0] + + # cparams already has user and password + cparams = {"credential": credential, "user": "existing", "password": "secret"} + handler(MagicMock(), MagicMock(), [], cparams) + + # get_entra_conninfo should NOT be called + mock_get_conninfo.assert_not_called() + # Original credentials are preserved + assert cparams["user"] == "existing" + assert cparams["password"] == "secret" + # credential is still removed + assert "credential" not in cparams + + @patch("azure_postgresql_auth.sqlalchemy.entra_connection.get_entra_conninfo") + @patch("azure_postgresql_auth.sqlalchemy.entra_connection.event.listens_for") + def test_entra_credential_failure_raises_error(self, mock_listens_for, mock_get_conninfo): + """Test that credential retrieval failure raises EntraConnectionValueError.""" + mock_get_conninfo.side_effect = Exception("auth failed") + + registered_handlers = [] + + def capture_decorator(engine, event_name): + def decorator(fn): + registered_handlers.append(fn) + return fn + + return decorator + + mock_listens_for.side_effect = capture_decorator + mock_engine = MagicMock() + token = create_valid_jwt_token(TEST_USERS["ENTRA_USER"]) + credential = MockTokenCredential(token) + enable_entra_authentication(mock_engine) + + assert len(registered_handlers) == 1 + handler = registered_handlers[0] + + with pytest.raises(EntraConnectionValueError, match="Could not retrieve Entra credentials"): + handler(MagicMock(), MagicMock(), [], {"credential": credential}) + + +@pytest.mark.live_test_only +@pytest.mark.skipif(sys.implementation.name == "pypy", reason="psycopg2 not supported on PyPy") +class TestSqlalchemyEntraAuthenticationLive: + """Live tests for synchronous SQLAlchemy with enable_entra_authentication.""" + + def test_connect_with_entra_user(self, connection_url, credential, postgresql_database): + """Test connecting with an Entra user using enable_entra_authentication.""" + engine = create_engine(connection_url, connect_args={"credential": credential}) + enable_entra_authentication(engine) + + with engine.connect() as conn: + result = conn.execute(text("SELECT current_user, current_database()")) + current_user, current_db = result.fetchone() + + assert current_db == postgresql_database + assert current_user is not None + + engine.dispose() + + def test_execute_query(self, connection_url, credential): + """Test executing a basic query through SQLAlchemy with Entra auth.""" + engine = create_engine(connection_url, connect_args={"credential": credential}) + enable_entra_authentication(engine) + + with engine.connect() as conn: + result = conn.execute(text("SELECT 1")) + assert result.fetchone()[0] == 1 + + engine.dispose() + + def test_multiple_sequential_connections(self, connection_url, credential): + """Test that the same credential works across multiple sequential connections.""" + engine = create_engine(connection_url, connect_args={"credential": credential}) + enable_entra_authentication(engine) + + for _ in range(3): + with engine.connect() as conn: + result = conn.execute(text("SELECT 1")) + assert result.fetchone()[0] == 1 + + engine.dispose() + + def test_connection_pool_reuse(self, connection_url, credential): + """Test that a pooled connection still works after being returned and reacquired.""" + engine = create_engine(connection_url, connect_args={"credential": credential}, pool_size=1, max_overflow=0) + enable_entra_authentication(engine) + + # First connection: use and return to pool + with engine.connect() as conn: + result = conn.execute(text("SELECT 1")) + assert result.fetchone()[0] == 1 + + # Second connection: reuse the pooled connection + with engine.connect() as conn: + result = conn.execute(text("SELECT current_user")) + assert result.fetchone()[0] is not None + + engine.dispose() diff --git a/sdk/postgresql/azure-postgresql-auth/tests/test_sqlalchemy_async.py b/sdk/postgresql/azure-postgresql-auth/tests/test_sqlalchemy_async.py new file mode 100644 index 000000000000..d8b3b561861a --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/tests/test_sqlalchemy_async.py @@ -0,0 +1,261 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +"""Async unit tests for SQLAlchemy Entra authentication integration.""" + +from __future__ import annotations + +import asyncio +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy import text +from sqlalchemy.ext.asyncio import create_async_engine + +from azure_postgresql_auth.errors import CredentialValueError, EntraConnectionValueError +from azure_postgresql_auth.sqlalchemy import enable_entra_authentication_async + +from utils import TEST_USERS, MockTokenCredential, create_valid_jwt_token + + +class TestSqlalchemyAsyncEntraAuthentication: + """Tests for enable_entra_authentication_async function.""" + + def test_async_function_exists(self): + """Test that the async authentication function is importable.""" + assert callable(enable_entra_authentication_async) + + @patch("azure_postgresql_auth.sqlalchemy.async_entra_connection.get_entra_conninfo") + @patch("azure_postgresql_auth.sqlalchemy.async_entra_connection.event.listens_for") + def test_provides_entra_credentials_async(self, mock_listens_for, mock_get_conninfo): + """Test that the event handler provides Entra credentials (async engine).""" + token = create_valid_jwt_token(TEST_USERS["ENTRA_USER"]) + mock_get_conninfo.return_value = { + "user": TEST_USERS["ENTRA_USER"], + "password": token, + } + + registered_handlers = [] + + def capture_decorator(engine, event_name): + def decorator(fn): + registered_handlers.append(fn) + return fn + + return decorator + + mock_listens_for.side_effect = capture_decorator + mock_engine = MagicMock() + enable_entra_authentication_async(mock_engine) + + assert len(registered_handlers) == 1 + handler = registered_handlers[0] + credential = MockTokenCredential(token) + cparams = {"credential": credential} + handler(MagicMock(), MagicMock(), [], cparams) + mock_get_conninfo.assert_called_once_with(credential) + assert cparams["user"] == TEST_USERS["ENTRA_USER"] + assert cparams["password"] == token + + @patch("azure_postgresql_auth.sqlalchemy.async_entra_connection.event.listens_for") + def test_missing_credential_raises_error_async(self, mock_listens_for): + """Test that the event handler raises CredentialValueError when no credential.""" + registered_handlers = [] + + def capture_decorator(engine, event_name): + def decorator(fn): + registered_handlers.append(fn) + return fn + + return decorator + + mock_listens_for.side_effect = capture_decorator + mock_engine = MagicMock() + enable_entra_authentication_async(mock_engine) + + assert len(registered_handlers) == 1 + handler = registered_handlers[0] + with pytest.raises(CredentialValueError): + handler(MagicMock(), MagicMock(), [], {}) + + @patch("azure_postgresql_auth.sqlalchemy.async_entra_connection.get_entra_conninfo") + def test_credential_removed_from_cparams_async(self, mock_get_conninfo): + """Test that the credential parameter is removed before DBAPI connect (async).""" + token = create_valid_jwt_token(TEST_USERS["ENTRA_USER"]) + mock_get_conninfo.return_value = { + "user": TEST_USERS["ENTRA_USER"], + "password": token, + } + credential = MockTokenCredential(token) + + cparams = {"credential": credential} + + mock_engine = MagicMock() + registered_handlers = [] + + def capture_handler(engine, event_name): + def decorator(fn): + registered_handlers.append(fn) + return fn + + return decorator + + with patch( + "azure_postgresql_auth.sqlalchemy.async_entra_connection.event.listens_for", + side_effect=capture_handler, + ): + enable_entra_authentication_async(mock_engine) + + if registered_handlers: + handler = registered_handlers[0] + handler(MagicMock(), MagicMock(), [], cparams) + assert "credential" not in cparams + assert cparams["user"] == TEST_USERS["ENTRA_USER"] + assert cparams["password"] == token + + @patch("azure_postgresql_auth.sqlalchemy.async_entra_connection.get_entra_conninfo") + @patch("azure_postgresql_auth.sqlalchemy.async_entra_connection.event.listens_for") + def test_existing_credentials_preserved_async(self, mock_listens_for, mock_get_conninfo): + """Test that existing user/password in cparams are preserved (async).""" + registered_handlers = [] + + def capture_decorator(engine, event_name): + def decorator(fn): + registered_handlers.append(fn) + return fn + + return decorator + + mock_listens_for.side_effect = capture_decorator + mock_engine = MagicMock() + token = create_valid_jwt_token(TEST_USERS["ENTRA_USER"]) + credential = MockTokenCredential(token) + enable_entra_authentication_async(mock_engine) + + assert len(registered_handlers) == 1 + handler = registered_handlers[0] + + # cparams already has user and password + cparams = {"credential": credential, "user": "existing", "password": "secret"} + handler(MagicMock(), MagicMock(), [], cparams) + + # get_entra_conninfo should NOT be called + mock_get_conninfo.assert_not_called() + # Original credentials are preserved + assert cparams["user"] == "existing" + assert cparams["password"] == "secret" + # credential is still removed + assert "credential" not in cparams + + @patch("azure_postgresql_auth.sqlalchemy.async_entra_connection.get_entra_conninfo") + @patch("azure_postgresql_auth.sqlalchemy.async_entra_connection.event.listens_for") + def test_entra_credential_failure_raises_error_async(self, mock_listens_for, mock_get_conninfo): + """Test that credential retrieval failure raises EntraConnectionValueError (async).""" + mock_get_conninfo.side_effect = Exception("auth failed") + + registered_handlers = [] + + def capture_decorator(engine, event_name): + def decorator(fn): + registered_handlers.append(fn) + return fn + + return decorator + + mock_listens_for.side_effect = capture_decorator + mock_engine = MagicMock() + token = create_valid_jwt_token(TEST_USERS["ENTRA_USER"]) + credential = MockTokenCredential(token) + enable_entra_authentication_async(mock_engine) + + assert len(registered_handlers) == 1 + handler = registered_handlers[0] + + with pytest.raises(EntraConnectionValueError, match="Could not retrieve Entra credentials"): + handler(MagicMock(), MagicMock(), [], {"credential": credential}) + + +@pytest.mark.live_test_only +class TestSqlalchemyAsyncEntraAuthenticationLive: + """Live tests for asynchronous SQLAlchemy with enable_entra_authentication_async.""" + + @pytest.mark.asyncio + async def test_connect_with_entra_user_async(self, async_connection_url, credential, postgresql_database): + """Test connecting with an Entra user using enable_entra_authentication_async.""" + engine = create_async_engine(async_connection_url, connect_args={"credential": credential}) + enable_entra_authentication_async(engine) + + async with engine.connect() as conn: + result = await conn.execute(text("SELECT current_user, current_database()")) + row = result.fetchone() + current_user, current_db = row + + assert current_db == postgresql_database + assert current_user is not None + + await engine.dispose() + + @pytest.mark.asyncio + async def test_execute_query_async(self, async_connection_url, credential): + """Test executing a basic query through async SQLAlchemy with Entra auth.""" + engine = create_async_engine(async_connection_url, connect_args={"credential": credential}) + enable_entra_authentication_async(engine) + + async with engine.connect() as conn: + result = await conn.execute(text("SELECT 1")) + row = result.fetchone() + assert row[0] == 1 + + await engine.dispose() + + @pytest.mark.asyncio + async def test_multiple_sequential_connections_async(self, async_connection_url, credential): + """Test that the same credential works across multiple sequential connections.""" + engine = create_async_engine(async_connection_url, connect_args={"credential": credential}) + enable_entra_authentication_async(engine) + + for _ in range(3): + async with engine.connect() as conn: + result = await conn.execute(text("SELECT 1")) + assert result.fetchone()[0] == 1 + + await engine.dispose() + + @pytest.mark.asyncio + async def test_concurrent_async_connections(self, async_connection_url, credential): + """Test that multiple concurrent async connections work with the same credential.""" + engine = create_async_engine(async_connection_url, connect_args={"credential": credential}) + enable_entra_authentication_async(engine) + + async def query(): + async with engine.connect() as conn: + result = await conn.execute(text("SELECT 1")) + return result.fetchone()[0] + + results = await asyncio.gather(query(), query(), query()) + assert results == [1, 1, 1] + + await engine.dispose() + + @pytest.mark.asyncio + async def test_connection_pool_reuse_async(self, async_connection_url, credential): + """Test that a pooled connection still works after being returned and reacquired.""" + engine = create_async_engine( + async_connection_url, connect_args={"credential": credential}, pool_size=1, max_overflow=0 + ) + enable_entra_authentication_async(engine) + + # First connection: use and return to pool + async with engine.connect() as conn: + result = await conn.execute(text("SELECT 1")) + assert result.fetchone()[0] == 1 + + # Second connection: reuse the pooled connection + async with engine.connect() as conn: + result = await conn.execute(text("SELECT current_user")) + assert result.fetchone()[0] is not None + + await engine.dispose() diff --git a/sdk/postgresql/azure-postgresql-auth/tests/utils.py b/sdk/postgresql/azure-postgresql-auth/tests/utils.py new file mode 100644 index 000000000000..f121dbe0fc16 --- /dev/null +++ b/sdk/postgresql/azure-postgresql-auth/tests/utils.py @@ -0,0 +1,119 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +"""Common utility functions and test credentials for unit tests.""" + +from __future__ import annotations + +import base64 +import json +from datetime import datetime, timedelta, timezone + +from azure.core.credentials import AccessToken, TokenCredential +from azure.core.credentials_async import AsyncTokenCredential + +# Test user constants +TEST_USERS = { + "ENTRA_USER": "test@example.com", + "MANAGED_IDENTITY_PATH": "/subscriptions/12345/resourcegroups/group/providers/Microsoft.ManagedIdentity/userAssignedIdentities/managed-identity", + "MANAGED_IDENTITY_NAME": "managed-identity", + "FALLBACK_USER": "fallback@example.com", +} + + +def create_base64_url_string(input_str: str) -> str: + """Create a base64url encoded string.""" + encoded = base64.urlsafe_b64encode(input_str.encode()).decode() + return encoded.rstrip("=") + + +def create_valid_jwt_token(username: str) -> str: + """Create a fake JWT token with a UPN claim.""" + header = {"alg": "RS256", "typ": "JWT"} + payload = {"upn": username, "iat": 1234567890, "exp": 9999999999} + header_encoded = create_base64_url_string(json.dumps(header)) + payload_encoded = create_base64_url_string(json.dumps(payload)) + return f"{header_encoded}.{payload_encoded}.fake-signature" + + +def create_jwt_token_with_xms_mirid(xms_mirid: str) -> str: + """Create a fake JWT token with an xms_mirid claim for managed identity.""" + header = {"alg": "RS256", "typ": "JWT"} + payload = {"xms_mirid": xms_mirid, "iat": 1234567890, "exp": 9999999999} + header_encoded = create_base64_url_string(json.dumps(header)) + payload_encoded = create_base64_url_string(json.dumps(payload)) + return f"{header_encoded}.{payload_encoded}.fake-signature" + + +def create_jwt_token_with_preferred_username(username: str) -> str: + """Create a fake JWT token with a preferred_username claim.""" + header = {"alg": "RS256", "typ": "JWT"} + payload = {"preferred_username": username, "iat": 1234567890, "exp": 9999999999} + header_encoded = create_base64_url_string(json.dumps(header)) + payload_encoded = create_base64_url_string(json.dumps(payload)) + return f"{header_encoded}.{payload_encoded}.fake-signature" + + +def create_jwt_token_with_unique_name(username: str) -> str: + """Create a fake JWT token with a unique_name claim.""" + header = {"alg": "RS256", "typ": "JWT"} + payload = {"unique_name": username, "iat": 1234567890, "exp": 9999999999} + header_encoded = create_base64_url_string(json.dumps(header)) + payload_encoded = create_base64_url_string(json.dumps(payload)) + return f"{header_encoded}.{payload_encoded}.fake-signature" + + +class MockTokenCredential(TokenCredential): + """Mock token credential for synchronous operations.""" + + def __init__(self, token: str): + self._token = token + self._call_count = 0 + + def get_token(self, *scopes, **kwargs) -> AccessToken: + """Return a fake access token.""" + self._call_count += 1 + expires_on = datetime.now(timezone.utc) + timedelta(hours=1) + return AccessToken(self._token, int(expires_on.timestamp())) + + def get_call_count(self) -> int: + """Return the number of times get_token was called.""" + return self._call_count + + def reset_call_count(self) -> None: + """Reset the call count.""" + self._call_count = 0 + + +class MockAsyncTokenCredential(AsyncTokenCredential): + """Mock token credential for asynchronous operations.""" + + def __init__(self, token: str): + self._token = token + self._call_count = 0 + + async def get_token(self, *scopes, **kwargs) -> AccessToken: + """Return a fake access token asynchronously.""" + self._call_count += 1 + expires_on = datetime.now(timezone.utc) + timedelta(hours=1) + return AccessToken(self._token, int(expires_on.timestamp())) + + async def close(self) -> None: + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + def get_call_count(self) -> int: + """Return the number of times get_token was called.""" + return self._call_count + + def reset_call_count(self) -> None: + """Reset the call count.""" + self._call_count = 0 diff --git a/sdk/postgresql/ci.yml b/sdk/postgresql/ci.yml new file mode 100644 index 000000000000..db2697e09971 --- /dev/null +++ b/sdk/postgresql/ci.yml @@ -0,0 +1,33 @@ +# NOTE: Please refer to https://aka.ms/azsdk/engsys/ci-yaml before editing this file. + +trigger: + branches: + include: + - main + - hotfix/* + - release/* + - restapi* + paths: + include: + - sdk/postgresql/ + +pr: + branches: + include: + - main + - feature/* + - hotfix/* + - release/* + - restapi* + paths: + include: + - sdk/postgresql/ + +extends: + template: ../../eng/pipelines/templates/stages/archetype-sdk-client.yml + parameters: + ServiceDirectory: postgresql + TestProxy: true + Artifacts: + - name: azure-postgresql-auth + safeName: azurepostgresqlauth diff --git a/sdk/postgresql/cspell.yaml b/sdk/postgresql/cspell.yaml new file mode 100644 index 000000000000..e81b84168faf --- /dev/null +++ b/sdk/postgresql/cspell.yaml @@ -0,0 +1,10 @@ +# Spell checking is not case sensitive +# Keep word lists in alphabetical order so the file is easier to manage +version: "0.2" +import: + - ../../.vscode/cspell.json +words: + - cargs + - mirid + - testdb + - undoc diff --git a/sdk/postgresql/test-resources-pre.ps1 b/sdk/postgresql/test-resources-pre.ps1 new file mode 100644 index 000000000000..5e74b3c75958 --- /dev/null +++ b/sdk/postgresql/test-resources-pre.ps1 @@ -0,0 +1,54 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +# IMPORTANT: Do not invoke this file directly. Please instead run eng/New-TestResources.ps1 from the repository root. + +# Use same parameter names as declared in eng/New-TestResources.ps1 (assume validation therein). +[CmdletBinding(SupportsShouldProcess = $true, ConfirmImpact = 'Medium')] +param ( + [Parameter()] + [string] $TestApplicationOid, + + [Parameter()] + [hashtable] $AdditionalParameters = @{}, + + # Captures any arguments from eng/New-TestResources.ps1 not declared here (no parameter errors). + [Parameter(ValueFromRemainingArguments = $true)] + $RemainingArguments +) + +# By default stop for any error. +if (!$PSBoundParameters.ContainsKey('ErrorAction')) { + $ErrorActionPreference = 'Stop' +} + +function Log($Message) { + Write-Host ('{0} - {1}' -f [DateTime]::Now.ToLongTimeString(), $Message) +} + +# Resolve the principal display name from the current context to use as the +# PostgreSQL Entra ID administrator. The test framework provides the OID but +# the PostgreSQL admin resource requires a display name / UPN. + +$UserPrincipalName = $AdditionalParameters['UserPrincipalName'] +if ($UserPrincipalName) { + # Principal name was explicitly provided — use it directly. + $principalName = $UserPrincipalName + $principalType = 'User' + Log "Using provided user principal name: '$principalName'" +} elseif ($CI) { + # In CI the test application is a service principal. Look up its display name. + Log "Resolving service principal display name for OID '$TestApplicationOid'" + $sp = Get-AzADServicePrincipal -ObjectId $TestApplicationOid + $principalName = $sp.DisplayName + $principalType = 'ServicePrincipal' + Log "Resolved service principal name: '$principalName'" +} else { + # Running locally — use the signed-in user's UPN. + $principalName = (Get-AzContext).Account.Id + $principalType = 'User' + Log "Using signed-in user principal name: '$principalName'" +} + +$templateFileParameters['principalName'] = $principalName +$templateFileParameters['principalType'] = $principalType diff --git a/sdk/postgresql/test-resources.bicep b/sdk/postgresql/test-resources.bicep new file mode 100644 index 000000000000..fe18ef8eddad --- /dev/null +++ b/sdk/postgresql/test-resources.bicep @@ -0,0 +1,83 @@ +@description('The base resource name.') +param baseName string = resourceGroup().name + +@description('The principal to assign the role to. This is the application object id.') +param testApplicationOid string + +@description('The display name or UPN of the test principal. Resolved by test-resources-pre.ps1.') +param principalName string = '' + +@description('The type of the test principal. Resolved by test-resources-pre.ps1.') +@allowed([ + 'User' + 'ServicePrincipal' +]) +param principalType string = 'ServicePrincipal' + +@description('The location of the resource group.') +param location string = resourceGroup().location + +@description('The tenant ID.') +param tenantId string = subscription().tenantId + +var serverName = '${baseName}-pg' +var databaseName = 'testdb' + +resource postgresServer 'Microsoft.DBforPostgreSQL/flexibleServers@2022-12-01' = { + name: serverName + location: location + sku: { + name: 'Standard_B1ms' + tier: 'Burstable' + } + properties: { + version: '17' + storage: { + storageSizeGB: 32 + } + backup: { + backupRetentionDays: 7 + geoRedundantBackup: 'Disabled' + } + highAvailability: { + mode: 'Disabled' + } + authConfig: { + activeDirectoryAuth: 'Enabled' + passwordAuth: 'Disabled' + tenantId: tenantId + } + } +} + +resource database 'Microsoft.DBforPostgreSQL/flexibleServers/databases@2022-12-01' = { + name: databaseName + parent: postgresServer + properties: {} +} + +resource entraAdmin 'Microsoft.DBforPostgreSQL/flexibleServers/administrators@2022-12-01' = { + name: testApplicationOid + parent: postgresServer + properties: { + principalName: principalName + principalType: principalType + tenantId: tenantId + } + dependsOn: [ + database + ] +} + +resource firewallAllowAzure 'Microsoft.DBforPostgreSQL/flexibleServers/firewallRules@2022-12-01' = { + name: 'AllowAllAzureServicesAndResourcesWithinAzureIps' + parent: postgresServer + properties: { + startIpAddress: '0.0.0.0' + endIpAddress: '0.0.0.0' + } +} + +output POSTGRESQL_HOST string = '${serverName}.postgres.database.azure.com' +output POSTGRESQL_DATABASE string = databaseName +output POSTGRESQL_PORT string = '5432' diff --git a/sdk/postgresql/tests.yml b/sdk/postgresql/tests.yml new file mode 100644 index 000000000000..dc57c73674af --- /dev/null +++ b/sdk/postgresql/tests.yml @@ -0,0 +1,11 @@ +trigger: none + +extends: + template: /eng/pipelines/templates/stages/archetype-sdk-tests.yml + parameters: + BuildTargetingString: azure-postgresql-auth + ServiceDirectory: postgresql + Location: eastus2 + EnvVars: + AZURE_SKIP_LIVE_RECORDING: 'true' + AZURE_TEST_RUN_LIVE: 'true'