From d469f05261ea8466e786b021193f1cda16a6e707 Mon Sep 17 00:00:00 2001 From: Molly He Date: Thu, 5 Mar 2026 13:47:06 -0800 Subject: [PATCH 01/10] Draft fix with unit tests --- .../remote_function/core/serialization.py | 321 +++++++++++++- .../remote_function/core/stored_function.py | 10 +- .../sagemaker/core/remote_function/errors.py | 4 +- .../core/remote_function/invoke_function.py | 1 + .../src/sagemaker/core/remote_function/job.py | 2 + .../remote_function/test_invoke_function.py | 1 + .../test_serialization_security.py | 401 ++++++++++++++++++ 7 files changed, 718 insertions(+), 22 deletions(-) create mode 100644 sagemaker-core/tests/unit/remote_function/test_serialization_security.py diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py b/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py index 8871f6727f..c5efd78b4d 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py @@ -15,14 +15,17 @@ import dataclasses import json +import logging import io import sys import hashlib +import hmac import pickle +import secrets -from typing import Any, Callable, Union +from typing import Any, Callable, Union, Optional import cloudpickle from tblib import pickling_support @@ -38,6 +41,8 @@ # Note: do not use os.path.join for s3 uris, fails on windows +logger = logging.getLogger(__name__) + def _get_python_version(): """Returns the current python version.""" @@ -49,6 +54,7 @@ class _MetaData: """Metadata about the serialized data or functions.""" sha256_hash: str + secret_arn: Optional[str] = None # ARN to AWS Secrets Manager secret containing HMAC key version: str = "2023-04-24" python_version: str = _get_python_version() serialization_module: str = "cloudpickle" @@ -66,7 +72,8 @@ def from_json(s): raise DeserializationError("Corrupt metadata file. It is not a valid json file.") sha256_hash = obj.get("sha256_hash") - metadata = _MetaData(sha256_hash=sha256_hash) + secret_arn = obj.get("secret_arn") # May be None for legacy format + metadata = _MetaData(sha256_hash=sha256_hash, secret_arn=secret_arn) metadata.version = obj.get("version") metadata.python_version = obj.get("python_version") metadata.serialization_module = obj.get("serialization_module") @@ -155,16 +162,21 @@ def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any: # TODO: use dask serializer in case dask distributed is installed in users' environment. def serialize_func_to_s3( - func: Callable, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None + func: Callable, + sagemaker_session: Session, + s3_uri: str, + job_name: str, + s3_kms_key: str = None ): """Serializes function and uploads it to S3. Args: + func: function to be serialized and persisted sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + job_name (str): Remote function job name for secret management s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - func: function to be serialized and persisted Raises: SerializationError: when fail to serialize function to bytes. """ @@ -173,6 +185,7 @@ def serialize_func_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(func), s3_uri=s3_uri, sagemaker_session=sagemaker_session, + job_name=job_name, s3_kms_key=s3_kms_key, ) @@ -199,23 +212,32 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callabl bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize + expected_hash_value=metadata.sha256_hash, + buffer=bytes_to_deserialize, + sagemaker_session=sagemaker_session, + secret_arn=metadata.secret_arn, + s3_uri=s3_uri ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) def serialize_obj_to_s3( - obj: Any, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None + obj: Any, + sagemaker_session: Session, + s3_uri: str, + job_name: str, + s3_kms_key: str = None ): """Serializes data object and uploads it to S3. Args: + obj: object to be serialized and persisted sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + job_name (str): Remote function job name for secret management s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - obj: object to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. """ @@ -224,6 +246,7 @@ def serialize_obj_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(obj), s3_uri=s3_uri, sagemaker_session=sagemaker_session, + job_name=job_name, s3_kms_key=s3_kms_key, ) @@ -288,23 +311,32 @@ def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str) -> Any: bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize + expected_hash_value=metadata.sha256_hash, + buffer=bytes_to_deserialize, + sagemaker_session=sagemaker_session, + secret_arn=metadata.secret_arn, + s3_uri=s3_uri ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) def serialize_exception_to_s3( - exc: Exception, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None + exc: Exception, + sagemaker_session: Session, + s3_uri: str, + job_name: str, + s3_kms_key: str = None ): """Serializes exception with traceback and uploads it to S3. Args: + exc: Exception to be serialized and persisted sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + job_name (str): Remote function job name for secret management s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - exc: Exception to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. """ @@ -314,6 +346,7 @@ def serialize_exception_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(exc), s3_uri=s3_uri, sagemaker_session=sagemaker_session, + job_name=job_name, s3_kms_key=s3_kms_key, ) @@ -322,6 +355,7 @@ def _upload_payload_and_metadata_to_s3( bytes_to_upload: Union[bytes, io.BytesIO], s3_uri: str, sagemaker_session: Session, + job_name: str, s3_kms_key, ): """Uploads serialized payload and metadata to s3. @@ -331,14 +365,22 @@ def _upload_payload_and_metadata_to_s3( s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. + job_name (str): Remote function job name for secret management s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. """ _upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session) - sha256_hash = _compute_hash(bytes_to_upload) + # Get or create HMAC secret in Secrets Manager + secret_arn, hmac_key = _get_or_create_hmac_secret(sagemaker_session, job_name) + + # Compute HMAC-SHA256 hash + sha256_hash = _compute_hmac(bytes_to_upload, hmac_key) + + # Store secret ARN in Parameter Store as trust anchor (Mitigation #3) + _store_secret_arn_in_parameter_store(sagemaker_session, job_name, secret_arn) _upload_bytes_to_s3( - _MetaData(sha256_hash).to_json(), + _MetaData(sha256_hash=sha256_hash, secret_arn=secret_arn).to_json(), f"{s3_uri}/metadata.json", s3_kms_key, sagemaker_session, @@ -365,7 +407,11 @@ def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> An bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize + expected_hash_value=metadata.sha256_hash, + buffer=bytes_to_deserialize, + sagemaker_session=sagemaker_session, + secret_arn=metadata.secret_arn, + s3_uri=s3_uri ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) @@ -396,15 +442,252 @@ def _compute_hash(buffer: bytes) -> str: return hashlib.sha256(buffer).hexdigest() -def _perform_integrity_check(expected_hash_value: str, buffer: bytes): +def _get_or_create_hmac_secret(sagemaker_session: Session, job_name: str) -> tuple[str, str]: + """Get or create HMAC key in AWS Secrets Manager. + + Args: + sagemaker_session: SageMaker session + job_name: Remote function job name + + Returns: + Tuple of (secret_arn, hmac_key) + """ + secret_name = f"sagemaker/remote-function/{job_name}/hmac-key" + secrets_client = sagemaker_session.boto_session.client('secretsmanager') + + try: + # Try to retrieve existing secret + response = secrets_client.get_secret_value(SecretId=secret_name) + return response['ARN'], response['SecretString'] + except secrets_client.exceptions.ResourceNotFoundException: + # Create new secret + hmac_key = secrets.token_hex(32) + + response = secrets_client.create_secret( + Name=secret_name, + SecretString=hmac_key, + Description=f"HMAC key for SageMaker remote function job {job_name}", + Tags=[ + {'Key': 'SageMaker:JobName', 'Value': job_name}, + {'Key': 'SageMaker:Purpose', 'Value': 'RemoteFunctionIntegrity'} + ] + ) + return response['ARN'], hmac_key + + +def _get_hmac_key_from_secret(sagemaker_session: Session, secret_arn: str) -> str: + """Retrieve HMAC key from AWS Secrets Manager. + + Args: + sagemaker_session: SageMaker session + secret_arn: ARN of the secret containing HMAC key + + Returns: + HMAC key string + """ + secrets_client = sagemaker_session.boto_session.client('secretsmanager') + response = secrets_client.get_secret_value(SecretId=secret_arn) + return response['SecretString'] + + +def _compute_hmac(buffer: bytes, hmac_key: str) -> str: + """Compute HMAC-SHA256 hash. + + Args: + buffer: Data to hash + hmac_key: HMAC secret key + + Returns: + HMAC-SHA256 hex digest + """ + return hmac.new(hmac_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest() + + +def _store_secret_arn_in_parameter_store( + sagemaker_session: Session, + job_name: str, + secret_arn: str +): + """Store secret ARN in Parameter Store as trust anchor. + + Args: + sagemaker_session: SageMaker session + job_name: Remote function job name + secret_arn: ARN of the secret to store + """ + ssm_client = sagemaker_session.boto_session.client('ssm') + parameter_name = f"/sagemaker/remote-function/{job_name}/secret-arn" + + ssm_client.put_parameter( + Name=parameter_name, + Value=secret_arn, + Type="String", + Overwrite=True, + Description=f"Secret ARN for SageMaker remote function job {job_name}", + Tags=[ + {'Key': 'SageMaker:JobName', 'Value': job_name}, + {'Key': 'SageMaker:Purpose', 'Value': 'RemoteFunctionIntegrity'} + ] + ) + + +def _get_secret_arn_from_parameter_store( + sagemaker_session: Session, + job_name: str +) -> str: + """Retrieve secret ARN from Parameter Store. + + Args: + sagemaker_session: SageMaker session + job_name: Remote function job name + + Returns: + Secret ARN string + + Raises: + DeserializationError: If parameter not found + """ + ssm_client = sagemaker_session.boto_session.client('ssm') + parameter_name = f"/sagemaker/remote-function/{job_name}/secret-arn" + + try: + response = ssm_client.get_parameter(Name=parameter_name) + return response['Parameter']['Value'] + except ssm_client.exceptions.ParameterNotFound: + raise DeserializationError( + f"Secret ARN not found in Parameter Store for job {job_name}. " + "This may indicate the job was not properly initialized or artifacts were tampered with." + ) + + +def _extract_job_name_from_s3_uri(s3_uri: str) -> str: + """Extract job name from S3 URI. + + S3 URI format: s3://bucket/path/to/job-name/results + or: s3://bucket/job-name/function + + Args: + s3_uri: S3 URI containing job name + + Returns: + Job name extracted from URI + """ + # Remove s3:// prefix and split by / + parts = s3_uri.replace("s3://", "").split("/") + + # Try to find a part that looks like a job name + # Job names typically contain execution IDs or timestamps + for part in reversed(parts): + if part and part not in ['function', 'arguments', 'results', 'exception', 'payload.pkl', 'metadata.json']: + return part + + # Fallback: use the last meaningful part + return parts[-2] if len(parts) > 1 else parts[0] + + +def _validate_secret_arn( + sagemaker_session: Session, + metadata_secret_arn: str, + job_name: str +): + """Validate secret ARN from metadata against trusted sources. + + Implements two mitigations: + 1. Validate secret is in same AWS account + 2. Validate secret ARN matches Parameter Store (trust anchor) + + Args: + sagemaker_session: SageMaker session + metadata_secret_arn: Secret ARN from S3 metadata (untrusted) + job_name: Remote function job name + + Raises: + DeserializationError: If validation fails + """ + # Mitigation #1: Validate same account + sts_client = sagemaker_session.boto_session.client('sts') + current_account_id = sts_client.get_caller_identity()['Account'] + + # Parse account ID from ARN: arn:aws:secretsmanager:region:ACCOUNT_ID:secret:name + arn_parts = metadata_secret_arn.split(":") + if len(arn_parts) < 5: + raise DeserializationError(f"Invalid secret ARN format: {metadata_secret_arn}") + + metadata_account_id = arn_parts[4] + + if metadata_account_id != current_account_id: + raise DeserializationError( + f"Secret must be in the same AWS account. " + f"Expected account {current_account_id}, but got {metadata_account_id}. " + "This may indicate a cross-account attack attempt." + ) + + # Mitigation #3: Validate against Parameter Store (trust anchor) + expected_secret_arn = _get_secret_arn_from_parameter_store(sagemaker_session, job_name) + + if metadata_secret_arn != expected_secret_arn: + raise DeserializationError( + f"Secret ARN mismatch. Expected: {expected_secret_arn}, " + f"Got: {metadata_secret_arn}. " + "Possible tampering detected - metadata may have been modified." + ) + + +def _perform_integrity_check( + expected_hash_value: str, + buffer: bytes, + sagemaker_session: Optional[Session] = None, + secret_arn: Optional[str] = None, + s3_uri: Optional[str] = None +): """Performs integrity checks for serialized code/arguments uploaded to s3. Verifies whether the hash read from s3 matches the hash calculated during remote function execution. + + Args: + expected_hash_value: Expected hash value from metadata + buffer: Serialized data buffer + sagemaker_session: SageMaker session (required if secret_arn is provided) + secret_arn: ARN of secret containing HMAC key (None for legacy plain SHA-256) + s3_uri: S3 URI for extracting job name (required if secret_arn is provided) """ - actual_hash_value = _compute_hash(buffer=buffer) - if expected_hash_value != actual_hash_value: - raise DeserializationError( - "Integrity check for the serialized function or data failed. " - "Please restrict access to your S3 bucket" + if secret_arn: + # New secure method: HMAC with key from Secrets Manager + if not sagemaker_session: + raise DeserializationError( + "sagemaker_session is required for HMAC integrity check" + ) + + if not s3_uri: + raise DeserializationError( + "s3_uri is required for HMAC integrity check to extract job name" + ) + + # Extract job name from S3 URI + job_name = _extract_job_name_from_s3_uri(s3_uri) + + # Validate secret ARN (Mitigations #1 and #3) + _validate_secret_arn(sagemaker_session, secret_arn, job_name) + + # Now safe to retrieve HMAC key + hmac_key = _get_hmac_key_from_secret(sagemaker_session, secret_arn) + actual_hash_value = _compute_hmac(buffer, hmac_key) + + if not hmac.compare_digest(expected_hash_value, actual_hash_value): + raise DeserializationError( + "HMAC integrity check failed. Serialized data may have been tampered with. " + "Please restrict access to your S3 bucket." + ) + else: + # Legacy method: plain SHA-256 (backward compatibility) + logger.warning( + "Using legacy SHA-256 integrity check without HMAC authentication. " + "This provides weaker security guarantees. Please upgrade to the latest SDK version." ) + actual_hash_value = _compute_hash(buffer) + if expected_hash_value != actual_hash_value: + raise DeserializationError( + "Integrity check for the serialized function or data failed. " + "Please restrict access to your S3 bucket" + ) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py b/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py index c7ee86f8a7..1a45c378f4 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py @@ -57,6 +57,7 @@ def __init__( s3_base_uri: str, s3_kms_key: str = None, context: Context = Context(), + job_name: str = None, ): """Construct a StoredFunction object. @@ -66,11 +67,13 @@ def __init__( s3_base_uri: the base uri to which serialized artifacts will be uploaded. s3_kms_key: KMS key used to encrypt artifacts uploaded to S3. context: Build or run context of a pipeline step. + job_name: Remote function job name for secret management. """ self.sagemaker_session = sagemaker_session self.s3_base_uri = s3_base_uri self.s3_kms_key = s3_kms_key self.context = context + self.job_name = job_name or os.environ.get("TRAINING_JOB_NAME") # For pipeline steps, function code is at: base/step_name/build_timestamp/ # For results, path is: base/step_name/build_timestamp/execution_id/ @@ -110,6 +113,7 @@ def save(self, func, *args, **kwargs): func=func, sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), + job_name=self.job_name, s3_kms_key=self.s3_kms_key, ) @@ -123,7 +127,7 @@ def save(self, func, *args, **kwargs): obj=(args, kwargs), sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), - + job_name=self.job_name, s3_kms_key=self.s3_kms_key, ) @@ -144,6 +148,7 @@ def save_pipeline_step_function(self, serialized_data): s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), sagemaker_session=self.sagemaker_session, + job_name=self.job_name, s3_kms_key=self.s3_kms_key, ) @@ -156,6 +161,7 @@ def save_pipeline_step_function(self, serialized_data): s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), sagemaker_session=self.sagemaker_session, + job_name=self.job_name, s3_kms_key=self.s3_kms_key, ) @@ -203,7 +209,7 @@ def load_and_invoke(self) -> Any: obj=result, sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.results_upload_path, RESULTS_FOLDER), - + job_name=self.job_name, s3_kms_key=self.s3_kms_key, ) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/errors.py b/sagemaker-core/src/sagemaker/core/remote_function/errors.py index 3f391570cf..6315c1c527 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/errors.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/errors.py @@ -70,7 +70,7 @@ def _write_failure_reason_file(failure_msg): f.write(failure_msg) -def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: +def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key, job_name=None) -> int: """Handle all exceptions raised during remote function execution. Args: @@ -79,6 +79,7 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: AWS service calls are delegated to. s3_base_uri (str): S3 root uri to which resulting serialized exception will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + job_name (str): Remote function job name for secret management. Returns : exit_code (int): Exit code to terminate current job. """ @@ -96,6 +97,7 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: exc=error, sagemaker_session=sagemaker_session, s3_uri=s3_path_join(s3_base_uri, "exception"), + job_name=job_name or os.environ.get("TRAINING_JOB_NAME"), s3_kms_key=s3_kms_key, ) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py b/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py index 2e69f4f116..c43978f687 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py @@ -108,6 +108,7 @@ def _execute_remote_function( s3_base_uri=s3_base_uri, s3_kms_key=s3_kms_key, context=context, + job_name=os.environ.get("TRAINING_JOB_NAME"), ) if run_in_context: diff --git a/sagemaker-core/src/sagemaker/core/remote_function/job.py b/sagemaker-core/src/sagemaker/core/remote_function/job.py index 6e727d4b9c..b6ac5572b7 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/job.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/job.py @@ -931,6 +931,7 @@ def compile( sagemaker_session=job_settings.sagemaker_session, s3_base_uri=s3_base_uri, s3_kms_key=job_settings.s3_kms_key, + job_name=job_name, ) stored_function.save(func, *func_args, **func_kwargs) else: @@ -942,6 +943,7 @@ def compile( step_name=step_compilation_context.step_name, func_step_s3_dir=step_compilation_context.pipeline_build_time, ), + job_name=job_name, ) stored_function.save_pipeline_step_function(serialized_data) diff --git a/sagemaker-core/tests/unit/remote_function/test_invoke_function.py b/sagemaker-core/tests/unit/remote_function/test_invoke_function.py index 4810eba2e0..7bd24489e7 100644 --- a/sagemaker-core/tests/unit/remote_function/test_invoke_function.py +++ b/sagemaker-core/tests/unit/remote_function/test_invoke_function.py @@ -188,6 +188,7 @@ def test_executes_without_run_context(self, mock_stored_function_class): s3_base_uri="s3://bucket/path", s3_kms_key="key-123", context=mock_context, + job_name=None, ) mock_stored_func.load_and_invoke.assert_called_once() diff --git a/sagemaker-core/tests/unit/remote_function/test_serialization_security.py b/sagemaker-core/tests/unit/remote_function/test_serialization_security.py new file mode 100644 index 0000000000..0478617ea1 --- /dev/null +++ b/sagemaker-core/tests/unit/remote_function/test_serialization_security.py @@ -0,0 +1,401 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit tests for serialization security (HMAC + Secrets Manager + Parameter Store).""" +from __future__ import absolute_import + +import hashlib +import hmac as hmac_module +import json +from unittest.mock import Mock, patch, MagicMock + +import pytest + +from sagemaker.core.remote_function.core.serialization import ( + _MetaData, + _compute_hash, + _compute_hmac, + _get_or_create_hmac_secret, + _get_hmac_key_from_secret, + _store_secret_arn_in_parameter_store, + _get_secret_arn_from_parameter_store, + _extract_job_name_from_s3_uri, + _validate_secret_arn, + _perform_integrity_check, + _upload_payload_and_metadata_to_s3, + serialize_obj_to_s3, + deserialize_obj_from_s3, + serialize_func_to_s3, + serialize_exception_to_s3, + deserialize_func_from_s3, + deserialize_exception_from_s3, +) +from sagemaker.core.remote_function.errors import DeserializationError + + +MOCK_JOB_NAME = "test-remote-function-job" +MOCK_SECRET_ARN = "arn:aws:secretsmanager:us-west-2:123456789012:secret:sagemaker/remote-function/test-remote-function-job/hmac-key-AbCdEf" +MOCK_HMAC_KEY = "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" +MOCK_ACCOUNT_ID = "123456789012" +MOCK_S3_URI = "s3://my-bucket/remote-function/test-remote-function-job/results" + + +def _mock_sagemaker_session(account_id=MOCK_ACCOUNT_ID): + """Create a mock SageMaker session with Secrets Manager, SSM, and STS clients.""" + session = Mock() + + # Mock Secrets Manager client + secrets_client = Mock() + secrets_client.get_secret_value.return_value = { + "ARN": MOCK_SECRET_ARN, + "SecretString": MOCK_HMAC_KEY, + } + secrets_client.create_secret.return_value = { + "ARN": MOCK_SECRET_ARN, + } + secrets_client.exceptions = Mock() + secrets_client.exceptions.ResourceNotFoundException = type( + "ResourceNotFoundException", (Exception,), {} + ) + + # Mock SSM client + ssm_client = Mock() + ssm_client.get_parameter.return_value = { + "Parameter": {"Value": MOCK_SECRET_ARN} + } + ssm_client.exceptions = Mock() + ssm_client.exceptions.ParameterNotFound = type( + "ParameterNotFound", (Exception,), {} + ) + + # Mock STS client + sts_client = Mock() + sts_client.get_caller_identity.return_value = {"Account": account_id} + + def client_factory(service_name): + if service_name == "secretsmanager": + return secrets_client + elif service_name == "ssm": + return ssm_client + elif service_name == "sts": + return sts_client + return Mock() + + session.boto_session.client = client_factory + return session, secrets_client, ssm_client, sts_client + + +class TestMetaData: + """Tests for _MetaData class.""" + + def test_metadata_with_secret_arn(self): + metadata = _MetaData(sha256_hash="abc123", secret_arn=MOCK_SECRET_ARN) + json_bytes = metadata.to_json() + parsed = _MetaData.from_json(json_bytes) + + assert parsed.sha256_hash == "abc123" + assert parsed.secret_arn == MOCK_SECRET_ARN + + def test_metadata_without_secret_arn_legacy(self): + metadata = _MetaData(sha256_hash="abc123") + json_bytes = metadata.to_json() + parsed = _MetaData.from_json(json_bytes) + + assert parsed.sha256_hash == "abc123" + assert parsed.secret_arn is None + + def test_metadata_missing_hash_raises(self): + with pytest.raises(DeserializationError, match="SHA256 hash"): + _MetaData.from_json(json.dumps({"version": "2023-04-24", "serialization_module": "cloudpickle"})) + + def test_metadata_invalid_json_raises(self): + with pytest.raises(DeserializationError, match="not a valid json"): + _MetaData.from_json(b"not json") + + +class TestComputeHmac: + """Tests for HMAC computation.""" + + def test_compute_hmac(self): + data = b"test data" + key = "test-key" + result = _compute_hmac(data, key) + expected = hmac_module.new(key.encode(), msg=data, digestmod=hashlib.sha256).hexdigest() + assert result == expected + + def test_compute_hmac_different_keys_produce_different_hashes(self): + data = b"test data" + hash1 = _compute_hmac(data, "key1") + hash2 = _compute_hmac(data, "key2") + assert hash1 != hash2 + + def test_compute_hash_plain_sha256(self): + data = b"test data" + result = _compute_hash(data) + expected = hashlib.sha256(data).hexdigest() + assert result == expected + + +class TestGetOrCreateHmacSecret: + """Tests for Secrets Manager integration.""" + + def test_get_existing_secret(self): + session, secrets_client, _, _ = _mock_sagemaker_session() + + arn, key = _get_or_create_hmac_secret(session, MOCK_JOB_NAME) + + assert arn == MOCK_SECRET_ARN + assert key == MOCK_HMAC_KEY + secrets_client.get_secret_value.assert_called_once_with( + SecretId=f"sagemaker/remote-function/{MOCK_JOB_NAME}/hmac-key" + ) + + def test_create_new_secret_when_not_found(self): + session, secrets_client, _, _ = _mock_sagemaker_session() + + # Simulate ResourceNotFoundException + secrets_client.get_secret_value.side_effect = ( + secrets_client.exceptions.ResourceNotFoundException("not found") + ) + + arn, key = _get_or_create_hmac_secret(session, MOCK_JOB_NAME) + + assert arn == MOCK_SECRET_ARN + assert len(key) == 64 # secrets.token_hex(32) produces 64 chars + secrets_client.create_secret.assert_called_once() + + +class TestParameterStore: + """Tests for Parameter Store trust anchor.""" + + def test_store_secret_arn(self): + session, _, ssm_client, _ = _mock_sagemaker_session() + + _store_secret_arn_in_parameter_store(session, MOCK_JOB_NAME, MOCK_SECRET_ARN) + + ssm_client.put_parameter.assert_called_once() + call_kwargs = ssm_client.put_parameter.call_args[1] + assert call_kwargs["Name"] == f"/sagemaker/remote-function/{MOCK_JOB_NAME}/secret-arn" + assert call_kwargs["Value"] == MOCK_SECRET_ARN + assert call_kwargs["Overwrite"] is True + + def test_get_secret_arn(self): + session, _, ssm_client, _ = _mock_sagemaker_session() + + result = _get_secret_arn_from_parameter_store(session, MOCK_JOB_NAME) + + assert result == MOCK_SECRET_ARN + ssm_client.get_parameter.assert_called_once_with( + Name=f"/sagemaker/remote-function/{MOCK_JOB_NAME}/secret-arn" + ) + + def test_get_secret_arn_not_found_raises(self): + session, _, ssm_client, _ = _mock_sagemaker_session() + ssm_client.get_parameter.side_effect = ( + ssm_client.exceptions.ParameterNotFound("not found") + ) + + with pytest.raises(DeserializationError, match="Secret ARN not found"): + _get_secret_arn_from_parameter_store(session, MOCK_JOB_NAME) + + +class TestExtractJobName: + """Tests for S3 URI job name extraction.""" + + def test_extract_from_results_uri(self): + result = _extract_job_name_from_s3_uri( + "s3://bucket/remote-function/my-job-123/results" + ) + assert result == "my-job-123" + + def test_extract_from_function_uri(self): + result = _extract_job_name_from_s3_uri( + "s3://bucket/remote-function/my-job-123/function" + ) + assert result == "my-job-123" + + def test_extract_from_exception_uri(self): + result = _extract_job_name_from_s3_uri( + "s3://bucket/remote-function/my-job-123/exception" + ) + assert result == "my-job-123" + + +class TestValidateSecretArn: + """Tests for secret ARN validation (Mitigations #1 and #3).""" + + def test_valid_secret_arn_passes(self): + """Valid ARN in same account matching Parameter Store should pass.""" + session, _, _, _ = _mock_sagemaker_session() + + # Should not raise + _validate_secret_arn(session, MOCK_SECRET_ARN, MOCK_JOB_NAME) + + def test_cross_account_arn_rejected(self): + """Mitigation #1: Secret ARN from different account should be rejected.""" + session, _, _, _ = _mock_sagemaker_session(account_id=MOCK_ACCOUNT_ID) + + attacker_arn = "arn:aws:secretsmanager:us-west-2:999999999999:secret:evil-secret" + + with pytest.raises(DeserializationError, match="same AWS account"): + _validate_secret_arn(session, attacker_arn, MOCK_JOB_NAME) + + def test_tampered_arn_rejected(self): + """Mitigation #3: ARN not matching Parameter Store should be rejected.""" + session, _, ssm_client, _ = _mock_sagemaker_session() + + # Parameter Store returns the legitimate ARN + ssm_client.get_parameter.return_value = { + "Parameter": {"Value": MOCK_SECRET_ARN} + } + + # Attacker's ARN (same account but different secret) + tampered_arn = f"arn:aws:secretsmanager:us-west-2:{MOCK_ACCOUNT_ID}:secret:attacker-created-secret" + + with pytest.raises(DeserializationError, match="Secret ARN mismatch"): + _validate_secret_arn(session, tampered_arn, MOCK_JOB_NAME) + + def test_invalid_arn_format_rejected(self): + """Malformed ARN should be rejected.""" + session, _, _, _ = _mock_sagemaker_session() + + with pytest.raises(DeserializationError, match="Invalid secret ARN format"): + _validate_secret_arn(session, "not-an-arn", MOCK_JOB_NAME) + + +class TestPerformIntegrityCheck: + """Tests for integrity check with HMAC.""" + + def test_hmac_integrity_check_passes(self): + """Valid HMAC should pass integrity check.""" + session, _, _, _ = _mock_sagemaker_session() + + payload = b"test payload" + expected_hmac = _compute_hmac(payload, MOCK_HMAC_KEY) + + # Should not raise + _perform_integrity_check( + expected_hash_value=expected_hmac, + buffer=payload, + sagemaker_session=session, + secret_arn=MOCK_SECRET_ARN, + s3_uri=MOCK_S3_URI, + ) + + def test_hmac_integrity_check_fails_on_tampered_payload(self): + """Tampered payload should fail HMAC check.""" + session, _, _, _ = _mock_sagemaker_session() + + original_payload = b"original payload" + tampered_payload = b"tampered payload" + expected_hmac = _compute_hmac(original_payload, MOCK_HMAC_KEY) + + with pytest.raises(DeserializationError, match="HMAC integrity check failed"): + _perform_integrity_check( + expected_hash_value=expected_hmac, + buffer=tampered_payload, + sagemaker_session=session, + secret_arn=MOCK_SECRET_ARN, + s3_uri=MOCK_S3_URI, + ) + + def test_legacy_sha256_check_passes_with_warning(self): + """Legacy SHA-256 check should pass with warning when no secret_arn.""" + payload = b"test payload" + expected_hash = _compute_hash(payload) + + # Should not raise (legacy path) + _perform_integrity_check( + expected_hash_value=expected_hash, + buffer=payload, + ) + + def test_legacy_sha256_check_fails_on_tampered_payload(self): + """Legacy SHA-256 check should fail on tampered payload.""" + original_payload = b"original payload" + tampered_payload = b"tampered payload" + expected_hash = _compute_hash(original_payload) + + with pytest.raises(DeserializationError, match="Integrity check"): + _perform_integrity_check( + expected_hash_value=expected_hash, + buffer=tampered_payload, + ) + + def test_hmac_check_requires_session(self): + """HMAC check should require sagemaker_session.""" + with pytest.raises(DeserializationError, match="sagemaker_session is required"): + _perform_integrity_check( + expected_hash_value="hash", + buffer=b"data", + secret_arn=MOCK_SECRET_ARN, + ) + + def test_hmac_check_requires_s3_uri(self): + """HMAC check should require s3_uri.""" + session, _, _, _ = _mock_sagemaker_session() + + with pytest.raises(DeserializationError, match="s3_uri is required"): + _perform_integrity_check( + expected_hash_value="hash", + buffer=b"data", + sagemaker_session=session, + secret_arn=MOCK_SECRET_ARN, + ) + + +class TestAttackScenarios: + """Tests simulating actual attack scenarios.""" + + def test_attacker_replaces_payload_and_metadata_plain_hash(self): + """Attacker replaces both files with plain SHA-256 - should fail HMAC check.""" + session, secrets_client, _, _ = _mock_sagemaker_session() + + # Attacker creates malicious payload + malicious_payload = b"malicious code" + + # Attacker computes plain SHA-256 (not HMAC) + plain_hash = hashlib.sha256(malicious_payload).hexdigest() + + # Attacker's HMAC won't match because they don't know the key + with pytest.raises(DeserializationError, match="HMAC integrity check failed"): + _perform_integrity_check( + expected_hash_value=plain_hash, + buffer=malicious_payload, + sagemaker_session=session, + secret_arn=MOCK_SECRET_ARN, + s3_uri=MOCK_S3_URI, + ) + + def test_attacker_points_to_cross_account_secret(self): + """Attacker points to their own secret in different account - should be rejected.""" + session, _, _, _ = _mock_sagemaker_session() + + attacker_secret_arn = "arn:aws:secretsmanager:us-west-2:999999999999:secret:attacker-secret" + + with pytest.raises(DeserializationError, match="same AWS account"): + _validate_secret_arn(session, attacker_secret_arn, MOCK_JOB_NAME) + + def test_attacker_creates_secret_in_same_account(self): + """Attacker creates secret in same account but ARN doesn't match Parameter Store.""" + session, _, ssm_client, _ = _mock_sagemaker_session() + + # Parameter Store has the legitimate ARN + ssm_client.get_parameter.return_value = { + "Parameter": {"Value": MOCK_SECRET_ARN} + } + + # Attacker's secret in same account + attacker_arn = f"arn:aws:secretsmanager:us-west-2:{MOCK_ACCOUNT_ID}:secret:sagemaker/remote-function/evil-job/hmac-key" + + with pytest.raises(DeserializationError, match="Secret ARN mismatch"): + _validate_secret_arn(session, attacker_arn, MOCK_JOB_NAME) From 72a398202103ef7a92db382038dca06c1ce386e3 Mon Sep 17 00:00:00 2001 From: Molly He Date: Fri, 6 Mar 2026 16:15:19 -0800 Subject: [PATCH 02/10] Update after testing --- .../sagemaker/core/remote_function/client.py | 6 -- .../remote_function/core/serialization.py | 84 +++++++++---------- .../remote_function/core/stored_function.py | 6 +- .../test_serialization_security.py | 56 ++----------- 4 files changed, 51 insertions(+), 101 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/client.py b/sagemaker-core/src/sagemaker/core/remote_function/client.py index 3cfa5e3b23..85e2cda868 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/client.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/client.py @@ -369,7 +369,6 @@ def wrapper(*args, **kwargs): s3_uri=s3_path_join( job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER ), - ) except ServiceError as serr: chained_e = serr.__cause__ @@ -406,7 +405,6 @@ def wrapper(*args, **kwargs): return serialization.deserialize_obj_from_s3( sagemaker_session=job_settings.sagemaker_session, s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER), - ) if job.describe()["TrainingJobStatus"] == "Stopped": @@ -1008,7 +1006,6 @@ def from_describe_response(describe_training_job_response, sagemaker_session): job_return = serialization.deserialize_obj_from_s3( sagemaker_session=sagemaker_session, s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER), - ) except DeserializationError as e: client_exception = e @@ -1020,7 +1017,6 @@ def from_describe_response(describe_training_job_response, sagemaker_session): job_exception = serialization.deserialize_exception_from_s3( sagemaker_session=sagemaker_session, s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER), - ) except ServiceError as serr: chained_e = serr.__cause__ @@ -1110,7 +1106,6 @@ def result(self, timeout: float = None) -> Any: self._return = serialization.deserialize_obj_from_s3( sagemaker_session=self._job.sagemaker_session, s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER), - ) self._state = _FINISHED return self._return @@ -1119,7 +1114,6 @@ def result(self, timeout: float = None) -> Any: self._exception = serialization.deserialize_exception_from_s3( sagemaker_session=self._job.sagemaker_session, s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER), - ) except ServiceError as serr: chained_e = serr.__cause__ diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py b/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py index c5efd78b4d..d5eab30f4f 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py @@ -216,7 +216,6 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callabl buffer=bytes_to_deserialize, sagemaker_session=sagemaker_session, secret_arn=metadata.secret_arn, - s3_uri=s3_uri ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) @@ -315,7 +314,6 @@ def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str) -> Any: buffer=bytes_to_deserialize, sagemaker_session=sagemaker_session, secret_arn=metadata.secret_arn, - s3_uri=s3_uri ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) @@ -411,7 +409,6 @@ def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> An buffer=bytes_to_deserialize, sagemaker_session=sagemaker_session, secret_arn=metadata.secret_arn, - s3_uri=s3_uri ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) @@ -518,17 +515,24 @@ def _store_secret_arn_in_parameter_store( ssm_client = sagemaker_session.boto_session.client('ssm') parameter_name = f"/sagemaker/remote-function/{job_name}/secret-arn" - ssm_client.put_parameter( - Name=parameter_name, - Value=secret_arn, - Type="String", - Overwrite=True, - Description=f"Secret ARN for SageMaker remote function job {job_name}", - Tags=[ - {'Key': 'SageMaker:JobName', 'Value': job_name}, - {'Key': 'SageMaker:Purpose', 'Value': 'RemoteFunctionIntegrity'} - ] - ) + try: + ssm_client.put_parameter( + Name=parameter_name, + Value=secret_arn, + Type="String", + Description=f"Secret ARN for SageMaker remote function job {job_name}", + Tags=[ + {'Key': 'SageMaker:JobName', 'Value': job_name}, + {'Key': 'SageMaker:Purpose', 'Value': 'RemoteFunctionIntegrity'} + ] + ) + except ssm_client.exceptions.ParameterAlreadyExists: + ssm_client.put_parameter( + Name=parameter_name, + Value=secret_arn, + Type="String", + Overwrite=True, + ) def _get_secret_arn_from_parameter_store( @@ -560,35 +564,34 @@ def _get_secret_arn_from_parameter_store( ) -def _extract_job_name_from_s3_uri(s3_uri: str) -> str: - """Extract job name from S3 URI. +def _extract_job_name_from_secret_arn(secret_arn: str) -> str: + """Extract job name from a Secrets Manager ARN. - S3 URI format: s3://bucket/path/to/job-name/results - or: s3://bucket/job-name/function + Secret name convention: sagemaker/remote-function/{job_name}/hmac-key + ARN format: arn:aws:secretsmanager:region:account:secret:sagemaker/remote-function/{job_name}/hmac-key-XXXXXX Args: - s3_uri: S3 URI containing job name + secret_arn: Full ARN of the secret Returns: - Job name extracted from URI + Extracted job name + + Raises: + DeserializationError: If ARN doesn't match expected format """ - # Remove s3:// prefix and split by / - parts = s3_uri.replace("s3://", "").split("/") - - # Try to find a part that looks like a job name - # Job names typically contain execution IDs or timestamps - for part in reversed(parts): - if part and part not in ['function', 'arguments', 'results', 'exception', 'payload.pkl', 'metadata.json']: - return part - - # Fallback: use the last meaningful part - return parts[-2] if len(parts) > 1 else parts[0] + import re + match = re.search(r":secret:sagemaker/remote-function/(.+)/hmac-key", secret_arn) + if not match: + raise DeserializationError( + f"Secret ARN does not match expected format " + f"'sagemaker/remote-function/{{job_name}}/hmac-key': {secret_arn}" + ) + return match.group(1) def _validate_secret_arn( sagemaker_session: Session, metadata_secret_arn: str, - job_name: str ): """Validate secret ARN from metadata against trusted sources. @@ -596,10 +599,12 @@ def _validate_secret_arn( 1. Validate secret is in same AWS account 2. Validate secret ARN matches Parameter Store (trust anchor) + The job_name is derived from the secret ARN's naming convention, then + independently validated against the SSM trust anchor. + Args: sagemaker_session: SageMaker session metadata_secret_arn: Secret ARN from S3 metadata (untrusted) - job_name: Remote function job name Raises: DeserializationError: If validation fails @@ -623,6 +628,7 @@ def _validate_secret_arn( ) # Mitigation #3: Validate against Parameter Store (trust anchor) + job_name = _extract_job_name_from_secret_arn(metadata_secret_arn) expected_secret_arn = _get_secret_arn_from_parameter_store(sagemaker_session, job_name) if metadata_secret_arn != expected_secret_arn: @@ -638,7 +644,6 @@ def _perform_integrity_check( buffer: bytes, sagemaker_session: Optional[Session] = None, secret_arn: Optional[str] = None, - s3_uri: Optional[str] = None ): """Performs integrity checks for serialized code/arguments uploaded to s3. @@ -650,7 +655,6 @@ def _perform_integrity_check( buffer: Serialized data buffer sagemaker_session: SageMaker session (required if secret_arn is provided) secret_arn: ARN of secret containing HMAC key (None for legacy plain SHA-256) - s3_uri: S3 URI for extracting job name (required if secret_arn is provided) """ if secret_arn: # New secure method: HMAC with key from Secrets Manager @@ -659,16 +663,8 @@ def _perform_integrity_check( "sagemaker_session is required for HMAC integrity check" ) - if not s3_uri: - raise DeserializationError( - "s3_uri is required for HMAC integrity check to extract job name" - ) - - # Extract job name from S3 URI - job_name = _extract_job_name_from_s3_uri(s3_uri) - # Validate secret ARN (Mitigations #1 and #3) - _validate_secret_arn(sagemaker_session, secret_arn, job_name) + _validate_secret_arn(sagemaker_session, secret_arn) # Now safe to retrieve HMAC key hmac_key = _get_hmac_key_from_secret(sagemaker_session, secret_arn) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py b/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py index 1a45c378f4..d09c3737f5 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py @@ -145,10 +145,9 @@ def save_pipeline_step_function(self, serialized_data): ) serialization._upload_payload_and_metadata_to_s3( bytes_to_upload=serialized_data.func, - + job_name=self.job_name, s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), sagemaker_session=self.sagemaker_session, - job_name=self.job_name, s3_kms_key=self.s3_kms_key, ) @@ -158,10 +157,9 @@ def save_pipeline_step_function(self, serialized_data): ) serialization._upload_payload_and_metadata_to_s3( bytes_to_upload=serialized_data.args, - + job_name=self.job_name, s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), sagemaker_session=self.sagemaker_session, - job_name=self.job_name, s3_kms_key=self.s3_kms_key, ) diff --git a/sagemaker-core/tests/unit/remote_function/test_serialization_security.py b/sagemaker-core/tests/unit/remote_function/test_serialization_security.py index 0478617ea1..eb2c7cc9f7 100644 --- a/sagemaker-core/tests/unit/remote_function/test_serialization_security.py +++ b/sagemaker-core/tests/unit/remote_function/test_serialization_security.py @@ -24,11 +24,11 @@ _MetaData, _compute_hash, _compute_hmac, + _extract_job_name_from_secret_arn, _get_or_create_hmac_secret, _get_hmac_key_from_secret, _store_secret_arn_in_parameter_store, _get_secret_arn_from_parameter_store, - _extract_job_name_from_s3_uri, _validate_secret_arn, _perform_integrity_check, _upload_payload_and_metadata_to_s3, @@ -186,7 +186,7 @@ def test_store_secret_arn(self): call_kwargs = ssm_client.put_parameter.call_args[1] assert call_kwargs["Name"] == f"/sagemaker/remote-function/{MOCK_JOB_NAME}/secret-arn" assert call_kwargs["Value"] == MOCK_SECRET_ARN - assert call_kwargs["Overwrite"] is True + assert "Tags" in call_kwargs def test_get_secret_arn(self): session, _, ssm_client, _ = _mock_sagemaker_session() @@ -208,28 +208,6 @@ def test_get_secret_arn_not_found_raises(self): _get_secret_arn_from_parameter_store(session, MOCK_JOB_NAME) -class TestExtractJobName: - """Tests for S3 URI job name extraction.""" - - def test_extract_from_results_uri(self): - result = _extract_job_name_from_s3_uri( - "s3://bucket/remote-function/my-job-123/results" - ) - assert result == "my-job-123" - - def test_extract_from_function_uri(self): - result = _extract_job_name_from_s3_uri( - "s3://bucket/remote-function/my-job-123/function" - ) - assert result == "my-job-123" - - def test_extract_from_exception_uri(self): - result = _extract_job_name_from_s3_uri( - "s3://bucket/remote-function/my-job-123/exception" - ) - assert result == "my-job-123" - - class TestValidateSecretArn: """Tests for secret ARN validation (Mitigations #1 and #3).""" @@ -238,7 +216,7 @@ def test_valid_secret_arn_passes(self): session, _, _, _ = _mock_sagemaker_session() # Should not raise - _validate_secret_arn(session, MOCK_SECRET_ARN, MOCK_JOB_NAME) + _validate_secret_arn(session, MOCK_SECRET_ARN) def test_cross_account_arn_rejected(self): """Mitigation #1: Secret ARN from different account should be rejected.""" @@ -247,7 +225,7 @@ def test_cross_account_arn_rejected(self): attacker_arn = "arn:aws:secretsmanager:us-west-2:999999999999:secret:evil-secret" with pytest.raises(DeserializationError, match="same AWS account"): - _validate_secret_arn(session, attacker_arn, MOCK_JOB_NAME) + _validate_secret_arn(session, attacker_arn) def test_tampered_arn_rejected(self): """Mitigation #3: ARN not matching Parameter Store should be rejected.""" @@ -261,15 +239,15 @@ def test_tampered_arn_rejected(self): # Attacker's ARN (same account but different secret) tampered_arn = f"arn:aws:secretsmanager:us-west-2:{MOCK_ACCOUNT_ID}:secret:attacker-created-secret" - with pytest.raises(DeserializationError, match="Secret ARN mismatch"): - _validate_secret_arn(session, tampered_arn, MOCK_JOB_NAME) + with pytest.raises(DeserializationError, match="does not match expected format"): + _validate_secret_arn(session, tampered_arn) def test_invalid_arn_format_rejected(self): """Malformed ARN should be rejected.""" session, _, _, _ = _mock_sagemaker_session() with pytest.raises(DeserializationError, match="Invalid secret ARN format"): - _validate_secret_arn(session, "not-an-arn", MOCK_JOB_NAME) + _validate_secret_arn(session, "not-an-arn") class TestPerformIntegrityCheck: @@ -288,7 +266,6 @@ def test_hmac_integrity_check_passes(self): buffer=payload, sagemaker_session=session, secret_arn=MOCK_SECRET_ARN, - s3_uri=MOCK_S3_URI, ) def test_hmac_integrity_check_fails_on_tampered_payload(self): @@ -305,7 +282,6 @@ def test_hmac_integrity_check_fails_on_tampered_payload(self): buffer=tampered_payload, sagemaker_session=session, secret_arn=MOCK_SECRET_ARN, - s3_uri=MOCK_S3_URI, ) def test_legacy_sha256_check_passes_with_warning(self): @@ -340,19 +316,6 @@ def test_hmac_check_requires_session(self): secret_arn=MOCK_SECRET_ARN, ) - def test_hmac_check_requires_s3_uri(self): - """HMAC check should require s3_uri.""" - session, _, _, _ = _mock_sagemaker_session() - - with pytest.raises(DeserializationError, match="s3_uri is required"): - _perform_integrity_check( - expected_hash_value="hash", - buffer=b"data", - sagemaker_session=session, - secret_arn=MOCK_SECRET_ARN, - ) - - class TestAttackScenarios: """Tests simulating actual attack scenarios.""" @@ -373,7 +336,6 @@ def test_attacker_replaces_payload_and_metadata_plain_hash(self): buffer=malicious_payload, sagemaker_session=session, secret_arn=MOCK_SECRET_ARN, - s3_uri=MOCK_S3_URI, ) def test_attacker_points_to_cross_account_secret(self): @@ -383,7 +345,7 @@ def test_attacker_points_to_cross_account_secret(self): attacker_secret_arn = "arn:aws:secretsmanager:us-west-2:999999999999:secret:attacker-secret" with pytest.raises(DeserializationError, match="same AWS account"): - _validate_secret_arn(session, attacker_secret_arn, MOCK_JOB_NAME) + _validate_secret_arn(session, attacker_secret_arn) def test_attacker_creates_secret_in_same_account(self): """Attacker creates secret in same account but ARN doesn't match Parameter Store.""" @@ -398,4 +360,4 @@ def test_attacker_creates_secret_in_same_account(self): attacker_arn = f"arn:aws:secretsmanager:us-west-2:{MOCK_ACCOUNT_ID}:secret:sagemaker/remote-function/evil-job/hmac-key" with pytest.raises(DeserializationError, match="Secret ARN mismatch"): - _validate_secret_arn(session, attacker_arn, MOCK_JOB_NAME) + _validate_secret_arn(session, attacker_arn) From 90f1f6632bb17a37076b51f099bad4a78bfd0d51 Mon Sep 17 00:00:00 2001 From: Molly He Date: Mon, 9 Mar 2026 15:50:27 -0700 Subject: [PATCH 03/10] Add wheel install for remote function integ test to use local code --- .../tests/integ/remote_function/conftest.py | 70 +++++++++++ .../test_sagemaker_dependency_injection.py | 112 ++++++++---------- 2 files changed, 119 insertions(+), 63 deletions(-) create mode 100644 sagemaker-core/tests/integ/remote_function/conftest.py diff --git a/sagemaker-core/tests/integ/remote_function/conftest.py b/sagemaker-core/tests/integ/remote_function/conftest.py new file mode 100644 index 0000000000..55a2f48aa4 --- /dev/null +++ b/sagemaker-core/tests/integ/remote_function/conftest.py @@ -0,0 +1,70 @@ +"""Shared fixtures for remote function integration tests.""" + +import glob +import os +import subprocess +import tempfile + +import cloudpickle +import pytest + +from sagemaker.core.helper.session_helper import Session +from sagemaker.core.s3 import S3Uploader + + +def _get_repo_root(): + return os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) + + +def _build_and_upload_core_wheel(sagemaker_session): + """Build sagemaker-core wheel and upload to S3. Returns (s3_prefix, wheel_basename).""" + repo_root = _get_repo_root() + dist_dir = tempfile.mkdtemp(prefix="sagemaker_core_wheel_") + + subprocess.run( + f"python -m build --wheel --outdir {dist_dir}", + shell=True, + cwd=os.path.join(repo_root, "sagemaker-core"), + check=True, + ) + + matches = glob.glob(os.path.join(dist_dir, "sagemaker_core-*.whl")) + if not matches: + raise FileNotFoundError(f"No sagemaker-core wheel found in {dist_dir}") + wheel_path = matches[0] + + s3_prefix = f"s3://{sagemaker_session.default_bucket()}/remote-function-test/wheels" + S3Uploader.upload(wheel_path, s3_prefix, sagemaker_session=sagemaker_session) + + return s3_prefix, os.path.basename(wheel_path) + + +@pytest.fixture(scope="module") +def sagemaker_session(): + import boto3 + return Session(boto3.Session()) + + +@pytest.fixture(scope="module") +def role(sagemaker_session): + import boto3 + account_id = boto3.client("sts").get_caller_identity()["Account"] + return f"arn:aws:iam::{account_id}:role/Admin" + + +@pytest.fixture(scope="module") +def image_uri(sagemaker_session): + region = sagemaker_session.boto_region_name + return f"763104351884.dkr.ecr.{region}.amazonaws.com/pytorch-training:2.0.0-cpu-py310" + + +@pytest.fixture(scope="module") +def dev_sdk_pre_execution_commands(sagemaker_session): + """Build dev sagemaker-core wheel, upload to S3, and return pre_execution_commands.""" + s3_prefix, wheel_name = _build_and_upload_core_wheel(sagemaker_session) + cp_version = cloudpickle.__version__ + return [ + f"pip install cloudpickle=={cp_version}", + f"aws s3 cp {s3_prefix}/{wheel_name} /tmp/{wheel_name}", + f"pip install /tmp/{wheel_name}", + ] diff --git a/sagemaker-core/tests/integ/remote_function/test_sagemaker_dependency_injection.py b/sagemaker-core/tests/integ/remote_function/test_sagemaker_dependency_injection.py index b3d38c32a4..61d26f78e8 100644 --- a/sagemaker-core/tests/integ/remote_function/test_sagemaker_dependency_injection.py +++ b/sagemaker-core/tests/integ/remote_function/test_sagemaker_dependency_injection.py @@ -9,12 +9,6 @@ import tempfile import pytest -# Skip decorator for AWS configuration -# skip_if_no_aws_region = pytest.mark.skipif( -# not os.environ.get('AWS_DEFAULT_REGION'), -# reason="AWS credentials not configured" -# ) - # Add src to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../../src')) @@ -25,58 +19,47 @@ class TestRemoteFunctionDependencyInjection: """Integration tests for dependency injection in remote functions.""" @pytest.mark.integ - # @skip_if_no_aws_region - def test_remote_function_without_dependencies(self): - """Test remote function execution without explicit dependencies. - - This test verifies that when no dependencies are provided, the remote - function still executes successfully because sagemaker>=3.2.0 is - automatically injected. - """ + def test_remote_function_without_dependencies( + self, dev_sdk_pre_execution_commands, role, image_uri, sagemaker_session + ): + """Test remote function execution without explicit dependencies.""" @remote( instance_type="ml.m5.large", - # No dependencies specified - sagemaker should be injected automatically + role=role, + image_uri=image_uri, + sagemaker_session=sagemaker_session, + pre_execution_commands=dev_sdk_pre_execution_commands, ) def simple_add(x, y): - """Simple function that adds two numbers.""" return x + y - - # Execute the function + result = simple_add(5, 3) - - # Verify result assert result == 8, f"Expected 8, got {result}" - print("✓ Remote function without dependencies executed successfully") @pytest.mark.integ - # @skip_if_no_aws_region - def test_remote_function_with_user_dependencies_no_sagemaker(self): - """Test remote function with user dependencies but no sagemaker. - - This test verifies that when user provides dependencies without sagemaker, - sagemaker>=3.2.0 is automatically appended. - """ - # Create a temporary requirements.txt without sagemaker + def test_remote_function_with_user_dependencies_no_sagemaker( + self, dev_sdk_pre_execution_commands, role, image_uri, sagemaker_session + ): + """Test remote function with user dependencies but no sagemaker.""" with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: f.write("numpy>=1.20.0\npandas>=1.3.0\n") req_file = f.name - + try: @remote( instance_type="ml.m5.large", + role=role, + image_uri=image_uri, + sagemaker_session=sagemaker_session, dependencies=req_file, + pre_execution_commands=dev_sdk_pre_execution_commands, ) def compute_with_numpy(x): - """Function that uses numpy.""" import numpy as np return np.array([x, x*2, x*3]).sum() - - # Execute the function + result = compute_with_numpy(5) - - # Verify result (5 + 10 + 15 = 30) assert result == 30, f"Expected 30, got {result}" - print("✓ Remote function with user dependencies executed successfully") finally: os.remove(req_file) @@ -85,52 +68,55 @@ class TestRemoteFunctionVersionCompatibility: """Tests for version compatibility between local and remote environments.""" @pytest.mark.integ - # @skip_if_no_aws_region - def test_deserialization_with_injected_sagemaker(self): - """Test that deserialization works with injected sagemaker dependency. - - This test verifies that the remote environment can properly deserialize - functions when sagemaker>=3.2.0 is available. - """ + def test_deserialization_with_injected_sagemaker( + self, dev_sdk_pre_execution_commands, role, image_uri, sagemaker_session + ): + """Test that deserialization works with injected sagemaker dependency.""" @remote( instance_type="ml.m5.large", + role=role, + image_uri=image_uri, + sagemaker_session=sagemaker_session, + pre_execution_commands=dev_sdk_pre_execution_commands, ) def complex_computation(data): - """Function that performs complex computation.""" result = sum(data) * len(data) return result - - # Execute with various data types + test_data = [1, 2, 3, 4, 5] result = complex_computation(test_data) - - # Verify result (sum=15, len=5, 15*5=75) assert result == 75, f"Expected 75, got {result}" - print("✓ Deserialization with injected sagemaker works correctly") @pytest.mark.integ - # @skip_if_no_aws_region - def test_multiple_remote_functions_with_dependencies(self): - """Test multiple remote functions with different dependency configurations. - - This test verifies that the dependency injection works correctly - when multiple remote functions are defined and executed. - """ - @remote(instance_type="ml.m5.large") + def test_multiple_remote_functions_with_dependencies( + self, dev_sdk_pre_execution_commands, role, image_uri, sagemaker_session + ): + """Test multiple remote functions with different dependency configurations.""" + @remote( + instance_type="ml.m5.large", + role=role, + image_uri=image_uri, + sagemaker_session=sagemaker_session, + pre_execution_commands=dev_sdk_pre_execution_commands, + ) def func1(x): return x + 1 - - @remote(instance_type="ml.m5.large") + + @remote( + instance_type="ml.m5.large", + role=role, + image_uri=image_uri, + sagemaker_session=sagemaker_session, + pre_execution_commands=dev_sdk_pre_execution_commands, + ) def func2(x): return x * 2 - - # Execute both functions + result1 = func1(5) result2 = func2(5) - + assert result1 == 6, f"func1: Expected 6, got {result1}" assert result2 == 10, f"func2: Expected 10, got {result2}" - print("✓ Multiple remote functions with dependencies executed successfully") if __name__ == "__main__": From 0045281264b17dd644228f93551324dbecb60464 Mon Sep 17 00:00:00 2001 From: Molly He Date: Mon, 9 Mar 2026 16:28:36 -0700 Subject: [PATCH 04/10] Update confest for whell build --- sagemaker-core/tests/integ/remote_function/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sagemaker-core/tests/integ/remote_function/conftest.py b/sagemaker-core/tests/integ/remote_function/conftest.py index 55a2f48aa4..0b4beb77ab 100644 --- a/sagemaker-core/tests/integ/remote_function/conftest.py +++ b/sagemaker-core/tests/integ/remote_function/conftest.py @@ -3,6 +3,7 @@ import glob import os import subprocess +import sys import tempfile import cloudpickle @@ -22,8 +23,7 @@ def _build_and_upload_core_wheel(sagemaker_session): dist_dir = tempfile.mkdtemp(prefix="sagemaker_core_wheel_") subprocess.run( - f"python -m build --wheel --outdir {dist_dir}", - shell=True, + [sys.executable, "-m", "pip", "wheel", "--no-deps", "-w", dist_dir, "."], cwd=os.path.join(repo_root, "sagemaker-core"), check=True, ) From 9a4340977b030cc0b9314b05c79396458f1f0754 Mon Sep 17 00:00:00 2001 From: Molly He Date: Mon, 9 Mar 2026 19:26:57 -0700 Subject: [PATCH 05/10] fix build setup error --- sagemaker-core/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sagemaker-core/pyproject.toml b/sagemaker-core/pyproject.toml index eb2b5c4087..0f3130523a 100644 --- a/sagemaker-core/pyproject.toml +++ b/sagemaker-core/pyproject.toml @@ -40,10 +40,10 @@ dependencies = [ "tblib>=1.7.0", ] requires-python = ">=3.9" +license = "Apache-2.0" classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", From b4351fd1e38f45bad324f7886af97e896795cec8 Mon Sep 17 00:00:00 2001 From: Molly He Date: Mon, 9 Mar 2026 19:37:58 -0700 Subject: [PATCH 06/10] Update license format --- sagemaker-core/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sagemaker-core/pyproject.toml b/sagemaker-core/pyproject.toml index 0f3130523a..b6e2e6d92e 100644 --- a/sagemaker-core/pyproject.toml +++ b/sagemaker-core/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "tblib>=1.7.0", ] requires-python = ">=3.9" -license = "Apache-2.0" +license = {text = "Apache-2.0"} classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", From 65ff9db6a6f1a5ec0fa59fed2508e7df896a0342 Mon Sep 17 00:00:00 2001 From: Molly He Date: Tue, 10 Mar 2026 10:04:13 -0700 Subject: [PATCH 07/10] Add upper bound to setuptools for sagemaker-core --- sagemaker-core/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sagemaker-core/pyproject.toml b/sagemaker-core/pyproject.toml index b6e2e6d92e..875ef44d8f 100644 --- a/sagemaker-core/pyproject.toml +++ b/sagemaker-core/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=64"] +requires = ["setuptools>=64,<77"] build-backend = "setuptools.build_meta" [project] From 35adf569801784b7ef4969f7eb3671aef60f46a4 Mon Sep 17 00:00:00 2001 From: Molly He Date: Tue, 10 Mar 2026 11:27:40 -0700 Subject: [PATCH 08/10] Try to resolve setuptools version conflict --- sagemaker-core/tests/integ/remote_function/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sagemaker-core/tests/integ/remote_function/conftest.py b/sagemaker-core/tests/integ/remote_function/conftest.py index 0b4beb77ab..6ab4786ad7 100644 --- a/sagemaker-core/tests/integ/remote_function/conftest.py +++ b/sagemaker-core/tests/integ/remote_function/conftest.py @@ -23,7 +23,7 @@ def _build_and_upload_core_wheel(sagemaker_session): dist_dir = tempfile.mkdtemp(prefix="sagemaker_core_wheel_") subprocess.run( - [sys.executable, "-m", "pip", "wheel", "--no-deps", "-w", dist_dir, "."], + [sys.executable, "-m", "pip", "wheel", "--no-build-isolation", "--no-deps", "-w", dist_dir, "."], cwd=os.path.join(repo_root, "sagemaker-core"), check=True, ) From 1bbcecb85ad97c6cd722f8ea60abff31f8f47590 Mon Sep 17 00:00:00 2001 From: Molly He Date: Tue, 10 Mar 2026 11:49:15 -0700 Subject: [PATCH 09/10] Further fix on confest --- .../tests/integ/remote_function/conftest.py | 37 ++++++++----------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/sagemaker-core/tests/integ/remote_function/conftest.py b/sagemaker-core/tests/integ/remote_function/conftest.py index 6ab4786ad7..8b00caa794 100644 --- a/sagemaker-core/tests/integ/remote_function/conftest.py +++ b/sagemaker-core/tests/integ/remote_function/conftest.py @@ -1,9 +1,7 @@ """Shared fixtures for remote function integration tests.""" -import glob import os -import subprocess -import sys +import shutil import tempfile import cloudpickle @@ -17,26 +15,20 @@ def _get_repo_root(): return os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) -def _build_and_upload_core_wheel(sagemaker_session): - """Build sagemaker-core wheel and upload to S3. Returns (s3_prefix, wheel_basename).""" +def _upload_core_source(sagemaker_session): + """Tar the sagemaker-core source and upload to S3. Returns (s3_prefix, tar_basename).""" repo_root = _get_repo_root() - dist_dir = tempfile.mkdtemp(prefix="sagemaker_core_wheel_") + core_dir = os.path.join(repo_root, "sagemaker-core") + dist_dir = tempfile.mkdtemp(prefix="sagemaker_core_src_") - subprocess.run( - [sys.executable, "-m", "pip", "wheel", "--no-build-isolation", "--no-deps", "-w", dist_dir, "."], - cwd=os.path.join(repo_root, "sagemaker-core"), - check=True, + archive_path = shutil.make_archive( + os.path.join(dist_dir, "sagemaker-core-src"), "gztar", root_dir=core_dir, base_dir="." ) - matches = glob.glob(os.path.join(dist_dir, "sagemaker_core-*.whl")) - if not matches: - raise FileNotFoundError(f"No sagemaker-core wheel found in {dist_dir}") - wheel_path = matches[0] + s3_prefix = f"s3://{sagemaker_session.default_bucket()}/remote-function-test/src" + S3Uploader.upload(archive_path, s3_prefix, sagemaker_session=sagemaker_session) - s3_prefix = f"s3://{sagemaker_session.default_bucket()}/remote-function-test/wheels" - S3Uploader.upload(wheel_path, s3_prefix, sagemaker_session=sagemaker_session) - - return s3_prefix, os.path.basename(wheel_path) + return s3_prefix, os.path.basename(archive_path) @pytest.fixture(scope="module") @@ -60,11 +52,12 @@ def image_uri(sagemaker_session): @pytest.fixture(scope="module") def dev_sdk_pre_execution_commands(sagemaker_session): - """Build dev sagemaker-core wheel, upload to S3, and return pre_execution_commands.""" - s3_prefix, wheel_name = _build_and_upload_core_wheel(sagemaker_session) + """Upload dev sagemaker-core source to S3 and return pre_execution_commands.""" + s3_prefix, tar_name = _upload_core_source(sagemaker_session) cp_version = cloudpickle.__version__ return [ f"pip install cloudpickle=={cp_version}", - f"aws s3 cp {s3_prefix}/{wheel_name} /tmp/{wheel_name}", - f"pip install /tmp/{wheel_name}", + f"aws s3 cp {s3_prefix}/{tar_name} /tmp/{tar_name}", + "mkdir -p /tmp/sagemaker-core-src && tar xzf /tmp/{tar_name} -C /tmp/sagemaker-core-src".format(tar_name=tar_name), + "pip install --no-deps /tmp/sagemaker-core-src", ] From 820ec3274a26c922ba53f4d4c696d26ea5c283a3 Mon Sep 17 00:00:00 2001 From: Molly He Date: Wed, 11 Mar 2026 15:23:11 -0700 Subject: [PATCH 10/10] Remove legacy fallback, add length check to secrete_arn, remove greedy check in regex --- .../remote_function/core/serialization.py | 77 +++++++++-------- .../test_serialization_security.py | 82 +++++++++++++------ 2 files changed, 98 insertions(+), 61 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py b/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py index d5eab30f4f..6a4aecfab0 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py @@ -579,12 +579,24 @@ def _extract_job_name_from_secret_arn(secret_arn: str) -> str: Raises: DeserializationError: If ARN doesn't match expected format """ + # Length guard to prevent ReDoS on crafted inputs. + # Real ARNs are ~165 chars (job names are max 63 chars per SageMaker API). + MAX_SECRET_ARN_LENGTH = 256 + if len(secret_arn) > MAX_SECRET_ARN_LENGTH: + raise DeserializationError( + f"Secret ARN exceeds maximum length of {MAX_SECRET_ARN_LENGTH} characters" + ) + import re - match = re.search(r":secret:sagemaker/remote-function/(.+)/hmac-key", secret_arn) + # Use [^/]+ (non-greedy, no slashes) to prevent path-traversal in job name, + # and anchor with $ to ensure the ARN ends with the expected suffix. + match = re.search( + r":secret:sagemaker/remote-function/([^/]+)/hmac-key-[A-Za-z0-9]{6}$", secret_arn + ) if not match: raise DeserializationError( f"Secret ARN does not match expected format " - f"'sagemaker/remote-function/{{job_name}}/hmac-key': {secret_arn}" + f"'sagemaker/remote-function/{{job_name}}/hmac-key-XXXXXX': {secret_arn}" ) return match.group(1) @@ -653,37 +665,34 @@ def _perform_integrity_check( Args: expected_hash_value: Expected hash value from metadata buffer: Serialized data buffer - sagemaker_session: SageMaker session (required if secret_arn is provided) - secret_arn: ARN of secret containing HMAC key (None for legacy plain SHA-256) - """ - if secret_arn: - # New secure method: HMAC with key from Secrets Manager - if not sagemaker_session: - raise DeserializationError( - "sagemaker_session is required for HMAC integrity check" - ) - - # Validate secret ARN (Mitigations #1 and #3) - _validate_secret_arn(sagemaker_session, secret_arn) - - # Now safe to retrieve HMAC key - hmac_key = _get_hmac_key_from_secret(sagemaker_session, secret_arn) - actual_hash_value = _compute_hmac(buffer, hmac_key) + sagemaker_session: SageMaker session (required for HMAC integrity check) + secret_arn: ARN of secret containing HMAC key (required) - if not hmac.compare_digest(expected_hash_value, actual_hash_value): - raise DeserializationError( - "HMAC integrity check failed. Serialized data may have been tampered with. " - "Please restrict access to your S3 bucket." - ) - else: - # Legacy method: plain SHA-256 (backward compatibility) - logger.warning( - "Using legacy SHA-256 integrity check without HMAC authentication. " - "This provides weaker security guarantees. Please upgrade to the latest SDK version." + Raises: + DeserializationError: If integrity check fails or secret_arn is missing + """ + if not secret_arn: + raise DeserializationError( + "Missing secret_arn in metadata. HMAC integrity check is required. " + "Legacy SHA-256 integrity check is no longer supported due to security " + "vulnerabilities. Please upgrade to the latest SDK version on both " + "client and remote sides." + ) + + if not sagemaker_session: + raise DeserializationError( + "sagemaker_session is required for HMAC integrity check" + ) + + # Validate secret ARN (Mitigations #1 and #3) + _validate_secret_arn(sagemaker_session, secret_arn) + + # Now safe to retrieve HMAC key + hmac_key = _get_hmac_key_from_secret(sagemaker_session, secret_arn) + actual_hash_value = _compute_hmac(buffer, hmac_key) + + if not hmac.compare_digest(expected_hash_value, actual_hash_value): + raise DeserializationError( + "HMAC integrity check failed. Serialized data may have been tampered with. " + "Please restrict access to your S3 bucket." ) - actual_hash_value = _compute_hash(buffer) - if expected_hash_value != actual_hash_value: - raise DeserializationError( - "Integrity check for the serialized function or data failed. " - "Please restrict access to your S3 bucket" - ) diff --git a/sagemaker-core/tests/unit/remote_function/test_serialization_security.py b/sagemaker-core/tests/unit/remote_function/test_serialization_security.py index eb2c7cc9f7..d01d64cea1 100644 --- a/sagemaker-core/tests/unit/remote_function/test_serialization_security.py +++ b/sagemaker-core/tests/unit/remote_function/test_serialization_security.py @@ -250,7 +250,44 @@ def test_invalid_arn_format_rejected(self): _validate_secret_arn(session, "not-an-arn") -class TestPerformIntegrityCheck: +class TestExtractJobNameFromSecretArn: + """Tests for _extract_job_name_from_secret_arn regex hardening.""" + + def test_valid_arn(self): + result = _extract_job_name_from_secret_arn(MOCK_SECRET_ARN) + assert result == MOCK_JOB_NAME + + def test_rejects_greedy_path_traversal(self): + """Greedy .+ allowed evil/hmac-key/../ in job name — now rejected.""" + malicious_arn = ( + "arn:aws:secretsmanager:us-east-1:123456789012:secret:" + "sagemaker/remote-function/evil/hmac-key/../sagemaker/" + "remote-function/legit-job/hmac-key-AbCdEf" + ) + with pytest.raises(DeserializationError, match="does not match expected format"): + _extract_job_name_from_secret_arn(malicious_arn) + + def test_rejects_arn_exceeding_max_length(self): + """Long input caused ReDoS — now rejected by length check.""" + long_arn = ( + "arn:aws:secretsmanager:us-east-1:123456789012:" + + ":secret:sagemaker/remote-function/y" * 10100 + + "\n:secret:sagemaker/remote-function/c/hmac-key-AbCdEf" + ) + with pytest.raises(DeserializationError, match="exceeds maximum length"): + _extract_job_name_from_secret_arn(long_arn) + + def test_rejects_arn_without_6char_suffix(self): + """ARN must end with hmac-key-XXXXXX (6 alphanumeric chars).""" + bad_arn = "arn:aws:secretsmanager:us-west-2:123456789012:secret:sagemaker/remote-function/job/hmac-key" + with pytest.raises(DeserializationError, match="does not match expected format"): + _extract_job_name_from_secret_arn(bad_arn) + + def test_rejects_arn_with_trailing_content(self): + """$ anchor prevents matching when extra content follows.""" + bad_arn = MOCK_SECRET_ARN + "/extra" + with pytest.raises(DeserializationError, match="does not match expected format"): + _extract_job_name_from_secret_arn(bad_arn) """Tests for integrity check with HMAC.""" def test_hmac_integrity_check_passes(self): @@ -284,27 +321,26 @@ def test_hmac_integrity_check_fails_on_tampered_payload(self): secret_arn=MOCK_SECRET_ARN, ) - def test_legacy_sha256_check_passes_with_warning(self): - """Legacy SHA-256 check should pass with warning when no secret_arn.""" + def test_legacy_sha256_check_rejected(self): + """Legacy SHA-256 check without secret_arn is no longer supported.""" payload = b"test payload" expected_hash = _compute_hash(payload) - # Should not raise (legacy path) - _perform_integrity_check( - expected_hash_value=expected_hash, - buffer=payload, - ) + with pytest.raises(DeserializationError, match="HMAC integrity check is required"): + _perform_integrity_check( + expected_hash_value=expected_hash, + buffer=payload, + ) - def test_legacy_sha256_check_fails_on_tampered_payload(self): - """Legacy SHA-256 check should fail on tampered payload.""" - original_payload = b"original payload" - tampered_payload = b"tampered payload" - expected_hash = _compute_hash(original_payload) + def test_legacy_sha256_tampered_payload_also_rejected(self): + """Legacy path is rejected regardless of hash correctness.""" + payload = b"test payload" + expected_hash = _compute_hash(payload) - with pytest.raises(DeserializationError, match="Integrity check"): + with pytest.raises(DeserializationError, match="HMAC integrity check is required"): _perform_integrity_check( expected_hash_value=expected_hash, - buffer=tampered_payload, + buffer=b"tampered", ) def test_hmac_check_requires_session(self): @@ -320,22 +356,14 @@ class TestAttackScenarios: """Tests simulating actual attack scenarios.""" def test_attacker_replaces_payload_and_metadata_plain_hash(self): - """Attacker replaces both files with plain SHA-256 - should fail HMAC check.""" - session, secrets_client, _, _ = _mock_sagemaker_session() - - # Attacker creates malicious payload + """Attacker replaces both files with plain SHA-256 (no secret_arn) - should be rejected.""" malicious_payload = b"malicious code" - - # Attacker computes plain SHA-256 (not HMAC) plain_hash = hashlib.sha256(malicious_payload).hexdigest() - # Attacker's HMAC won't match because they don't know the key - with pytest.raises(DeserializationError, match="HMAC integrity check failed"): + with pytest.raises(DeserializationError, match="HMAC integrity check is required"): _perform_integrity_check( expected_hash_value=plain_hash, buffer=malicious_payload, - sagemaker_session=session, - secret_arn=MOCK_SECRET_ARN, ) def test_attacker_points_to_cross_account_secret(self): @@ -356,8 +384,8 @@ def test_attacker_creates_secret_in_same_account(self): "Parameter": {"Value": MOCK_SECRET_ARN} } - # Attacker's secret in same account - attacker_arn = f"arn:aws:secretsmanager:us-west-2:{MOCK_ACCOUNT_ID}:secret:sagemaker/remote-function/evil-job/hmac-key" + # Attacker's secret in same account (with valid suffix format) + attacker_arn = f"arn:aws:secretsmanager:us-west-2:{MOCK_ACCOUNT_ID}:secret:sagemaker/remote-function/evil-job/hmac-key-XyZ123" with pytest.raises(DeserializationError, match="Secret ARN mismatch"): _validate_secret_arn(session, attacker_arn)