Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mkdocs/docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,9 @@ catalog:
| hive.kerberos-authentication | true | Using authentication via Kerberos |
| hive.kerberos-service-name | hive | Kerberos service name (default hive) |
| ugi | t-1234:secret | Hadoop UGI for Hive client. |
| hive.metastore.authentication | DIGEST-MD5 | Auth mechanism: `NONE` (default), `KERBEROS`, or `DIGEST-MD5` |

When using DIGEST-MD5 authentication, PyIceberg reads a Hive delegation token from the file pointed to by the `$HADOOP_TOKEN_FILE_LOCATION` environment variable. This is the standard mechanism used in secure Hadoop environments where delegation tokens are distributed to jobs. Install PyIceberg with `pip install "pyiceberg[hive]"` to get the required `pure-sasl` dependency.

When using Hive 2.x, make sure to set the compatibility flag:

Expand Down
51 changes: 47 additions & 4 deletions pyiceberg/catalog/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
)
from pyiceberg.exceptions import (
CommitFailedException,
HiveAuthError,
NamespaceAlreadyExistsError,
NamespaceNotEmptyError,
NoSuchIcebergTableError,
Expand Down Expand Up @@ -109,6 +110,7 @@
UnknownType,
UUIDType,
)
from pyiceberg.utils.hadoop_credentials import read_hive_delegation_token
from pyiceberg.utils.properties import property_as_bool, property_as_float

if TYPE_CHECKING:
Expand All @@ -127,6 +129,9 @@
HIVE_KERBEROS_SERVICE_NAME = "hive.kerberos-service-name"
HIVE_KERBEROS_SERVICE_NAME_DEFAULT = "hive"

HIVE_METASTORE_AUTH = "hive.metastore.authentication"
HIVE_METASTORE_AUTH_DEFAULT = "NONE"

LOCK_CHECK_MIN_WAIT_TIME = "lock-check-min-wait-time"
LOCK_CHECK_MAX_WAIT_TIME = "lock-check-max-wait-time"
LOCK_CHECK_RETRIES = "lock-check-retries"
Expand All @@ -139,6 +144,20 @@
logger = logging.getLogger(__name__)


class _DigestMD5SaslTransport(TTransport.TSaslClientTransport):
"""TSaslClientTransport subclass that works around THRIFT-5926.

The upstream ``TSaslClientTransport.open()`` passes the first
``sasl.process()`` response directly to ``_send_sasl_message()``,
but for DIGEST-MD5 the initial response is ``None`` (challenge-first
mechanism). Sending ``None`` causes a ``TypeError``. This subclass
coerces ``None`` to ``b""`` so the SASL handshake proceeds normally.
"""

def send_sasl_msg(self, status: int, body: bytes | None) -> None: # type: ignore[override]
super().send_sasl_msg(status, body if body is not None else b"")


class _HiveClient:
"""Helper class to nicely open and close the transport."""

Expand All @@ -150,21 +169,44 @@ def __init__(
uri: str,
ugi: str | None = None,
kerberos_auth: bool | None = HIVE_KERBEROS_AUTH_DEFAULT,
kerberos_service_name: str | None = HIVE_KERBEROS_SERVICE_NAME,
kerberos_service_name: str | None = HIVE_KERBEROS_SERVICE_NAME_DEFAULT,
auth_mechanism: str | None = None,
):
self._uri = uri
self._kerberos_auth = kerberos_auth
self._kerberos_service_name = kerberos_service_name
self._ugi = ugi.split(":") if ugi else None

# Resolve auth mechanism: explicit auth_mechanism takes precedence,
# then fall back to legacy kerberos_auth boolean for backward compat.
if auth_mechanism is not None:
self._auth_mechanism = auth_mechanism.upper()
elif kerberos_auth:
self._auth_mechanism = "KERBEROS"
else:
self._auth_mechanism = HIVE_METASTORE_AUTH_DEFAULT

self._transport = self._init_thrift_transport()

def _init_thrift_transport(self) -> TTransport:
url_parts = urlparse(self._uri)
socket = TSocket.TSocket(url_parts.hostname, url_parts.port)
if not self._kerberos_auth:

if self._auth_mechanism == "NONE":
return TTransport.TBufferedTransport(socket)
else:
elif self._auth_mechanism == "KERBEROS":
return TTransport.TSaslClientTransport(socket, host=url_parts.hostname, service=self._kerberos_service_name)
elif self._auth_mechanism == "DIGEST-MD5":
identifier, password = read_hive_delegation_token()
return _DigestMD5SaslTransport(
socket,
host=url_parts.hostname,
service=self._kerberos_service_name,
mechanism="DIGEST-MD5",
username=identifier,
password=password,
)
else:
raise HiveAuthError(f"Unknown auth mechanism: {self._auth_mechanism!r}. Valid values: NONE, KERBEROS, DIGEST-MD5")

def _client(self) -> Client:
protocol = TBinaryProtocol.TBinaryProtocol(self._transport)
Expand Down Expand Up @@ -319,6 +361,7 @@ def _create_hive_client(properties: dict[str, str]) -> _HiveClient:
properties.get("ugi"),
property_as_bool(properties, HIVE_KERBEROS_AUTH, HIVE_KERBEROS_AUTH_DEFAULT),
properties.get(HIVE_KERBEROS_SERVICE_NAME, HIVE_KERBEROS_SERVICE_NAME_DEFAULT),
auth_mechanism=properties.get(HIVE_METASTORE_AUTH),
)
except BaseException as e:
last_exception = e
Expand Down
4 changes: 4 additions & 0 deletions pyiceberg/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,7 @@ class WaitingForLockException(Exception):

