Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sagemaker-core/src/sagemaker/core/shapes/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8577,7 +8577,7 @@ class InferenceComponentComputeResourceRequirements(Base):
max_memory_required_in_mb: The maximum MB of memory to allocate to run a model that you assign to an inference component.
"""

min_memory_required_in_mb: int
min_memory_required_in_mb: Optional[int] = Unassigned()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update the engine to mark this as Optional ?

We can add a condition for this specific case

number_of_cpu_cores_required: Optional[float] = Unassigned()
number_of_accelerator_devices_required: Optional[float] = Unassigned()
max_memory_required_in_mb: Optional[int] = Unassigned()
Expand Down
6 changes: 6 additions & 0 deletions sagemaker-core/src/sagemaker/core/tools/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,9 @@
CONFIG_SCHEMA_FILE_NAME = "config_schema.py"

API_COVERAGE_JSON_FILE_PATH = os.getcwd() + "/src/sagemaker/core/tools/api_coverage.json"

# Members that the service model marks as required but the API returns as optional.
# E.g. DescribeInferenceComponent returns empty ComputeResourceRequirements for adapter ICs.
REQUIRED_TO_OPTIONAL_OVERRIDES = {
"InferenceComponentComputeResourceRequirements": ["MinMemoryRequiredInMb"],
}
11 changes: 10 additions & 1 deletion sagemaker-core/src/sagemaker/core/tools/shapes_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from functools import lru_cache
from typing import Optional, Any

from sagemaker.core.tools.constants import BASIC_JSON_TYPES_TO_PYTHON_TYPES, SHAPE_DAG_FILE_PATH
from sagemaker.core.tools.constants import (
BASIC_JSON_TYPES_TO_PYTHON_TYPES,
REQUIRED_TO_OPTIONAL_OVERRIDES,
SHAPE_DAG_FILE_PATH,
)
from sagemaker.core.utils.utils import (
reformat_file_with_black,
convert_to_snake_case,
Expand Down Expand Up @@ -216,6 +220,11 @@ def generate_shape_members(self, shape, required_override=()):
shape_dict = self.combined_shapes[shape]
members = shape_dict["members"]
required_args = list(required_override) or shape_dict.get("required", [])
# Remove members that are known to be optional despite the service model
required_args = [
r for r in required_args
if r not in REQUIRED_TO_OPTIONAL_OVERRIDES.get(shape, [])
]
init_data_body = {}
# bring the required members in front
ordered_members = {key: members[key] for key in required_args if key in members}
Expand Down
522 changes: 403 additions & 119 deletions sagemaker-serve/src/sagemaker/serve/model_builder.py

Large diffs are not rendered by default.

40 changes: 25 additions & 15 deletions sagemaker-serve/tests/integ/test_model_customization_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,33 +113,43 @@ def test_build_from_training_job(self, training_job_name):
assert model_builder.instance_type is not None

def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanup_endpoints):
"""Test deploying model from training job and adapter."""
from sagemaker.core.resources import TrainingJob
"""Test deploying model from training job.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test case for Nova ?


