From 6004e8b69cd9b1a4fd37a3e589436b182761c276 Mon Sep 17 00:00:00 2001 From: Joshua Towner Date: Thu, 12 Mar 2026 21:31:07 -0700 Subject: [PATCH 1/4] fixes for model builder --- .../src/sagemaker/core/shapes/shapes.py | 2 +- .../src/sagemaker/serve/model_builder.py | 295 +++++++++++------- 2 files changed, 177 insertions(+), 120 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/shapes/shapes.py b/sagemaker-core/src/sagemaker/core/shapes/shapes.py index c0715fc4ae..adbcf6ec67 100644 --- a/sagemaker-core/src/sagemaker/core/shapes/shapes.py +++ b/sagemaker-core/src/sagemaker/core/shapes/shapes.py @@ -8577,7 +8577,7 @@ class InferenceComponentComputeResourceRequirements(Base): max_memory_required_in_mb: The maximum MB of memory to allocate to run a model that you assign to an inference component. """ - min_memory_required_in_mb: int + min_memory_required_in_mb: Optional[int] = Unassigned() number_of_cpu_cores_required: Optional[float] = Unassigned() number_of_accelerator_devices_required: Optional[float] = Unassigned() max_memory_required_in_mb: Optional[int] = Unassigned() diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index c2ba0c36eb..b8193d53fc 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -957,6 +957,10 @@ def _fetch_and_cache_recipe_config(self): if not self.image_uri: self.image_uri = config.get("EcrAddress") + # Cache environment variables from recipe config + if not self.env_vars: + self.env_vars = config.get("Environment", {}) + # Infer instance type from JumpStart metadata if not provided # This is only called for model_customization deployments if not self.instance_type: @@ -2211,21 +2215,57 @@ def _build_single_modelbuilder( "Only SageMaker Endpoint Mode is supported for Model Customization use cases" ) model_package = self._fetch_model_package() - # Fetch recipe config first to set image_uri, instance_type, and s3_upload_path + # Fetch recipe config first to set image_uri, instance_type, env_vars, and s3_upload_path self._fetch_and_cache_recipe_config() - self.s3_upload_path = model_package.inference_specification.containers[ - 0 - ].model_data_source.s3_data_source.s3_uri - container_def = ContainerDefinition( - image=self.image_uri, - model_data_source={ - "s3_data_source": { - "s3_uri": f"{self.s3_upload_path}/", - "s3_data_type": "S3Prefix", - "compression_type": "None", - } - }, - ) + peft_type = self._fetch_peft() + + if peft_type == "LORA": + # For LORA: Model points at JumpStart base model, not training output + hub_document = self._fetch_hub_document_for_custom_model() + hosting_artifact_uri = hub_document.get("HostingArtifactUri") + if not hosting_artifact_uri: + raise ValueError( + "HostingArtifactUri not found in JumpStart hub metadata. " + "Cannot deploy LORA adapter without base model artifacts." + ) + container_def = ContainerDefinition( + image=self.image_uri, + environment=self.env_vars, + model_data_source={ + "s3_data_source": { + "s3_uri": hosting_artifact_uri, + "s3_data_type": "S3Prefix", + "compression_type": "None", + "model_access_config": {"accept_eula": True}, + } + }, + ) + # Store adapter path for use during deploy + if isinstance(self.model, TrainingJob): + self._adapter_s3_uri = ( + f"{self.model.model_artifacts.s3_model_artifacts}/checkpoints/hf/" + ) + elif isinstance(self.model, ModelTrainer): + self._adapter_s3_uri = ( + f"{self.model._latest_training_job.model_artifacts.s3_model_artifacts}" + "/checkpoints/hf/" + ) + else: + # Non-LORA: Model points at training output + self.s3_upload_path = model_package.inference_specification.containers[ + 0 + ].model_data_source.s3_data_source.s3_uri + container_def = ContainerDefinition( + image=self.image_uri, + model_data_source={ + "s3_data_source": { + "s3_uri": self.s3_upload_path.rstrip("/") + "/", + "s3_data_type": "S3Prefix", + "compression_type": "None", + } + }, + ) + model_name = self.model_name or f"model-{uuid.uuid4().hex[:10]}" # Create model self.built_model = Model.create( @@ -4142,17 +4182,13 @@ def _deploy_model_customization( """Deploy a model customization (fine-tuned) model to an endpoint with inference components. This method handles the special deployment flow for fine-tuned models, creating: - 1. Core Model resource - 2. EndpointConfig - 3. Endpoint - 4. InferenceComponent + 1. EndpointConfig and Endpoint + 2. Base model InferenceComponent (for LORA: from JumpStart base model) + 3. Adapter InferenceComponent (for LORA: referencing base IC with adapter weights) Args: endpoint_name (str): Name of the endpoint to create or update - instance_type (str): EC2 instance type for deployment initial_instance_count (int): Number of instances (default: 1) - wait (bool): Whether to wait for deployment to complete (default: True) - container_timeout_in_seconds (int): Container timeout in seconds (default: 300) inference_component_name (Optional[str]): Name for the inference component inference_config (Optional[ResourceRequirements]): Inference configuration including resource requirements (accelerator count, memory, CPU cores) @@ -4161,21 +4197,15 @@ def _deploy_model_customization( Returns: Endpoint: The deployed sagemaker.core.resources.Endpoint """ - from sagemaker.core.resources import ( - Model as CoreModel, - EndpointConfig as CoreEndpointConfig, - ) - from sagemaker.core.shapes import ContainerDefinition, ProductionVariant from sagemaker.core.shapes import ( InferenceComponentSpecification, InferenceComponentContainerSpecification, InferenceComponentRuntimeConfig, InferenceComponentComputeResourceRequirements, - ModelDataSource, - S3ModelDataSource, ) + from sagemaker.core.shapes import ProductionVariant from sagemaker.core.resources import InferenceComponent - from sagemaker.core.utils.utils import Unassigned + from sagemaker.core.resources import Tag as CoreTag # Fetch model package model_package = self._fetch_model_package() @@ -4183,9 +4213,6 @@ def _deploy_model_customization( # Check if endpoint exists is_existing_endpoint = self._does_endpoint_exist(endpoint_name) - # Generate model name if not set - model_name = self.model_name or f"model-{uuid.uuid4().hex[:10]}" - if not is_existing_endpoint: EndpointConfig.create( endpoint_config_name=endpoint_name, @@ -4206,48 +4233,16 @@ def _deploy_model_customization( else: endpoint = Endpoint.get(endpoint_name=endpoint_name) - # Set inference component name - if not inference_component_name: - if not is_existing_endpoint: - inference_component_name = f"{endpoint_name}-inference-component" - else: - inference_component_name = f"{endpoint_name}-inference-component-adapter" - - # Get PEFT type and base model recipe name peft_type = self._fetch_peft() base_model_recipe_name = model_package.inference_specification.containers[ 0 ].base_model.recipe_name - base_inference_component_name = None - tag = None - - # Resolve the correct model artifact URI based on deployment type - artifact_url = self._resolve_model_artifact_uri() - - # Determine if this is a base model deployment - # A base model deployment uses HostingArtifactUri from JumpStart (not from model package) - is_base_model_deployment = False - if artifact_url and not peft_type: - # Check if artifact_url comes from JumpStart (not from model package) - # If model package has model_data_source, it's a full fine-tuned model - if ( - hasattr(model_package.inference_specification.containers[0], "model_data_source") - and model_package.inference_specification.containers[0].model_data_source - ): - is_base_model_deployment = False # Full fine-tuned model - else: - is_base_model_deployment = True # Base model from JumpStart - - # Handle tagging and base component lookup - if not is_existing_endpoint and is_base_model_deployment: - # Only tag as "Base" if we're actually deploying a base model - from sagemaker.core.resources import Tag as CoreTag - tag = CoreTag(key="Base", value=base_model_recipe_name) - elif peft_type == "LORA": - # For LORA adapters, look up the existing base component - from sagemaker.core.resources import Tag as CoreTag + if peft_type == "LORA": + # LORA deployment: base IC + adapter IC + # Find or create base IC + base_ic_name = None for component in InferenceComponent.get_all( endpoint_name_equals=endpoint_name, status_equals="InService" ): @@ -4255,65 +4250,128 @@ def _deploy_model_customization( if any( t.key == "Base" and t.value == base_model_recipe_name for t in component_tags ): - base_inference_component_name = component.inference_component_name + base_ic_name = component.inference_component_name break - ic_spec = InferenceComponentSpecification( - container=InferenceComponentContainerSpecification( - image=self.image_uri, artifact_url=artifact_url, environment=self.env_vars + if not base_ic_name: + # Deploy base model IC + base_ic_name = f"{endpoint_name}-inference-component" + + base_ic_spec = InferenceComponentSpecification( + model_name=self.built_model.model_name, + ) + if inference_config is not None: + base_ic_spec.compute_resource_requirements = ( + InferenceComponentComputeResourceRequirements( + min_memory_required_in_mb=inference_config.min_memory, + max_memory_required_in_mb=inference_config.max_memory, + number_of_cpu_cores_required=inference_config.num_cpus, + number_of_accelerator_devices_required=inference_config.num_accelerators, + ) + ) + else: + base_ic_spec.compute_resource_requirements = self._cached_compute_requirements + + InferenceComponent.create( + inference_component_name=base_ic_name, + endpoint_name=endpoint_name, + variant_name=endpoint_name, + specification=base_ic_spec, + runtime_config=InferenceComponentRuntimeConfig(copy_count=1), + tags=[{"key": "Base", "value": base_model_recipe_name}], + ) + logger.info("Created base model InferenceComponent: '%s'", base_ic_name) + + # Wait for base IC to be InService before creating adapter + base_ic = InferenceComponent.get(inference_component_name=base_ic_name) + base_ic.wait_for_status("InService") + + # Deploy adapter IC + adapter_ic_name = inference_component_name or f"{endpoint_name}-adapter" + adapter_s3_uri = getattr(self, "_adapter_s3_uri", None) + + adapter_ic_spec = InferenceComponentSpecification( + base_inference_component_name=base_ic_name, + container=InferenceComponentContainerSpecification( + artifact_url=adapter_s3_uri, + ), ) - ) - if peft_type == "LORA": - ic_spec.base_inference_component_name = base_inference_component_name - - # Use inference_config if provided, otherwise fall back to cached requirements - if inference_config is not None: - # Extract compute requirements from inference_config (ResourceRequirements) - ic_spec.compute_resource_requirements = InferenceComponentComputeResourceRequirements( - min_memory_required_in_mb=inference_config.min_memory, - max_memory_required_in_mb=inference_config.max_memory, - number_of_cpu_cores_required=inference_config.num_cpus, - number_of_accelerator_devices_required=inference_config.num_accelerators, + InferenceComponent.create( + inference_component_name=adapter_ic_name, + endpoint_name=endpoint_name, + specification=adapter_ic_spec, ) + logger.info("Created adapter InferenceComponent: '%s'", adapter_ic_name) + else: - # Fall back to resolved compute requirements from build() - ic_spec.compute_resource_requirements = self._cached_compute_requirements + # Non-LORA deployment: single IC + if not inference_component_name: + inference_component_name = f"{endpoint_name}-inference-component" - InferenceComponent.create( - inference_component_name=inference_component_name, - endpoint_name=endpoint_name, - variant_name=endpoint_name, - specification=ic_spec, - runtime_config=InferenceComponentRuntimeConfig(copy_count=1), - tags=[{"key": tag.key, "value": tag.value}] if tag else [], - ) + artifact_url = self._resolve_model_artifact_uri() + + ic_spec = InferenceComponentSpecification( + container=InferenceComponentContainerSpecification( + image=self.image_uri, artifact_url=artifact_url, environment=self.env_vars + ) + ) + + if inference_config is not None: + ic_spec.compute_resource_requirements = ( + InferenceComponentComputeResourceRequirements( + min_memory_required_in_mb=inference_config.min_memory, + max_memory_required_in_mb=inference_config.max_memory, + number_of_cpu_cores_required=inference_config.num_cpus, + number_of_accelerator_devices_required=inference_config.num_accelerators, + ) + ) + else: + ic_spec.compute_resource_requirements = self._cached_compute_requirements + + InferenceComponent.create( + inference_component_name=inference_component_name, + endpoint_name=endpoint_name, + variant_name=endpoint_name, + specification=ic_spec, + runtime_config=InferenceComponentRuntimeConfig(copy_count=1), + ) # Create lineage tracking for new endpoints if not is_existing_endpoint: - from sagemaker.core.resources import Action, Association, Artifact - from sagemaker.core.shapes import ActionSource, MetadataProperties + try: + from sagemaker.core.resources import Action, Association, Artifact + from sagemaker.core.shapes import ActionSource, MetadataProperties - inference_component = InferenceComponent.get( - inference_component_name=inference_component_name - ) + ic_name = ( + inference_component_name + if not peft_type == "LORA" + else adapter_ic_name + ) + inference_component = InferenceComponent.get( + inference_component_name=ic_name + ) - action = Action.create( - source=ActionSource( - source_uri=self._fetch_model_package_arn(), source_type="SageMaker" - ), - action_name=f"{endpoint_name}-action", - action_type="ModelDeployment", - properties={"EndpointConfigName": endpoint_name}, - metadata_properties=MetadataProperties( - generated_by=inference_component.inference_component_arn - ), - ) + action = Action.create( + source=ActionSource( + source_uri=self._fetch_model_package_arn(), source_type="SageMaker" + ), + action_name=f"{endpoint_name}-action", + action_type="ModelDeployment", + properties={"EndpointConfigName": endpoint_name}, + metadata_properties=MetadataProperties( + generated_by=inference_component.inference_component_arn + ), + ) - artifacts = Artifact.get_all(source_uri=model_package.model_package_arn) - for artifact in artifacts: - Association.add(source_arn=artifact.artifact_arn, destination_arn=action.action_arn) - break + artifacts = Artifact.get_all(source_uri=model_package.model_package_arn) + for artifact in artifacts: + Association.add( + source_arn=artifact.artifact_arn, destination_arn=action.action_arn + ) + break + except Exception as e: + logger.warning(f"Failed to create lineage tracking: {e}") logger.info("✅ Model customization deployment successful: Endpoint '%s'", endpoint_name) return endpoint @@ -4329,11 +4387,10 @@ def _fetch_peft(self) -> Optional[str]: from sagemaker.core.utils.utils import Unassigned - if ( - training_job.serverless_job_config != Unassigned() - and training_job.serverless_job_config.job_spec != Unassigned() - ): - return training_job.serverless_job_config.job_spec.get("PEFT") + if training_job.serverless_job_config != Unassigned(): + peft = getattr(training_job.serverless_job_config, "peft", None) + if peft and not isinstance(peft, Unassigned): + return peft return None def _does_endpoint_exist(self, endpoint_name: str) -> bool: From e0c912b35923bfb7d05d1fbfd3a2fa30ac66cefe Mon Sep 17 00:00:00 2001 From: Joshua Towner Date: Thu, 12 Mar 2026 22:51:29 -0700 Subject: [PATCH 2/4] add nova model support --- .../src/sagemaker/serve/model_builder.py | 223 ++++++++++++++++++ 1 file changed, 223 insertions(+) diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index b8193d53fc..943e3a2953 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -985,12 +985,112 @@ def _fetch_and_cache_recipe_config(self): ) return + # Fallback: Nova recipes don't have hosting configs in the hub document + if self._is_nova_model(): + nova_config = self._get_nova_hosting_config(instance_type=self.instance_type) + if not self.image_uri: + self.image_uri = nova_config["image_uri"] + if not self.env_vars: + self.env_vars = nova_config["env_vars"] + if not self.instance_type: + self.instance_type = nova_config["instance_type"] + return + raise ValueError( f"Model with recipe '{recipe_name}' is not supported for deployment. " f"The recipe does not have hosting configuration. " f"Please use a model that supports deployment or contact AWS support for assistance." ) + # Nova escrow ECR accounts per region + _NOVA_ESCROW_ACCOUNTS = { + "us-east-1": "708977205387", + "us-west-2": "176779409107", + "eu-west-2": "470633809225", + "ap-northeast-1": "878185805882", + } + + # Nova hosting configs per model (from Rhinestone modelDeployment.ts) + _NOVA_HOSTING_CONFIGS = { + "nova-textgeneration-micro": [ + {"InstanceType": "ml.g5.12xlarge", "Environment": {"CONTEXT_LENGTH": "4096", "MAX_CONCURRENCY": "16"}}, + {"InstanceType": "ml.g5.24xlarge", "Profile": "Default", "Environment": {"CONTEXT_LENGTH": "8192", "MAX_CONCURRENCY": "16"}}, + {"InstanceType": "ml.g6.12xlarge", "Environment": {"CONTEXT_LENGTH": "10000", "MAX_CONCURRENCY": "16"}}, + {"InstanceType": "ml.g6.24xlarge", "Environment": {"CONTEXT_LENGTH": "10000", "MAX_CONCURRENCY": "16"}}, + {"InstanceType": "ml.g6.48xlarge", "Environment": {"CONTEXT_LENGTH": "12000", "MAX_CONCURRENCY": "16"}}, + {"InstanceType": "ml.p5.48xlarge", "Environment": {"CONTEXT_LENGTH": "12000", "MAX_CONCURRENCY": "16"}}, + ], + "nova-textgeneration-lite": [ + {"InstanceType": "ml.g6.48xlarge", "Profile": "Default", "Environment": {"CONTEXT_LENGTH": "20000", "MAX_CONCURRENCY": "16"}}, + {"InstanceType": "ml.p5.48xlarge", "Environment": {"CONTEXT_LENGTH": "12000", "MAX_CONCURRENCY": "16"}}, + ], + "nova-textgeneration-pro": [ + {"InstanceType": "ml.g6.48xlarge", "Environment": {"CONTEXT_LENGTH": "12000", "MAX_CONCURRENCY": "16"}}, + {"InstanceType": "ml.p5.48xlarge", "Profile": "Default", "Environment": {"CONTEXT_LENGTH": "50000", "MAX_CONCURRENCY": "16"}}, + ], + "nova-textgeneration-lite-v2": [ + {"InstanceType": "ml.p5.48xlarge", "Profile": "Default", "Environment": {"CONTEXT_LENGTH": "50000", "MAX_CONCURRENCY": "16"}}, + ], + } + + def _is_nova_model(self) -> bool: + """Check if the model is a Nova model based on recipe name or hub content name.""" + model_package = self._fetch_model_package() + if not model_package: + return False + containers = getattr(model_package.inference_specification, "containers", None) + if not containers: + return False + base_model = getattr(containers[0], "base_model", None) + if not base_model: + return False + recipe_name = getattr(base_model, "recipe_name", "") or "" + hub_content_name = getattr(base_model, "hub_content_name", "") or "" + return "nova" in recipe_name.lower() or "nova" in hub_content_name.lower() + + def _get_nova_hosting_config(self, instance_type=None): + """Get Nova hosting config (image URI, env vars, instance type). + + Nova training recipes don't have hosting configs in the JumpStart hub document. + This provides the hardcoded fallback, matching Rhinestone's getNovaHostingConfigs(). + """ + model_package = self._fetch_model_package() + hub_content_name = model_package.inference_specification.containers[0].base_model.hub_content_name + + configs = self._NOVA_HOSTING_CONFIGS.get(hub_content_name) + if not configs: + raise ValueError( + f"Nova model '{hub_content_name}' is not supported for deployment. " + f"Supported: {list(self._NOVA_HOSTING_CONFIGS.keys())}" + ) + + region = self.sagemaker_session.boto_region_name + escrow_account = self._NOVA_ESCROW_ACCOUNTS.get(region) + if not escrow_account: + raise ValueError( + f"Nova deployment is not supported in region '{region}'. " + f"Supported: {list(self._NOVA_ESCROW_ACCOUNTS.keys())}" + ) + + image_uri = f"{escrow_account}.dkr.ecr.{region}.amazonaws.com/nova-inference-repo:SM-Inference-latest" + + if instance_type: + config = next((c for c in configs if c["InstanceType"] == instance_type), None) + if not config: + supported = [c["InstanceType"] for c in configs] + raise ValueError( + f"Instance type '{instance_type}' not supported for '{hub_content_name}'. " + f"Supported: {supported}" + ) + else: + config = next((c for c in configs if c.get("Profile") == "Default"), configs[0]) + + return { + "image_uri": image_uri, + "env_vars": config["Environment"], + "instance_type": config["InstanceType"], + } + def _initialize_jumpstart_config(self) -> None: """Initialize JumpStart-specific configuration.""" if hasattr(self, "hub_name") and self.hub_name and not self.hub_arn: @@ -2217,6 +2317,36 @@ def _build_single_modelbuilder( model_package = self._fetch_model_package() # Fetch recipe config first to set image_uri, instance_type, env_vars, and s3_upload_path self._fetch_and_cache_recipe_config() + + # Nova models use a completely different deployment architecture + if self._is_nova_model(): + escrow_uri = self._resolve_nova_escrow_uri() + base_model = model_package.inference_specification.containers[0].base_model + + container_def = ContainerDefinition( + image=self.image_uri, + environment=self.env_vars, + model_data_source={ + "s3_data_source": { + "s3_uri": escrow_uri.rstrip("/") + "/", + "s3_data_type": "S3Prefix", + "compression_type": "None", + } + }, + ) + model_name = self.model_name or f"model-{uuid.uuid4().hex[:10]}" + self.built_model = Model.create( + execution_role_arn=self.role_arn, + model_name=model_name, + containers=[container_def], + enable_network_isolation=True, + tags=[ + {"key": "sagemaker-studio:jumpstart-model-id", + "value": base_model.hub_content_name}, + ], + ) + return self.built_model + peft_type = self._fetch_peft() if peft_type == "LORA": @@ -4207,6 +4337,14 @@ def _deploy_model_customization( from sagemaker.core.resources import InferenceComponent from sagemaker.core.resources import Tag as CoreTag + # Nova models use direct model-on-variant, no InferenceComponents + if self._is_nova_model(): + return self._deploy_nova_model( + endpoint_name=endpoint_name, + initial_instance_count=initial_instance_count, + wait=kwargs.get("wait", True), + ) + # Fetch model package model_package = self._fetch_model_package() @@ -4403,6 +4541,91 @@ def _does_endpoint_exist(self, endpoint_name: str) -> bool: return False raise + def _resolve_nova_escrow_uri(self) -> str: + """Resolve the escrow S3 URI for Nova model artifacts from manifest.json. + + Nova training jobs write artifacts to an escrow S3 bucket. The location + is recorded in manifest.json in the training job output directory. + """ + import json + from urllib.parse import urlparse + + if isinstance(self.model, TrainingJob): + training_job = self.model + elif isinstance(self.model, ModelTrainer): + training_job = self.model._latest_training_job + else: + raise ValueError("Nova escrow URI resolution requires a TrainingJob or ModelTrainer") + + output_path = training_job.output_data_config.s3_output_path.rstrip("/") + manifest_s3 = f"{output_path}/{training_job.training_job_name}/output/output/manifest.json" + + parsed = urlparse(manifest_s3) + bucket = parsed.netloc + key = parsed.path.lstrip("/") + + s3_client = self.sagemaker_session.boto_session.client("s3") + resp = s3_client.get_object(Bucket=bucket, Key=key) + manifest = json.loads(resp["Body"].read().decode()) + + escrow_uri = manifest.get("checkpoint_s3_bucket") + if not escrow_uri: + raise ValueError( + f"'checkpoint_s3_bucket' not found in manifest.json. " + f"Available keys: {list(manifest.keys())}" + ) + return escrow_uri + + def _deploy_nova_model( + self, + endpoint_name: str, + initial_instance_count: int = 1, + wait: bool = True, + ) -> Endpoint: + """Deploy a Nova model directly to an endpoint without inference components. + + Nova models use a model-on-variant architecture: + - ModelName is embedded in the ProductionVariant + - No InferenceComponents are created + - EnableNetworkIsolation is set on the Model (during build) + """ + from sagemaker.core.shapes import ProductionVariant + + model_package = self._fetch_model_package() + base_model = model_package.inference_specification.containers[0].base_model + + if not endpoint_name: + endpoint_name = f"endpoint-{uuid.uuid4().hex[:8]}" + + EndpointConfig.create( + endpoint_config_name=endpoint_name, + production_variants=[ + ProductionVariant( + variant_name="AllTraffic", + model_name=self.built_model.model_name, + instance_type=self.instance_type, + initial_instance_count=initial_instance_count, + ) + ], + ) + + tags = [ + {"key": "sagemaker-studio:jumpstart-model-id", "value": base_model.hub_content_name}, + ] + if base_model.recipe_name: + tags.append({"key": "sagemaker-studio:recipe-name", "value": base_model.recipe_name}) + + endpoint = Endpoint.create( + endpoint_name=endpoint_name, + endpoint_config_name=endpoint_name, + tags=tags, + ) + + if wait: + endpoint.wait_for_status("InService") + + return endpoint + @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="model_builder.deploy_local") def deploy_local( self, endpoint_name: str = "endpoint", container_timeout_in_seconds: int = 300, **kwargs From 3debf58d21a76b5d10402c305e8e74dc7d13a758 Mon Sep 17 00:00:00 2001 From: Roja Reddy Sareddy Date: Sun, 15 Mar 2026 11:03:42 -0700 Subject: [PATCH 3/4] fix env_vars merge, update integ test for LORA two-step deployment, fix unit tests for nova model support - env_vars: append recipe/nova config to existing env_vars instead of skipping - integ test: verify both base IC and adapter IC creation for LORA models - unit tests: add _is_nova_model mock to accommodate nova model support changes --- .../src/sagemaker/serve/model_builder.py | 8 +++- .../test_model_customization_deployment.py | 40 ++++++++++++------- .../unit/test_artifact_path_propagation.py | 32 +++++++++------ ...est_inference_config_parameter_handling.py | 21 +++++++--- .../tests/unit/test_model_builder.py | 14 +++---- .../tests/unit/test_two_stage_deployment.py | 22 ++++++---- 6 files changed, 88 insertions(+), 49 deletions(-) diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index 943e3a2953..2f32b8a7e7 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -958,7 +958,9 @@ def _fetch_and_cache_recipe_config(self): self.image_uri = config.get("EcrAddress") # Cache environment variables from recipe config - if not self.env_vars: + if self.env_vars: + self.env_vars.update(config.get("Environment", {})) + else: self.env_vars = config.get("Environment", {}) # Infer instance type from JumpStart metadata if not provided @@ -990,7 +992,9 @@ def _fetch_and_cache_recipe_config(self): nova_config = self._get_nova_hosting_config(instance_type=self.instance_type) if not self.image_uri: self.image_uri = nova_config["image_uri"] - if not self.env_vars: + if self.env_vars: + self.env_vars.update(nova_config["env_vars"]) + else: self.env_vars = nova_config["env_vars"] if not self.instance_type: self.instance_type = nova_config["instance_type"] diff --git a/sagemaker-serve/tests/integ/test_model_customization_deployment.py b/sagemaker-serve/tests/integ/test_model_customization_deployment.py index 615bb67d2c..cc99828bc2 100644 --- a/sagemaker-serve/tests/integ/test_model_customization_deployment.py +++ b/sagemaker-serve/tests/integ/test_model_customization_deployment.py @@ -113,15 +113,26 @@ def test_build_from_training_job(self, training_job_name): assert model_builder.instance_type is not None def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanup_endpoints): - """Test deploying model from training job and adapter.""" - from sagemaker.core.resources import TrainingJob + """Test deploying model from training job. + + For LORA models, this verifies the two-step deployment: + base IC + adapter IC are both created on the same endpoint. + """ + from sagemaker.core.resources import TrainingJob, InferenceComponent from sagemaker.serve import ModelBuilder import time training_job = TrainingJob.get(training_job_name=training_job_name) - model_builder = ModelBuilder(model=training_job) - model = model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}") - endpoint = model_builder.deploy(endpoint_name=endpoint_name) + model_builder = ModelBuilder(model=training_job, instance_type="ml.g5.4xlarge") + model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}") + + peft_type = model_builder._fetch_peft() + adapter_name = f"{endpoint_name}-adapter" + + endpoint = model_builder.deploy( + endpoint_name=endpoint_name, + inference_component_name=adapter_name if peft_type == "LORA" else None, + ) cleanup_endpoints.append(endpoint_name) @@ -129,17 +140,16 @@ def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanu assert endpoint.endpoint_arn is not None assert endpoint.endpoint_status == "InService" - # Deploy adapter to the same endpoint - adapter_name = f"{endpoint_name}-adapter-{int(time.time())}-{random.randint(100, 100000)}" - model_builder2 = ModelBuilder(model=training_job) - model_builder2.build() - endpoint2 = model_builder2.deploy( - endpoint_name=endpoint_name, - inference_component_name=adapter_name - ) + if peft_type == "LORA": + # Verify base IC was created + base_ic_name = f"{endpoint_name}-inference-component" + base_ic = InferenceComponent.get(inference_component_name=base_ic_name) + assert base_ic is not None + assert base_ic.inference_component_status == "InService" - assert endpoint2 is not None - assert endpoint2.endpoint_name == endpoint_name + # Verify adapter IC was created + adapter_ic = InferenceComponent.get(inference_component_name=adapter_name) + assert adapter_ic is not None def test_fetch_endpoint_names_for_base_model(self, training_job_name): """Test fetching endpoint names for base model.""" diff --git a/sagemaker-serve/tests/unit/test_artifact_path_propagation.py b/sagemaker-serve/tests/unit/test_artifact_path_propagation.py index 89b511ab83..1094e76612 100644 --- a/sagemaker-serve/tests/unit/test_artifact_path_propagation.py +++ b/sagemaker-serve/tests/unit/test_artifact_path_propagation.py @@ -45,8 +45,10 @@ def setUp(self): @patch("sagemaker.serve.model_builder.ModelBuilder._resolve_model_artifact_uri") @patch("sagemaker.serve.model_builder.ModelBuilder._fetch_peft") @patch("sagemaker.serve.model_builder.ModelBuilder._is_model_customization") + @patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False) def test_base_model_artifact_uri_propagated_to_inference_component( self, + mock_is_nova_model, mock_is_model_customization, mock_fetch_peft, mock_resolve_artifact, @@ -133,8 +135,10 @@ def test_base_model_artifact_uri_propagated_to_inference_component( @patch("sagemaker.serve.model_builder.ModelBuilder._resolve_model_artifact_uri") @patch("sagemaker.serve.model_builder.ModelBuilder._fetch_peft") @patch("sagemaker.serve.model_builder.ModelBuilder._is_model_customization") + @patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False) def test_fine_tuned_model_artifact_uri_propagated_to_inference_component( self, + mock_is_nova_model, mock_is_model_customization, mock_fetch_peft, mock_resolve_artifact, @@ -220,8 +224,10 @@ def test_fine_tuned_model_artifact_uri_propagated_to_inference_component( @patch("sagemaker.serve.model_builder.ModelBuilder._resolve_model_artifact_uri") @patch("sagemaker.serve.model_builder.ModelBuilder._fetch_peft") @patch("sagemaker.serve.model_builder.ModelBuilder._is_model_customization") + @patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False) def test_lora_adapter_no_artifact_uri_propagated( self, + mock_is_nova_model, mock_is_model_customization, mock_fetch_peft, mock_resolve_artifact, @@ -298,21 +304,21 @@ def test_lora_adapter_no_artifact_uri_propagated( # Execute: Deploy to existing endpoint (LORA adapter) builder._deploy_model_customization(endpoint_name="test-endpoint", initial_instance_count=1) - # Verify: _resolve_model_artifact_uri was called - assert mock_resolve_artifact.called + # Verify: _resolve_model_artifact_uri is NOT called for LORA adapters + assert not mock_resolve_artifact.called - # Verify: InferenceComponent.create was called with artifact_url=None + # Verify: InferenceComponent.create was called assert mock_ic_create.called - call_kwargs = mock_ic_create.call_args[1] - - # Extract the specification - ic_spec = call_kwargs["specification"] - - # Verify artifact_url is None for LORA adapters - assert ic_spec.container.artifact_url is None - # Verify base_inference_component_name is set - assert ic_spec.base_inference_component_name == "base-component" + # Verify: adapter IC has base_inference_component_name set + # Find the adapter IC create call (the one with base_inference_component_name) + for c in mock_ic_create.call_args_list: + ic_spec = c[1]["specification"] + if ic_spec.base_inference_component_name: + assert ic_spec.base_inference_component_name == "base-component" + break + else: + pytest.fail("No adapter IC with base_inference_component_name found") @patch("sagemaker.core.resources.InferenceComponent.create") @patch("sagemaker.core.resources.Endpoint.get") @@ -323,8 +329,10 @@ def test_lora_adapter_no_artifact_uri_propagated( @patch("sagemaker.serve.model_builder.ModelBuilder._resolve_model_artifact_uri") @patch("sagemaker.serve.model_builder.ModelBuilder._fetch_peft") @patch("sagemaker.serve.model_builder.ModelBuilder._is_model_customization") + @patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False) def test_environment_variables_propagated_with_artifact_path( self, + mock_is_nova_model, mock_is_model_customization, mock_fetch_peft, mock_resolve_artifact, diff --git a/sagemaker-serve/tests/unit/test_inference_config_parameter_handling.py b/sagemaker-serve/tests/unit/test_inference_config_parameter_handling.py index 36f312e7f8..960e1fc459 100644 --- a/sagemaker-serve/tests/unit/test_inference_config_parameter_handling.py +++ b/sagemaker-serve/tests/unit/test_inference_config_parameter_handling.py @@ -50,8 +50,10 @@ def setUp(self): @patch("sagemaker.core.resources.Endpoint.get") @patch("sagemaker.core.resources.InferenceComponent.create") @patch("sagemaker.core.resources.InferenceComponent.get_all") + @patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False) def test_inference_config_provided_all_fields( self, + mock_is_nova_model, mock_ic_get_all, mock_ic_create, mock_endpoint_get, @@ -155,8 +157,10 @@ def test_inference_config_provided_all_fields( @patch("sagemaker.core.resources.Endpoint.get") @patch("sagemaker.core.resources.InferenceComponent.create") @patch("sagemaker.core.resources.InferenceComponent.get_all") + @patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False) def test_inference_config_provided_partial_fields( self, + mock_is_nova_model, mock_ic_get_all, mock_ic_create, mock_endpoint_get, @@ -258,8 +262,10 @@ def test_inference_config_provided_partial_fields( @patch("sagemaker.core.resources.InferenceComponent.get_all") @patch("sagemaker.serve.model_builder.ModelBuilder._fetch_hub_document_for_custom_model") @patch("sagemaker.serve.model_builder.ModelBuilder._get_instance_resources") + @patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False) def test_inference_config_not_provided_uses_cached_requirements( self, + mock_is_nova_model, mock_get_resources, mock_fetch_hub, mock_ic_get_all, @@ -378,8 +384,10 @@ def test_inference_config_not_provided_uses_cached_requirements( @patch("sagemaker.core.resources.Endpoint.get") @patch("sagemaker.core.resources.InferenceComponent.create") @patch("sagemaker.core.resources.InferenceComponent.get_all") + @patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False) def test_inference_config_overrides_cached_requirements( self, + mock_is_nova_model, mock_ic_get_all, mock_ic_create, mock_endpoint_get, @@ -486,8 +494,10 @@ def test_inference_config_overrides_cached_requirements( @patch("sagemaker.core.resources.Endpoint.get") @patch("sagemaker.core.resources.InferenceComponent.create") @patch("sagemaker.core.resources.InferenceComponent.get_all") + @patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False) def test_all_resource_requirements_fields_reach_api_call( self, + mock_is_nova_model, mock_ic_get_all, mock_ic_create, mock_endpoint_get, @@ -588,8 +598,10 @@ def test_all_resource_requirements_fields_reach_api_call( @patch("sagemaker.core.resources.InferenceComponent.create") @patch("sagemaker.core.resources.InferenceComponent.get_all") @patch("sagemaker.core.resources.Tag.get_all") + @patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False) def test_inference_config_with_existing_endpoint_lora_adapter( self, + mock_is_nova_model, mock_tag_get_all, mock_ic_get_all, mock_ic_create, @@ -653,15 +665,10 @@ def test_inference_config_with_existing_endpoint_lora_adapter( endpoint_name="existing-endpoint", inference_config=inference_config ) - # Verify: InferenceComponent.create was called with inference_config + # Verify: InferenceComponent.create was called assert mock_ic_create.called call_kwargs = mock_ic_create.call_args[1] ic_spec = call_kwargs["specification"] - compute_reqs = ic_spec.compute_resource_requirements - - # Verify inference_config values were used - assert compute_reqs.number_of_accelerator_devices_required == 1 - assert compute_reqs.min_memory_required_in_mb == 4096 # Verify base_inference_component_name is set for LORA assert ic_spec.base_inference_component_name == "base-component" @@ -679,8 +686,10 @@ def test_inference_config_with_existing_endpoint_lora_adapter( @patch("sagemaker.core.resources.Endpoint.get") @patch("sagemaker.core.resources.InferenceComponent.create") @patch("sagemaker.core.resources.InferenceComponent.get_all") + @patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False) def test_inference_config_with_zero_accelerators( self, + mock_is_nova_model, mock_ic_get_all, mock_ic_create, mock_endpoint_get, diff --git a/sagemaker-serve/tests/unit/test_model_builder.py b/sagemaker-serve/tests/unit/test_model_builder.py index 19f7cae31f..e5900f5562 100644 --- a/sagemaker-serve/tests/unit/test_model_builder.py +++ b/sagemaker-serve/tests/unit/test_model_builder.py @@ -371,10 +371,8 @@ def test_fetch_peft_from_training_job(self): """Test fetching PEFT from TrainingJob.""" from sagemaker.core.utils.utils import Unassigned - mock_job_spec = Mock() - mock_job_spec.get = Mock(return_value="LORA") self.mock_training_job.serverless_job_config = Mock() - self.mock_training_job.serverless_job_config.job_spec = mock_job_spec + self.mock_training_job.serverless_job_config.peft = "LORA" builder = ModelBuilder( model=self.mock_training_job, @@ -389,10 +387,8 @@ def test_fetch_peft_from_model_trainer(self): """Test fetching PEFT from ModelTrainer.""" from sagemaker.train.model_trainer import ModelTrainer - mock_job_spec = Mock() - mock_job_spec.get = Mock(return_value="LORA") self.mock_training_job.serverless_job_config = Mock() - self.mock_training_job.serverless_job_config.job_spec = mock_job_spec + self.mock_training_job.serverless_job_config.peft = "LORA" mock_trainer = Mock(spec=ModelTrainer) mock_trainer._latest_training_job = self.mock_training_job @@ -459,7 +455,8 @@ def test_build_single_modelbuilder_with_model_customization(self, mock_is_1p, mo with patch.object(builder, '_fetch_and_cache_recipe_config'): with patch.object(builder, '_get_client_translators', return_value=(Mock(), Mock())): with patch.object(builder, '_get_serve_setting', return_value=Mock()): - result = builder._build_single_modelbuilder() + with patch.object(builder, '_is_nova_model', return_value=False): + result = builder._build_single_modelbuilder() # Verify Model.create was called (indicating model customization path was taken) mock_model_class.create.assert_called_once() @@ -500,6 +497,7 @@ def test_deploy_model_customization_new_endpoint(self): with patch.object(builder, '_fetch_model_package', return_value=mock_model_package): with patch.object(builder, '_fetch_peft', return_value=None): + with patch.object(builder, '_is_nova_model', return_value=False): with patch.object(EndpointConfig, 'create', return_value=mock_endpoint_config): with patch.object(Endpoint, 'get', side_effect=ClientError({'Error': {'Code': 'ValidationException'}}, 'GetEndpoint')): with patch.object(Endpoint, 'create', return_value=mock_endpoint): @@ -574,6 +572,7 @@ def capture_ic_create(**kwargs): with patch.object(builder, '_fetch_model_package', return_value=mock_model_package): with patch.object(builder, '_fetch_peft', return_value=None): + with patch.object(builder, '_is_nova_model', return_value=False): with patch.object(EndpointConfig, 'create', return_value=mock_endpoint_config): with patch.object(Endpoint, 'get', side_effect=ClientError({'Error': {'Code': 'ValidationException'}}, 'GetEndpoint')): with patch.object(Endpoint, 'create', return_value=mock_endpoint): @@ -646,6 +645,7 @@ def capture_ic_create(**kwargs): with patch.object(builder, '_fetch_model_package', return_value=mock_model_package): with patch.object(builder, '_fetch_peft', return_value=None): + with patch.object(builder, '_is_nova_model', return_value=False): with patch.object(EndpointConfig, 'create', return_value=mock_endpoint_config): with patch.object(Endpoint, 'get', side_effect=ClientError({'Error': {'Code': 'ValidationException'}}, 'GetEndpoint')): with patch.object(Endpoint, 'create', return_value=mock_endpoint): diff --git a/sagemaker-serve/tests/unit/test_two_stage_deployment.py b/sagemaker-serve/tests/unit/test_two_stage_deployment.py index 827d3bea41..8d9f7b4f75 100644 --- a/sagemaker-serve/tests/unit/test_two_stage_deployment.py +++ b/sagemaker-serve/tests/unit/test_two_stage_deployment.py @@ -26,10 +26,12 @@ class TestTwoStageDeployment: @patch.object(ModelBuilder, "_fetch_peft") @patch.object(ModelBuilder, "_does_endpoint_exist") @patch.object(ModelBuilder, "_fetch_hub_document_for_custom_model") + @patch.object(ModelBuilder, "_is_nova_model", return_value=False) @patch.object(ModelBuilder, "_is_model_customization") def test_base_model_deployment_tagged_correctly( self, mock_is_customization, + mock_is_nova_model, mock_fetch_hub, mock_endpoint_exists, mock_fetch_peft, @@ -96,15 +98,13 @@ def test_base_model_deployment_tagged_correctly( ), patch("sagemaker.core.resources.Artifact"): model_builder._deploy_model_customization(endpoint_name="test-endpoint") - # Verify: InferenceComponent.create was called with Base tag + # Verify: InferenceComponent.create was called assert mock_ic_create.called create_call = mock_ic_create.call_args tags = create_call[1].get("tags", []) - # Should have exactly one tag with key="Base" - assert len(tags) == 1 - assert tags[0]["key"] == "Base" - assert tags[0]["value"] == "test-base-model" + # Non-LORA deployments do not get Base tags + assert len(tags) == 0 @patch("sagemaker.core.resources.InferenceComponent.get") @patch("sagemaker.core.resources.InferenceComponent.create") @@ -115,10 +115,12 @@ def test_base_model_deployment_tagged_correctly( @patch.object(ModelBuilder, "_fetch_model_package") @patch.object(ModelBuilder, "_fetch_peft") @patch.object(ModelBuilder, "_does_endpoint_exist") + @patch.object(ModelBuilder, "_is_nova_model", return_value=False) @patch.object(ModelBuilder, "_is_model_customization") def test_full_fine_tuned_model_not_tagged_as_base( self, mock_is_customization, + mock_is_nova_model, mock_endpoint_exists, mock_fetch_peft, mock_fetch_package, @@ -200,10 +202,12 @@ def test_full_fine_tuned_model_not_tagged_as_base( @patch.object(ModelBuilder, "_fetch_model_package") @patch.object(ModelBuilder, "_fetch_peft") @patch.object(ModelBuilder, "_does_endpoint_exist") + @patch.object(ModelBuilder, "_is_nova_model", return_value=False) @patch.object(ModelBuilder, "_is_model_customization") def test_lora_adapter_references_base_component( self, mock_is_customization, + mock_is_nova_model, mock_endpoint_exists, mock_fetch_peft, mock_fetch_package, @@ -286,10 +290,12 @@ def test_lora_adapter_references_base_component( @patch.object(ModelBuilder, "_fetch_peft") @patch.object(ModelBuilder, "_does_endpoint_exist") @patch.object(ModelBuilder, "_fetch_hub_document_for_custom_model") + @patch.object(ModelBuilder, "_is_nova_model", return_value=False) @patch.object(ModelBuilder, "_is_model_customization") def test_base_model_uses_hosting_artifact_uri( self, mock_is_customization, + mock_is_nova_model, mock_fetch_hub, mock_endpoint_exists, mock_fetch_peft, @@ -375,10 +381,12 @@ def test_base_model_uses_hosting_artifact_uri( @patch.object(ModelBuilder, "_fetch_peft") @patch.object(ModelBuilder, "_does_endpoint_exist") @patch.object(ModelBuilder, "_fetch_hub_document_for_custom_model") + @patch.object(ModelBuilder, "_is_nova_model", return_value=False) @patch.object(ModelBuilder, "_is_model_customization") def test_sequential_base_then_adapter_deployment( self, mock_is_customization, + mock_is_nova_model, mock_fetch_hub, mock_endpoint_exists, mock_fetch_peft, @@ -457,8 +465,8 @@ def test_sequential_base_then_adapter_deployment( assert mock_ic_create.call_count == 1 base_create_call = mock_ic_create.call_args base_tags = base_create_call[1].get("tags", []) - assert len(base_tags) == 1 - assert base_tags[0]["key"] == "Base" + # Non-LORA deployments do not get Base tags + assert len(base_tags) == 0 # Reset mocks for adapter deployment mock_ic_create.reset_mock() From a31427df2c19d0dfba5af225e3fa11f1d2e36ed6 Mon Sep 17 00:00:00 2001 From: Roja Reddy Sareddy Date: Sun, 15 Mar 2026 21:40:37 -0700 Subject: [PATCH 4/4] update codegen to mark MinMemoryRequiredInMb as optional DescribeInferenceComponent returns empty ComputeResourceRequirements for adapter ICs (created with BaseInferenceComponentName), but the service model still marks MinMemoryRequiredInMb as required. Add a REQUIRED_TO_OPTIONAL_OVERRIDES config in the codegen so re-running shapes generation produces the correct Optional field. --- sagemaker-core/src/sagemaker/core/tools/constants.py | 6 ++++++ .../src/sagemaker/core/tools/shapes_extractor.py | 11 ++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sagemaker-core/src/sagemaker/core/tools/constants.py b/sagemaker-core/src/sagemaker/core/tools/constants.py index e785a7975d..0649eb980e 100644 --- a/sagemaker-core/src/sagemaker/core/tools/constants.py +++ b/sagemaker-core/src/sagemaker/core/tools/constants.py @@ -107,3 +107,9 @@ CONFIG_SCHEMA_FILE_NAME = "config_schema.py" API_COVERAGE_JSON_FILE_PATH = os.getcwd() + "/src/sagemaker/core/tools/api_coverage.json" + +# Members that the service model marks as required but the API returns as optional. +# E.g. DescribeInferenceComponent returns empty ComputeResourceRequirements for adapter ICs. +REQUIRED_TO_OPTIONAL_OVERRIDES = { + "InferenceComponentComputeResourceRequirements": ["MinMemoryRequiredInMb"], +} diff --git a/sagemaker-core/src/sagemaker/core/tools/shapes_extractor.py b/sagemaker-core/src/sagemaker/core/tools/shapes_extractor.py index e6d1e573d0..95cc359e59 100644 --- a/sagemaker-core/src/sagemaker/core/tools/shapes_extractor.py +++ b/sagemaker-core/src/sagemaker/core/tools/shapes_extractor.py @@ -16,7 +16,11 @@ from functools import lru_cache from typing import Optional, Any -from sagemaker.core.tools.constants import BASIC_JSON_TYPES_TO_PYTHON_TYPES, SHAPE_DAG_FILE_PATH +from sagemaker.core.tools.constants import ( + BASIC_JSON_TYPES_TO_PYTHON_TYPES, + REQUIRED_TO_OPTIONAL_OVERRIDES, + SHAPE_DAG_FILE_PATH, +) from sagemaker.core.utils.utils import ( reformat_file_with_black, convert_to_snake_case, @@ -216,6 +220,11 @@ def generate_shape_members(self, shape, required_override=()): shape_dict = self.combined_shapes[shape] members = shape_dict["members"] required_args = list(required_override) or shape_dict.get("required", []) + # Remove members that are known to be optional despite the service model + required_args = [ + r for r in required_args + if r not in REQUIRED_TO_OPTIONAL_OVERRIDES.get(shape, []) + ] init_data_body = {} # bring the required members in front ordered_members = {key: members[key] for key in required_args if key in members}