class ValidationException(Exception):
"""Raised when validation fails."""


class HiveAuthError(Exception):
"""Raised when Hive Metastore authentication fails."""
142 changes: 142 additions & 0 deletions pyiceberg/utils/hadoop_credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Hadoop Delegation Token Service (HDTS) file parser.

Reads delegation tokens from the binary token file pointed to by
the ``$HADOOP_TOKEN_FILE_LOCATION`` environment variable.
"""

from __future__ import annotations

import base64
import os
from io import BytesIO

from pyiceberg.exceptions import HiveAuthError

HADOOP_TOKEN_FILE_LOCATION = "HADOOP_TOKEN_FILE_LOCATION"
HIVE_DELEGATION_TOKEN_KIND = "HIVE_DELEGATION_TOKEN"
HDTS_MAGIC = b"HDTS"
HDTS_SUPPORTED_VERSION = 0


def _read_hadoop_vint(stream: BytesIO) -> int:
"""Decode a Hadoop WritableUtils VInt/VLong from a byte stream.

Matches the encoding in Java's ``WritableUtils.readVInt``/``readVLong``:
- If the first byte (interpreted as signed) is >= -112, it *is* the value.
- Otherwise the first byte encodes both a negativity flag and the number
of additional big-endian payload bytes that carry the actual value.
"""
first = stream.read(1)
if not first:
raise HiveAuthError("Unexpected end of token file while reading VInt")
# Reinterpret as signed byte to match Java's signed-byte semantics
b = first[0]
if b > 127:
b -= 256
if b >= -112:
return b
negative = b < -120
length = (-119 - b) if negative else (-111 - b)
extra = stream.read(length)
if len(extra) != length:
raise HiveAuthError("Unexpected end of token file while reading VInt")
result = 0
for byte_val in extra:
result = (result << 8) | byte_val
if negative:
result = ~result
return result


def _read_hadoop_bytes(stream: BytesIO) -> bytes:
"""Read a VInt-prefixed byte array from a Hadoop token stream."""
length = _read_hadoop_vint(stream)
if length < 0:
raise HiveAuthError(f"Invalid byte array length: {length}")
data = stream.read(length)
if len(data) != length:
raise HiveAuthError("Unexpected end of token file while reading byte array")
return data


def _read_hadoop_text(stream: BytesIO) -> str:
"""Read a VInt-prefixed UTF-8 string from a Hadoop token stream."""
raw = _read_hadoop_bytes(stream)
try:
return raw.decode("utf-8")
except UnicodeDecodeError as e:
raise HiveAuthError(f"Token file contains invalid UTF-8 in text field: {e}") from e


def read_hive_delegation_token() -> tuple[str, str]:
"""Read a Hive delegation token from ``$HADOOP_TOKEN_FILE_LOCATION``.

Returns:
A ``(identifier, password)`` tuple where both values are
base64-encoded strings suitable for SASL DIGEST-MD5 auth.

Raises:
HiveAuthError: If the token file is missing, malformed, or
does not contain a ``HIVE_DELEGATION_TOKEN``.
"""
token_file = os.environ.get(HADOOP_TOKEN_FILE_LOCATION)
if not token_file:
raise HiveAuthError(
f"${HADOOP_TOKEN_FILE_LOCATION} environment variable is not set. "
"A Hadoop delegation token file is required for DIGEST-MD5 authentication."
)

try:
with open(token_file, "rb") as f:
data = f.read()
except OSError as e:
raise HiveAuthError(f"Cannot read Hadoop token file {token_file}: {e}") from e

stream = BytesIO(data)

magic = stream.read(4)
if magic != HDTS_MAGIC:
raise HiveAuthError(f"Invalid Hadoop token file magic: expected {HDTS_MAGIC!r}, got {magic!r}")

version_byte = stream.read(1)
if not version_byte:
raise HiveAuthError("Unexpected end of token file while reading version")
version = version_byte[0]
if version != HDTS_SUPPORTED_VERSION:
raise HiveAuthError(f"Unsupported Hadoop token file version: {version}")

num_tokens = _read_hadoop_vint(stream)

for _ in range(num_tokens):
# Each token entry: identifier_bytes, password_bytes, kind_text, service_text
identifier_bytes = _read_hadoop_bytes(stream)
password_bytes = _read_hadoop_bytes(stream)
kind = _read_hadoop_text(stream)
_service = _read_hadoop_text(stream)

if kind == HIVE_DELEGATION_TOKEN_KIND:
return (
base64.b64encode(identifier_bytes).decode("ascii"),
base64.b64encode(password_bytes).decode("ascii"),
)

raise HiveAuthError(
f"No {HIVE_DELEGATION_TOKEN_KIND} found in token file: {token_file}. File contains {num_tokens} token(s)."
)
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ bodo = ["bodo>=2025.7.4"]
daft = ["daft>=0.5.0"]
polars = ["polars>=1.21.0,<2"]
snappy = ["python-snappy>=0.6.0,<1.0.0"]
hive = ["thrift>=0.13.0,<1.0.0"]
hive = [
"thrift>=0.13.0,<1.0.0",
"pure-sasl>=0.6.0,<1.0.0",
]
hive-kerberos = [
"thrift>=0.13.0,<1.0.0",
"thrift-sasl>=0.4.3",
Expand Down
Loading
Loading