Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 69 additions & 69 deletions pyiceberg/catalog/rest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,15 @@

from pyiceberg import __version__
from pyiceberg.catalog import BOTOCORE_SESSION, TOKEN, URI, WAREHOUSE_LOCATION, Catalog, PropertiesUpdateSummary
from pyiceberg.catalog.rest.auth import AUTH_MANAGER, AuthManager, AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager
from pyiceberg.catalog.rest.auth import (
AUTH_MANAGER,
AuthManager,
AuthManagerAdapter,
AuthManagerFactory,
LegacyOAuth2AuthManager,
NoopAuthManager,
SigV4AuthManager,
)
from pyiceberg.catalog.rest.response import _handle_non_200_response
from pyiceberg.catalog.rest.scan_planning import (
FetchScanTasksRequest,
Expand Down Expand Up @@ -251,11 +259,11 @@ class ScanPlanningMode(Enum):
CA_BUNDLE = "cabundle"
SSL = "ssl"
SIGV4 = "rest.sigv4-enabled"
SIGV4_AUTH_TYPE = "sigv4"
SIGV4_REGION = "rest.signing-region"
SIGV4_SERVICE = "rest.signing-name"
SIGV4_MAX_RETRIES = "rest.sigv4.max-retries"
SIGV4_MAX_RETRIES_DEFAULT = 10
EMPTY_BODY_SHA256: str = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
OAUTH2_SERVER_URI = "oauth2-server-uri"
SNAPSHOT_LOADING_MODE = "snapshot-loading-mode"
AUTH = "auth"
Expand Down Expand Up @@ -431,10 +439,49 @@ def _create_session(self) -> Session:
elif ssl_client_cert := ssl_client.get(CERT):
session.cert = ssl_client_cert

self._auth_manager = self._build_auth_manager(session)
session.auth = AuthManagerAdapter(self._auth_manager)

# SigV4 retry is decoupled from signing: mount a plain retry adapter.
if self._is_sigv4_enabled():
from requests.adapters import HTTPAdapter

max_retries = property_as_int(self.properties, SIGV4_MAX_RETRIES, SIGV4_MAX_RETRIES_DEFAULT)
session.mount(self.uri, HTTPAdapter(max_retries=max_retries))

return session

def _is_sigv4_enabled(self) -> bool:
"""Return True if SigV4 signing is requested via either config path."""
if property_as_bool(self.properties, SIGV4, False):
return True
auth_config = self.properties.get(AUTH)
return auth_config is not None and auth_config.get("type") == SIGV4_AUTH_TYPE

def _build_auth_manager(self, session: Session) -> AuthManager:
"""Build the AuthManager, wrapping the delegate in SigV4 when enabled."""
delegate = self._build_delegate_auth_manager(session)
if self._is_sigv4_enabled():
return self._build_sigv4_auth_manager(delegate)
return delegate

def _build_delegate_auth_manager(self, session: Session) -> AuthManager:
"""Build the header-based AuthManager (the SigV4 delegate, or the manager used directly)."""
if auth_config := self.properties.get(AUTH):
auth_type = auth_config.get("type")
if auth_type is None:
raise ValueError("auth.type must be defined")

if auth_type == SIGV4_AUTH_TYPE:
# The delegate is configured under auth.sigv4.delegate.*
sigv4_config = auth_config.get(SIGV4_AUTH_TYPE, {})
delegate_config = sigv4_config.get("delegate")
if not delegate_config or "type" not in delegate_config:
# No delegate configured: SigV4-only auth, with no header-based delegate.
return NoopAuthManager()
delegate_type = delegate_config["type"]
return AuthManagerFactory.create(delegate_type, delegate_config.get(delegate_type, {}))

auth_type_config = auth_config.get(auth_type, {})
auth_impl = auth_config.get("impl")

Expand All @@ -444,17 +491,28 @@ def _create_session(self) -> Session:
if auth_type != CUSTOM and auth_impl:
raise ValueError("auth.impl can only be specified when using custom auth.type")

self._auth_manager = AuthManagerFactory.create(auth_impl or auth_type, auth_type_config)
session.auth = AuthManagerAdapter(self._auth_manager)
else:
self._auth_manager = self._create_legacy_oauth2_auth_manager(session)
session.auth = AuthManagerAdapter(self._auth_manager)
return AuthManagerFactory.create(auth_impl or auth_type, auth_type_config)

# Configure SigV4 Request Signing
if property_as_bool(self.properties, SIGV4, False):
self._init_sigv4(session)
return self._create_legacy_oauth2_auth_manager(session)

return session
def _build_sigv4_auth_manager(self, delegate: AuthManager) -> AuthManager:
"""Wrap the delegate AuthManager in a SigV4AuthManager."""
import boto3

boto_session = boto3.Session(
profile_name=get_first_property_value(self.properties, AWS_PROFILE_NAME),
region_name=get_first_property_value(self.properties, AWS_REGION),
botocore_session=self.properties.get(BOTOCORE_SESSION),
aws_access_key_id=get_first_property_value(self.properties, AWS_ACCESS_KEY_ID),
aws_secret_access_key=get_first_property_value(self.properties, AWS_SECRET_ACCESS_KEY),
aws_session_token=get_first_property_value(self.properties, AWS_SESSION_TOKEN),
)
return SigV4AuthManager(
delegate=delegate,
boto_session=boto_session,
region=self.properties.get(SIGV4_REGION),
service=self.properties.get(SIGV4_SERVICE, "execute-api"),
)

@staticmethod
def _resolve_storage_credentials(storage_credentials: list[StorageCredential], location: str | None) -> Properties:
Expand Down Expand Up @@ -757,64 +815,6 @@ def _split_identifier_for_json(self, identifier: str | Identifier) -> dict[str,
identifier_tuple = self._identifier_to_validated_tuple(identifier)
return {"namespace": identifier_tuple[:-1], "name": identifier_tuple[-1]}

def _init_sigv4(self, session: Session) -> None:
from urllib import parse

import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from requests import PreparedRequest
from requests.adapters import HTTPAdapter

class SigV4Adapter(HTTPAdapter):
def __init__(self, **properties: str):
self._properties = properties
max_retries = property_as_int(self._properties, SIGV4_MAX_RETRIES, SIGV4_MAX_RETRIES_DEFAULT)
super().__init__(max_retries=max_retries)
self._boto_session = boto3.Session(
profile_name=get_first_property_value(self._properties, AWS_PROFILE_NAME),
region_name=get_first_property_value(self._properties, AWS_REGION),
botocore_session=self._properties.get(BOTOCORE_SESSION),
aws_access_key_id=get_first_property_value(self._properties, AWS_ACCESS_KEY_ID),
aws_secret_access_key=get_first_property_value(self._properties, AWS_SECRET_ACCESS_KEY),
aws_session_token=get_first_property_value(self._properties, AWS_SESSION_TOKEN),
)

def add_headers(self, request: PreparedRequest, **kwargs: Any) -> None: # pylint: disable=W0613
credentials = self._boto_session.get_credentials().get_frozen_credentials()
region = self._properties.get(SIGV4_REGION, self._boto_session.region_name)
service = self._properties.get(SIGV4_SERVICE, "execute-api")

url = str(request.url).split("?")[0]
query = str(parse.urlsplit(request.url).query)
params = dict(parse.parse_qsl(query))

# remove the connection header as it will be updated after signing
if "connection" in request.headers:
del request.headers["connection"]
# For empty bodies, explicitly set the content hash header to the SHA256 of an empty string
if not request.body:
request.headers["x-amz-content-sha256"] = EMPTY_BODY_SHA256

aws_request = AWSRequest(
method=request.method, url=url, params=params, data=request.body, headers=dict(request.headers)
)

SigV4Auth(credentials, service, region).add_auth(aws_request)
original_header = request.headers
signed_headers = aws_request.headers
relocated_headers = {}

# relocate headers if there is a conflict with signed headers
for header, value in original_header.items():
if header in signed_headers and signed_headers[header] != value:
relocated_headers[f"Original-{header}"] = value

request.headers.update(relocated_headers)
request.headers.update(signed_headers)

session.mount(self.uri, SigV4Adapter(**self.properties))

def _response_to_table(self, identifier_tuple: tuple[str, ...], table_response: TableResponse) -> Table:
# Per Iceberg spec: storage-credentials take precedence over config
credential_config = self._resolve_storage_credentials(
Expand Down
134 changes: 130 additions & 4 deletions pyiceberg/catalog/rest/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import threading
import time
from abc import ABC, abstractmethod
from functools import cached_property
from functools import cache, cached_property
from typing import Any

import requests
Expand All @@ -36,6 +36,37 @@
COLON = ":"
logger = logging.getLogger(__name__)

# SHA-256 of an empty payload. Used as the x-amz-content-sha256 header value for
# empty-body requests, matching Iceberg Java's RESTSigV4AuthSession workaround.
EMPTY_BODY_SHA256 = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"


@cache
def _iceberg_sigv4_auth_class() -> type:
"""Lazily build the botocore SigV4Auth subclass (botocore is an optional dependency)."""
from urllib import parse

from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest

class _IcebergSigV4Auth(SigV4Auth):
def canonical_request(self, request: AWSRequest) -> str:
# Override forces the hex payload hash in the canonical request even when
# the x-amz-content-sha256 header is base64 (see SigV4AuthManager.sign_request).
# Mirrors botocore <=1.42.x SigV4Auth.canonical_request layout:
# https://github.com/boto/botocore/blob/1.42.85/botocore/auth.py#L622-L637
cr = [request.method.upper()]
path = self._normalize_url_path(parse.urlsplit(request.url).path)
cr.append(path)
cr.append(self.canonical_query_string(request))
headers_to_sign = self.headers_to_sign(request)
cr.append(self.canonical_headers(headers_to_sign) + "\n")
cr.append(self.signed_headers(headers_to_sign))
cr.append(self.payload(request))
return "\n".join(cr)

return _IcebergSigV4Auth


class AuthManager(ABC):
"""
Expand All @@ -48,6 +79,14 @@ class AuthManager(ABC):
def auth_header(self) -> str | None:
"""Return the Authorization header value, or None if not applicable."""

def sign_request(self, request: PreparedRequest) -> PreparedRequest:
"""Optionally sign or otherwise modify the prepared request.

The default implementation is a no-op. Override for request-signing
schemes such as SigV4 that must inspect the full request.
"""
return request


class NoopAuthManager(AuthManager):
"""Auth Manager implementation with no auth."""
Expand Down Expand Up @@ -311,6 +350,91 @@ def auth_header(self) -> str:
return f"Bearer {self._get_token()}"


class SigV4AuthManager(AuthManager):
"""AuthManager that signs requests with AWS SigV4, wrapping a delegate AuthManager.

Mirrors Iceberg Java's RESTSigV4AuthManager: the delegate AuthManager handles
header-based auth (e.g. OAuth2), then SigV4 signs the resulting request.
"""

def __init__(
self,
delegate: AuthManager,
boto_session: Any,
region: str | None,
service: str = "execute-api",
):
"""Initialize SigV4AuthManager.

Args:
delegate: AuthManager that supplies header-based auth before signing.
boto_session: A boto3.Session used to resolve AWS credentials.
region: SigV4 signing region; falls back to the boto session's region.
service: SigV4 signing service name.
"""
self._delegate = delegate
self._boto_session = boto_session
self._region = region
self._service = service

def auth_header(self) -> str | None:
return self._delegate.auth_header()

def sign_request(self, request: PreparedRequest) -> PreparedRequest:
import hashlib
from urllib import parse

from botocore.awsrequest import AWSRequest

credentials = self._boto_session.get_credentials().get_frozen_credentials()
region = self._region or self._boto_session.region_name

url = str(request.url).split("?")[0]
query = str(parse.urlsplit(request.url).query)
params = dict(parse.parse_qsl(query))

# remove the connection header as it will be updated after signing
if "connection" in request.headers:
del request.headers["connection"]

# Match Iceberg Java's AWS SDK v2 flexible-checksum signing:
# x-amz-content-sha256 header is base64 for non-empty bodies, hex for empty.
# The SigV4 canonical request still uses hex (enforced in _iceberg_sigv4_auth_class).
# Ref: https://github.com/apache/iceberg/blob/main/aws/src/main/java/org/apache/iceberg/aws/RESTSigV4AuthSession.java
if request.body:
if isinstance(request.body, str):
body_bytes = request.body.encode("utf-8")
elif isinstance(request.body, (bytes, bytearray)):
body_bytes = bytes(request.body)
else:
raise TypeError(
f"Unsupported request body type for SigV4 signing: {type(request.body).__name__}; expected str or bytes."
)
content_sha256_header = base64.b64encode(hashlib.sha256(body_bytes).digest()).decode()
else:
content_sha256_header = EMPTY_BODY_SHA256

signing_headers = dict(request.headers)
signing_headers["x-amz-content-sha256"] = content_sha256_header

aws_request = AWSRequest(method=request.method, url=url, params=params, data=request.body, headers=signing_headers)

_iceberg_sigv4_auth_class()(credentials, self._service, region).add_auth(aws_request)

original_header = dict(request.headers)
signed_headers = dict(aws_request.headers)
relocated_headers = {}

# relocate headers if there is a conflict with signed headers
for header, value in original_header.items():
if header in signed_headers and signed_headers[header] != value:
relocated_headers[f"Original-{header}"] = value

request.headers.update(relocated_headers)
request.headers.update(signed_headers)
return request


class AuthManagerAdapter(AuthBase):
"""A `requests.auth.AuthBase` adapter for integrating an `AuthManager` into a `requests.Session`.

Expand All @@ -332,17 +456,19 @@ def __init__(self, auth_manager: AuthManager):

def __call__(self, request: PreparedRequest) -> PreparedRequest:
"""
Modify the outgoing request to include the Authorization header.
Modify the outgoing request to include the Authorization header and any signature.

Args:
request (requests.PreparedRequest): The HTTP request being prepared.

Returns:
requests.PreparedRequest: The modified request with Authorization header.
requests.PreparedRequest: The modified request.
"""
if auth_header := self.auth_manager.auth_header():
request.headers["Authorization"] = auth_header
return request
# Header first, then sign: a request-signing AuthManager (e.g. SigV4) must
# see the Authorization header so it can relocate it before signing.
return self.auth_manager.sign_request(request)


class AuthManagerFactory:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ sql-postgres = [
]
sql-sqlite = ["sqlalchemy>=2.0.18,<3"]
gcsfs = ["gcsfs>=2023.1.0"]
rest-sigv4 = ["boto3>=1.24.59"]
rest-sigv4 = ["boto3>=1.24.59", "botocore<2"]
hf = ["huggingface-hub>=0.24.0"]
pyiceberg-core = ["pyiceberg-core>=0.5.1,<0.10.0"]
datafusion = ["datafusion>=52,<53"]
Expand Down
Loading
Loading