For LORA models, this verifies the two-step deployment:
base IC + adapter IC are both created on the same endpoint.
"""
from sagemaker.core.resources import TrainingJob, InferenceComponent
from sagemaker.serve import ModelBuilder
import time

training_job = TrainingJob.get(training_job_name=training_job_name)
model_builder = ModelBuilder(model=training_job)
model = model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}")
endpoint = model_builder.deploy(endpoint_name=endpoint_name)
model_builder = ModelBuilder(model=training_job, instance_type="ml.g5.4xlarge")
model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}")

peft_type = model_builder._fetch_peft()
adapter_name = f"{endpoint_name}-adapter"

endpoint = model_builder.deploy(
endpoint_name=endpoint_name,
inference_component_name=adapter_name if peft_type == "LORA" else None,
)

cleanup_endpoints.append(endpoint_name)

assert endpoint is not None
assert endpoint.endpoint_arn is not None
assert endpoint.endpoint_status == "InService"

# Deploy adapter to the same endpoint
adapter_name = f"{endpoint_name}-adapter-{int(time.time())}-{random.randint(100, 100000)}"
model_builder2 = ModelBuilder(model=training_job)
model_builder2.build()
endpoint2 = model_builder2.deploy(
endpoint_name=endpoint_name,
inference_component_name=adapter_name
)
if peft_type == "LORA":
# Verify base IC was created
base_ic_name = f"{endpoint_name}-inference-component"
base_ic = InferenceComponent.get(inference_component_name=base_ic_name)
assert base_ic is not None
assert base_ic.inference_component_status == "InService"

assert endpoint2 is not None
assert endpoint2.endpoint_name == endpoint_name
# Verify adapter IC was created
adapter_ic = InferenceComponent.get(inference_component_name=adapter_name)
assert adapter_ic is not None

def test_fetch_endpoint_names_for_base_model(self, training_job_name):
"""Test fetching endpoint names for base model."""
Expand Down
32 changes: 20 additions & 12 deletions sagemaker-serve/tests/unit/test_artifact_path_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ def setUp(self):
@patch("sagemaker.serve.model_builder.ModelBuilder._resolve_model_artifact_uri")
@patch("sagemaker.serve.model_builder.ModelBuilder._fetch_peft")
@patch("sagemaker.serve.model_builder.ModelBuilder._is_model_customization")
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
def test_base_model_artifact_uri_propagated_to_inference_component(
self,
mock_is_nova_model,
mock_is_model_customization,
mock_fetch_peft,
mock_resolve_artifact,
Expand Down Expand Up @@ -133,8 +135,10 @@ def test_base_model_artifact_uri_propagated_to_inference_component(
@patch("sagemaker.serve.model_builder.ModelBuilder._resolve_model_artifact_uri")
@patch("sagemaker.serve.model_builder.ModelBuilder._fetch_peft")
@patch("sagemaker.serve.model_builder.ModelBuilder._is_model_customization")
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
def test_fine_tuned_model_artifact_uri_propagated_to_inference_component(
self,
mock_is_nova_model,
mock_is_model_customization,
mock_fetch_peft,
mock_resolve_artifact,
Expand Down Expand Up @@ -220,8 +224,10 @@ def test_fine_tuned_model_artifact_uri_propagated_to_inference_component(
@patch("sagemaker.serve.model_builder.ModelBuilder._resolve_model_artifact_uri")
@patch("sagemaker.serve.model_builder.ModelBuilder._fetch_peft")
@patch("sagemaker.serve.model_builder.ModelBuilder._is_model_customization")
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
def test_lora_adapter_no_artifact_uri_propagated(
self,
mock_is_nova_model,
mock_is_model_customization,
mock_fetch_peft,
mock_resolve_artifact,
Expand Down Expand Up @@ -298,21 +304,21 @@ def test_lora_adapter_no_artifact_uri_propagated(
# Execute: Deploy to existing endpoint (LORA adapter)
builder._deploy_model_customization(endpoint_name="test-endpoint", initial_instance_count=1)

# Verify: _resolve_model_artifact_uri was called
assert mock_resolve_artifact.called
# Verify: _resolve_model_artifact_uri is NOT called for LORA adapters
assert not mock_resolve_artifact.called

# Verify: InferenceComponent.create was called with artifact_url=None
# Verify: InferenceComponent.create was called
assert mock_ic_create.called
call_kwargs = mock_ic_create.call_args[1]

# Extract the specification
ic_spec = call_kwargs["specification"]

# Verify artifact_url is None for LORA adapters
assert ic_spec.container.artifact_url is None

# Verify base_inference_component_name is set
assert ic_spec.base_inference_component_name == "base-component"
# Verify: adapter IC has base_inference_component_name set
# Find the adapter IC create call (the one with base_inference_component_name)
for c in mock_ic_create.call_args_list:
ic_spec = c[1]["specification"]
if ic_spec.base_inference_component_name:
assert ic_spec.base_inference_component_name == "base-component"
break
else:
pytest.fail("No adapter IC with base_inference_component_name found")

@patch("sagemaker.core.resources.InferenceComponent.create")
@patch("sagemaker.core.resources.Endpoint.get")
Expand All @@ -323,8 +329,10 @@ def test_lora_adapter_no_artifact_uri_propagated(
@patch("sagemaker.serve.model_builder.ModelBuilder._resolve_model_artifact_uri")
@patch("sagemaker.serve.model_builder.ModelBuilder._fetch_peft")
@patch("sagemaker.serve.model_builder.ModelBuilder._is_model_customization")
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
def test_environment_variables_propagated_with_artifact_path(
self,
mock_is_nova_model,
mock_is_model_customization,
mock_fetch_peft,
mock_resolve_artifact,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ def setUp(self):
@patch("sagemaker.core.resources.Endpoint.get")
@patch("sagemaker.core.resources.InferenceComponent.create")
@patch("sagemaker.core.resources.InferenceComponent.get_all")
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
def test_inference_config_provided_all_fields(
self,
mock_is_nova_model,
mock_ic_get_all,
mock_ic_create,
mock_endpoint_get,
Expand Down Expand Up @@ -155,8 +157,10 @@ def test_inference_config_provided_all_fields(
@patch("sagemaker.core.resources.Endpoint.get")
@patch("sagemaker.core.resources.InferenceComponent.create")
@patch("sagemaker.core.resources.InferenceComponent.get_all")
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
def test_inference_config_provided_partial_fields(
self,
mock_is_nova_model,
mock_ic_get_all,
mock_ic_create,
mock_endpoint_get,
Expand Down Expand Up @@ -258,8 +262,10 @@ def test_inference_config_provided_partial_fields(
@patch("sagemaker.core.resources.InferenceComponent.get_all")
@patch("sagemaker.serve.model_builder.ModelBuilder._fetch_hub_document_for_custom_model")
@patch("sagemaker.serve.model_builder.ModelBuilder._get_instance_resources")
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
def test_inference_config_not_provided_uses_cached_requirements(
self,
mock_is_nova_model,
mock_get_resources,
mock_fetch_hub,
mock_ic_get_all,
Expand Down Expand Up @@ -378,8 +384,10 @@ def test_inference_config_not_provided_uses_cached_requirements(
@patch("sagemaker.core.resources.Endpoint.get")
@patch("sagemaker.core.resources.InferenceComponent.create")
@patch("sagemaker.core.resources.InferenceComponent.get_all")
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
def test_inference_config_overrides_cached_requirements(
self,
mock_is_nova_model,
mock_ic_get_all,
mock_ic_create,
mock_endpoint_get,
Expand Down Expand Up @@ -486,8 +494,10 @@ def test_inference_config_overrides_cached_requirements(
@patch("sagemaker.core.resources.Endpoint.get")
@patch("sagemaker.core.resources.InferenceComponent.create")
@patch("sagemaker.core.resources.InferenceComponent.get_all")
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
def test_all_resource_requirements_fields_reach_api_call(
self,
mock_is_nova_model,
mock_ic_get_all,
mock_ic_create,
mock_endpoint_get,
Expand Down Expand Up @@ -588,8 +598,10 @@ def test_all_resource_requirements_fields_reach_api_call(
@patch("sagemaker.core.resources.InferenceComponent.create")
@patch("sagemaker.core.resources.InferenceComponent.get_all")
@patch("sagemaker.core.resources.Tag.get_all")
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
def test_inference_config_with_existing_endpoint_lora_adapter(
self,
mock_is_nova_model,
mock_tag_get_all,
mock_ic_get_all,
mock_ic_create,
Expand Down Expand Up @@ -653,15 +665,10 @@ def test_inference_config_with_existing_endpoint_lora_adapter(
endpoint_name="existing-endpoint", inference_config=inference_config
)

# Verify: InferenceComponent.create was called with inference_config
# Verify: InferenceComponent.create was called
assert mock_ic_create.called
call_kwargs = mock_ic_create.call_args[1]
ic_spec = call_kwargs["specification"]
compute_reqs = ic_spec.compute_resource_requirements

# Verify inference_config values were used
assert compute_reqs.number_of_accelerator_devices_required == 1
assert compute_reqs.min_memory_required_in_mb == 4096

# Verify base_inference_component_name is set for LORA
assert ic_spec.base_inference_component_name == "base-component"
Expand All @@ -679,8 +686,10 @@ def test_inference_config_with_existing_endpoint_lora_adapter(
@patch("sagemaker.core.resources.Endpoint.get")
@patch("sagemaker.core.resources.InferenceComponent.create")
@patch("sagemaker.core.resources.InferenceComponent.get_all")
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
def test_inference_config_with_zero_accelerators(
self,
mock_is_nova_model,
mock_ic_get_all,
mock_ic_create,
mock_endpoint_get,
Expand Down
14 changes: 7 additions & 7 deletions sagemaker-serve/tests/unit/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,8 @@ def test_fetch_peft_from_training_job(self):
"""Test fetching PEFT from TrainingJob."""
from sagemaker.core.utils.utils import Unassigned

mock_job_spec = Mock()
mock_job_spec.get = Mock(return_value="LORA")
self.mock_training_job.serverless_job_config = Mock()
self.mock_training_job.serverless_job_config.job_spec = mock_job_spec
self.mock_training_job.serverless_job_config.peft = "LORA"

builder = ModelBuilder(
model=self.mock_training_job,
Expand All @@ -389,10 +387,8 @@ def test_fetch_peft_from_model_trainer(self):
"""Test fetching PEFT from ModelTrainer."""
from sagemaker.train.model_trainer import ModelTrainer

mock_job_spec = Mock()
mock_job_spec.get = Mock(return_value="LORA")
self.mock_training_job.serverless_job_config = Mock()
self.mock_training_job.serverless_job_config.job_spec = mock_job_spec
self.mock_training_job.serverless_job_config.peft = "LORA"

mock_trainer = Mock(spec=ModelTrainer)
mock_trainer._latest_training_job = self.mock_training_job
Expand Down Expand Up @@ -459,7 +455,8 @@ def test_build_single_modelbuilder_with_model_customization(self, mock_is_1p, mo
with patch.object(builder, '_fetch_and_cache_recipe_config'):
with patch.object(builder, '_get_client_translators', return_value=(Mock(), Mock())):
with patch.object(builder, '_get_serve_setting', return_value=Mock()):
result = builder._build_single_modelbuilder()
with patch.object(builder, '_is_nova_model', return_value=False):
result = builder._build_single_modelbuilder()

# Verify Model.create was called (indicating model customization path was taken)
mock_model_class.create.assert_called_once()
Expand Down Expand Up @@ -500,6 +497,7 @@ def test_deploy_model_customization_new_endpoint(self):

with patch.object(builder, '_fetch_model_package', return_value=mock_model_package):
with patch.object(builder, '_fetch_peft', return_value=None):
with patch.object(builder, '_is_nova_model', return_value=False):
with patch.object(EndpointConfig, 'create', return_value=mock_endpoint_config):
with patch.object(Endpoint, 'get', side_effect=ClientError({'Error': {'Code': 'ValidationException'}}, 'GetEndpoint')):
with patch.object(Endpoint, 'create', return_value=mock_endpoint):
Expand Down Expand Up @@ -574,6 +572,7 @@ def capture_ic_create(**kwargs):

with patch.object(builder, '_fetch_model_package', return_value=mock_model_package):
with patch.object(builder, '_fetch_peft', return_value=None):
with patch.object(builder, '_is_nova_model', return_value=False):
with patch.object(EndpointConfig, 'create', return_value=mock_endpoint_config):
with patch.object(Endpoint, 'get', side_effect=ClientError({'Error': {'Code': 'ValidationException'}}, 'GetEndpoint')):
with patch.object(Endpoint, 'create', return_value=mock_endpoint):
Expand Down Expand Up @@ -646,6 +645,7 @@ def capture_ic_create(**kwargs):

with patch.object(builder, '_fetch_model_package', return_value=mock_model_package):
with patch.object(builder, '_fetch_peft', return_value=None):
with patch.object(builder, '_is_nova_model', return_value=False):
with patch.object(EndpointConfig, 'create', return_value=mock_endpoint_config):
with patch.object(Endpoint, 'get', side_effect=ClientError({'Error': {'Code': 'ValidationException'}}, 'GetEndpoint')):
with patch.object(Endpoint, 'create', return_value=mock_endpoint):
Expand Down
Loading
Loading