From 82c2386b5e2fe4d2a0e3f5d4608709285ea4a040 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 10 Jun 2026 21:51:52 +0000 Subject: [PATCH 1/2] solved the weight streaming failure --- .../dynamo/runtime/_TRTEngine.py | 40 ++++++++++++------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py index aeda1aa1e4..d2ec9e702d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py +++ b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py @@ -696,10 +696,17 @@ def device_memory_budget(self) -> Any: def device_memory_budget(self, budget_bytes: int) -> None: if budget_bytes < 0: budget_bytes = self.streamable_device_memory_budget + # The weight streaming budget cannot be modified while an execution + # context is active, so release the current context first, then update + # the budget and recreate it (mirrors the C++ runtime's + # set_device_memory_budget). + self.context = None self.cuda_engine.weight_streaming_budget_v2 = budget_bytes if self.cuda_engine.weight_streaming_budget_v2 != budget_bytes: logger.error(f"Failed to set weight streaming budget to {budget_bytes}") self.context = self._create_execution_context() + if self._profile_execution: + self.enable_profiling() self.runtime_states.context_changed = True def reset_captured_graph(self) -> None: @@ -882,11 +889,12 @@ def _prepare_streams(self, contiguous_inputs: List[torch.Tensor]) -> bool: ): # Captured CUDA graph was recorded against the old stream. self.runtime_states.context_changed = True - return caller_on_default + return bool(caller_on_default) def _execute_standard( self, contiguous_inputs: List[torch.Tensor] ) -> torch.Tensor | Tuple[torch.Tensor, ...]: + cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() if ( ENABLED_FEATURES.tensorrt_rtx @@ -913,6 +921,9 @@ def _execute_standard( # cudagraph recapture (set_runtime_states consumes and resets the # flag). caller_on_default = self._prepare_streams(contiguous_inputs) + engine_stream = self._engine_stream + caller_stream = self._caller_stream + assert engine_stream is not None and caller_stream is not None shape_changed = self.validate_input_shapes(contiguous_inputs) ( need_cudagraphs_record, @@ -970,8 +981,8 @@ def _execute_standard( with self._profile_section("TRTEngine:TensorRTRuntime"): if caller_on_default: - self._engine_stream.wait_stream(self._caller_stream) - with torch.cuda.stream(self._engine_stream): + engine_stream.wait_stream(caller_stream) + with torch.cuda.stream(engine_stream): if self.resource_allocation_strategy: self._dynamic_workspace = torch.empty( self.cuda_engine.device_memory_size_v2, @@ -985,22 +996,18 @@ def _execute_standard( self.cudagraph = torch.cuda.CUDAGraph() if self._profile_execution: self.cudagraph.enable_debug_mode() - with torch.cuda.graph( - self.cudagraph, stream=self._engine_stream - ): - self.context.execute_async_v3( - self._engine_stream.cuda_stream - ) + with torch.cuda.graph(self.cudagraph, stream=engine_stream): + self.context.execute_async_v3(engine_stream.cuda_stream) if self._profile_execution: self.cudagraph.debug_dump( f"{DEBUG_LOGGING_DIR}/{self.name}_cudagraph.dot" ) self.cudagraph.replay() # type: ignore[union-attr] else: - self.context.execute_async_v3(self._engine_stream.cuda_stream) + self.context.execute_async_v3(engine_stream.cuda_stream) if caller_on_default: - self._caller_stream.wait_stream(self._engine_stream) + caller_stream.wait_stream(engine_stream) if self.use_pre_allocated_outputs and ( self.output_tensors_are_unowned @@ -1040,14 +1047,17 @@ def _execute_output_allocator( ) caller_on_default = self._prepare_streams(contiguous_inputs) + engine_stream = self._engine_stream + caller_stream = self._caller_stream + assert engine_stream is not None and caller_stream is not None with self._profile_section("TRTEngine:TensorRTRuntime"): if caller_on_default: - self._engine_stream.wait_stream(self._caller_stream) - with torch.cuda.stream(self._engine_stream): - self.context.execute_async_v3(self._engine_stream.cuda_stream) + engine_stream.wait_stream(caller_stream) + with torch.cuda.stream(engine_stream): + self.context.execute_async_v3(engine_stream.cuda_stream) if caller_on_default: - self._caller_stream.wait_stream(self._engine_stream) + caller_stream.wait_stream(engine_stream) outputs = [] assert self.output_allocator is not None From 4deae4223501a842c891ff4a17af9fd76063171a Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 12 Jun 2026 23:52:28 +0000 Subject: [PATCH 2/2] removed weight name map functionality for refit --- .../tutorials/runtime_opt/python_runtime.rst | 4 - docsrc/tutorials/weight_refit/refit.rst | 7 - examples/dynamo/refit_engine_example.py | 16 +- py/torch_tensorrt/dynamo/_engine_cache.py | 11 +- py/torch_tensorrt/dynamo/_refit.py | 240 ++---------------- .../dynamo/conversion/_TRTInterpreter.py | 220 ---------------- .../dynamo/conversion/_conversion.py | 7 - .../runtime/_MutableTorchTensorRTModule.py | 1 - .../dynamo/runtime/_TRTEngine.py | 1 - .../dynamo/runtime/_TorchTensorRTModule.py | 6 - tests/py/dynamo/models/test_model_refit.py | 126 --------- 11 files changed, 30 insertions(+), 609 deletions(-) diff --git a/docsrc/tutorials/runtime_opt/python_runtime.rst b/docsrc/tutorials/runtime_opt/python_runtime.rst index 3f78bb05e5..0e59e3831f 100644 --- a/docsrc/tutorials/runtime_opt/python_runtime.rst +++ b/docsrc/tutorials/runtime_opt/python_runtime.rst @@ -130,10 +130,6 @@ produced outside Torch-TensorRT): ``settings`` (:class:`~torch_tensorrt.dynamo._settings.CompilationSettings`, optional) Device and runtime options (must match how the engine was built). -``weight_name_map`` (``dict``, optional) - Mapping of TRT weight names to PyTorch state dict names. Required for refit - support via :func:`~torch_tensorrt.dynamo.refit_module_weights`. - ``requires_output_allocator`` (``bool``, default ``False``) Set to ``True`` if the engine contains data-dependent-shape ops (``nonzero``, ``unique``, etc.) that require TRT's output allocator. diff --git a/docsrc/tutorials/weight_refit/refit.rst b/docsrc/tutorials/weight_refit/refit.rst index 62c464cf91..26d71a178c 100644 --- a/docsrc/tutorials/weight_refit/refit.rst +++ b/docsrc/tutorials/weight_refit/refit.rst @@ -88,7 +88,6 @@ API arg_inputs=None, kwarg_inputs=None, verify_output=False, - use_weight_map_cache=True, in_place=False, ) @@ -114,12 +113,6 @@ API PyTorch on the provided sample inputs. Useful for catching silent refit failures during development. -``use_weight_map_cache`` (``bool``, default ``True``) - When torch-tensorrt programs are compiled, the TRTIntpereter builds a map of which - exported program nodes correspond to which TensorRT layers. This mapping is stored as metadata in serialized - torch-tensorrt programs. This cache is not gaurenteed to be an exact match but to a new - unseen exported program but when it does, it reduces refit time by ~50%. - ``in_place`` (``bool``, default ``False``) If ``True``, modify the compiled module in-place rather than returning a copy. Not supported for ``ExportedProgram`` inputs (use the returned module instead). diff --git a/examples/dynamo/refit_engine_example.py b/examples/dynamo/refit_engine_example.py index 00bdbd0029..087e186673 100644 --- a/examples/dynamo/refit_engine_example.py +++ b/examples/dynamo/refit_engine_example.py @@ -114,21 +114,15 @@ # # There are a number of settings you can use to control the refit process # -# Weight Map Cache +# Output Verification # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # # Weight refitting works by matching the weights of the compiled module with the new weights from -# the user supplied ExportedProgram. Since 1:1 name matching from PyTorch to TensorRT is hard to accomplish, -# the only gaurenteed way to match weights at *refit-time* is to pass the new ExportedProgram through the -# early phases of the compilation process to generate near identical weight names. This can be expensive -# and is not always necessary. +# the user supplied ExportedProgram. To do this, the new ExportedProgram is passed through the early +# phases of the compilation process to generate near identical weight names, which are then used to +# refit the existing TensorRT engine in place without rebuilding it. # -# To avoid this, **At initial compile**, Torch-TensorRt will attempt to cache a direct mapping from PyTorch -# weights to TensorRT weights. This cache is stored in the compiled module as metadata and can be used -# to speed up refit. If the cache is not present, the refit system will fallback to rebuilding the mapping at -# refit-time. Use of this cache is controlled by the ``use_weight_map_cache`` parameter. -# -# Since the cache uses a heuristic based system for matching PyTorch and TensorRT weights, you may want to verify the refitting. This can be done by setting +# You may want to verify the refitting. This can be done by setting # ``verify_output`` to True and providing sample ``arg_inputs`` and ``kwarg_inputs``. When this is done, the refit # system will run the refitted module and the user supplied module on the same inputs and compare the outputs. # diff --git a/py/torch_tensorrt/dynamo/_engine_cache.py b/py/torch_tensorrt/dynamo/_engine_cache.py index 141649fcd9..215c113697 100644 --- a/py/torch_tensorrt/dynamo/_engine_cache.py +++ b/py/torch_tensorrt/dynamo/_engine_cache.py @@ -24,7 +24,6 @@ List[str], Sequence[Input], CompilationSettings, - Optional[Dict[str, Any]], bool, bool, ] @@ -108,11 +107,10 @@ def pack( output_names: List[str], input_specs: Sequence[Input], compilation_settings: CompilationSettings, - weight_name_map: Optional[Dict[Any, Any]], requires_output_allocator: bool, requires_native_multidevice: bool, ) -> bytes: - """Pack serialized engine, input names, output names, and weight map into a single blob + """Pack serialized engine, input names, and output names into a single blob Args: serialized_engine (bytes): serialized TRT engine @@ -120,7 +118,6 @@ def pack( output_names (List[str]): output names of TRT engine input_specs (Sequence[Input]): input specs of TRT engine compilation_settings (CompilationSettings): compilation settings of TRT engine - weight_name_map (Optional[Dict[Any, Any]]): weight name map for refitting requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators) requires_native_multidevice (bool): Boolean flag indicating if the converter creates operators which require multiple devices to run (e.g. multi-device collective operations) Returns: @@ -135,7 +132,6 @@ def pack( "output_names": output_names, "input_specs": input_specs, "compilation_settings": settings, - "weight_name_map": weight_name_map, "requires_output_allocator": requires_output_allocator, "requires_native_multidevice": requires_native_multidevice, } @@ -143,13 +139,13 @@ def pack( @staticmethod def unpack(packed_obj: bytes) -> UnpackedCacheHit: - """Unpack packed blob into serialized engine, input names, output names, and weight map + """Unpack packed blob into serialized engine, input names, and output names Args: packed_obj (bytes): packed blob Returns: - Tuple[bytes, List[str], List[str], Sequence[Input], CompilationSettings, Optional[Dict[str, Any]]]: serialized engine, input names, output names, input specs, CompilationSettings, weight name map + Tuple[bytes, List[str], List[str], Sequence[Input], CompilationSettings, bool, bool]: serialized engine, input names, output names, input specs, CompilationSettings, requires_output_allocator, requires_native_multidevice """ unpacked = pickle.loads(packed_obj) return ( @@ -158,7 +154,6 @@ def unpack(packed_obj: bytes) -> UnpackedCacheHit: unpacked["output_names"], unpacked["input_specs"], unpacked["compilation_settings"], - unpacked["weight_name_map"], unpacked["requires_output_allocator"], unpacked.get("requires_native_multidevice", False), ) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 06e1b6dbd4..f6589f26b6 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -4,7 +4,7 @@ import copy import gc import logging -from typing import Any, List, Optional, Sequence, Tuple +from typing import Any, Optional, Sequence, Tuple import numpy as np import tensorrt as trt @@ -12,7 +12,7 @@ from torch.export import ExportedProgram from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._enums import dtype -from torch_tensorrt._features import ENABLED_FEATURES, needs_refit +from torch_tensorrt._features import needs_refit from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import partitioning from torch_tensorrt.dynamo._exporter import inline_torch_modules @@ -22,9 +22,6 @@ DYNAMO_CONVERTERS as CONVERTERS, ) from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter -from torch_tensorrt.dynamo.conversion.impl.normalization.ops import ( - batch_norm_constant_folding, -) from torch_tensorrt.dynamo.conversion.truncate_double import repair_double_inputs from torch_tensorrt.dynamo.lowering import ( clean_up_graph_after_modifications, @@ -41,7 +38,6 @@ from torch_tensorrt.dynamo.utils import ( check_module_output, check_output_equal, - get_model_device, get_torch_inputs, to_torch_device, to_torch_tensorrt_device, @@ -84,69 +80,12 @@ def construct_refit_mapping( return weight_refit_map -@needs_refit # type: ignore[misc] -def construct_refit_mapping_from_weight_name_map( - weight_name_map: dict[Any, Any], - state_dict: dict[Any, Any], - settings: CompilationSettings, -) -> dict[Any, Any]: - engine_weight_map = {} - for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items(): - # Add more constant folding converters here - trt_dtype = dtype._from(np_weight_type).to(trt.DataType) - torch_dtype = dtype._from(np_weight_type).to(torch.dtype) - if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]: - # Batch Norm Layer - params = {} - for w in sd_weight_name: - params[w.split(".")[-1]] = state_dict[w].cuda() - # Batch norm constant folding - - scale, shift = batch_norm_constant_folding(**params, eps=1e-5) - # Set scale to scale or shift to shift - engine_weight_map[engine_weight_name] = eval( - engine_weight_name.split(" ")[-1].lower() - ) - - elif isinstance(sd_weight_name, tuple): - # Buffer-slice mapping created by Stage 3 of _save_weight_mapping. - # Encodes (state_dict_key, dim, index) for weights that are slices - # of a source buffer (e.g. real/imag parts of an unpacked complex buffer). - sd_key, dim, idx = sd_weight_name - if sd_key not in state_dict: - continue - engine_weight_map[engine_weight_name] = ( - state_dict[sd_key].select(dim, idx).to(to_torch_device(settings.device)) - ) - - elif sd_weight_name not in state_dict: - # If weights is not in sd, we can leave it unchanged - continue - else: - - engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to( - to_torch_device(settings.device) - ) - - engine_weight_map[engine_weight_name] = ( - engine_weight_map[engine_weight_name] - .clone() - .reshape(-1) - .contiguous() - .to(torch_dtype), - trt_dtype, - ) - - return engine_weight_map - - @needs_refit # type: ignore[misc] def _refit_single_trt_engine_with_gm( new_gm: torch.fx.GraphModule, old_engine: trt.ICudaEngine, input_list: Sequence[Any], settings: CompilationSettings = CompilationSettings(), - weight_name_map: Optional[dict[str, List[str]]] = None, ) -> None: """ Refit a TensorRT Engine in place @@ -154,84 +93,25 @@ def _refit_single_trt_engine_with_gm( with unset_fake_temporarily(): refitted = set() - torch_device = get_model_device(new_gm) refitter = trt.Refitter(old_engine, TRT_LOGGER) weight_list = refitter.get_all_weights() - if weight_name_map: - # Get the refitting mapping - trt_wt_location = ( - trt.TensorLocation.DEVICE - if torch_device.type == "cuda" - else trt.TensorLocation.HOST - ) - - constant_mapping: dict[str, Any] = weight_name_map.pop( - "constant_mapping", {} - ) # type: ignore - mapping = construct_refit_mapping_from_weight_name_map( - weight_name_map, new_gm.state_dict(), settings + mapping = construct_refit_mapping(new_gm, input_list, settings) + trt_wt_location = trt.TensorLocation.HOST + for layer_name in weight_list: + if layer_name not in mapping: + raise AssertionError(f"{layer_name} is not found in weight mapping") + # Use Tensor to create weights + weight = mapping[layer_name] + trt_dtype = dtype._from(weight.dtype).to(trt.DataType) + trt_wt_tensor = trt.Weights( + trt_dtype, weight.data_ptr(), torch.numel(weight) ) - constant_mapping_with_type = {} - - for constant_name, val in constant_mapping.items(): - weight_dtype = val.dtype - val_tensor = val.cuda() - trt_dtype = dtype._from(weight_dtype).to(trt.DataType) - torch_dtype = dtype._from(weight_dtype).to(torch.dtype) - constant_mapping_with_type[constant_name] = ( - val_tensor.clone().reshape(-1).contiguous().to(torch_dtype), - trt_dtype, - ) + refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) + refitted.add(layer_name) - mapping.update(constant_mapping_with_type) - - for layer_name in weight_list: - if layer_name not in mapping: - logger.warning(f"{layer_name} is not found in weight mapping.") - continue - # Use Numpy to create weights - weight, weight_dtype = mapping[layer_name] - trt_wt_tensor = trt.Weights( - weight_dtype, weight.data_ptr(), torch.numel(weight) - ) - refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) - # get_missing_weights(): reports weights in connected engines - # that were not set. - missing_weights = refitter.get_missing_weights() - assert len(missing_weights) == 0, ( - f"Fast refit failed: refitter.get_missing_weights() reports " - f"{len(missing_weights)} of {len(weight_list)} engine weight(s) " - f"were never set." - ) - if ENABLED_FEATURES.tensorrt_rtx: - # Compare weights actually set vs all engine weights: catches - # weights in independent engines that get_missing_weights() may not report. - unset_weights = {w for w in weight_list if w not in mapping} - assert len(unset_weights) == 0, ( - f"Fast refit failed on TensorRT-RTX: {len(unset_weights)} of " - f"{len(weight_list)} engine weight(s) had no entry in " - f"weight_name_map. " - f"Unset (showing up to 5): {sorted(unset_weights)[:5]}" - ) - - else: - mapping = construct_refit_mapping(new_gm, input_list, settings) - trt_wt_location = trt.TensorLocation.HOST - for layer_name in weight_list: - if layer_name not in mapping: - raise AssertionError(f"{layer_name} is not found in weight mapping") - # Use Tensor to create weights - weight = mapping[layer_name] - trt_dtype = dtype._from(weight.dtype).to(trt.DataType) - trt_wt_tensor = trt.Weights( - trt_dtype, weight.data_ptr(), torch.numel(weight) - ) - refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) - refitted.add(layer_name) - - if len(refitted) != len(weight_list): - logger.warning("Not all weights have been refitted!!!") + if len(refitted) != len(weight_list): + logger.warning("Not all weights have been refitted!!!") if not refitter.refit_cuda_engine(): logger.error("Error: failed to refit new weights.") @@ -245,7 +125,6 @@ def refit_module_weights( arg_inputs: Optional[Tuple[Any, ...]] = None, kwarg_inputs: Optional[dict[str, Any]] = None, verify_output: bool = False, - use_weight_map_cache: bool = True, in_place: bool = False, ) -> torch.fx.GraphModule: """ @@ -467,7 +346,6 @@ def refit_module_weights( # Extract engine from the submodule try: if inline_module: - weight_name_map = None compiled_submodule = compiled_submodules_map[name] # If this is a torch module, load the old state_dict if "_run_on_acc" not in name: @@ -478,59 +356,12 @@ def refit_module_weights( engine = get_engine_from_encoded_engine( engine_info[ENGINE_IDX], runtime ) - if use_weight_map_cache: - encoded_metadata = compiled_submodule.__getstate__()[0][ - SERIALIZED_METADATA_IDX - ] - weight_name_map = TorchTensorRTModule.decode_metadata( - encoded_metadata - )["weight_name_map"] - if not weight_name_map: - use_weight_map_cache = False - logger.warning( - "This engine does not have a weight map cache. Rebuilding the weight map" - ) else: compiled_submodule = getattr(compiled_module, name) if "_run_on_acc" not in name: compiled_submodule.load_state_dict(new_submodule.state_dict()) continue - weight_name_map = None - if use_weight_map_cache: - try: - weight_name_map = compiled_submodule.weight_name_map - except AttributeError: - if isinstance(compiled_submodule, torch.nn.Module): - # Torch retrace module - assert not isinstance( - compiled_submodule.engine, - TRTEngine, - ), ( - "Refitting a torch retraced module is only supported when " - "the engine uses the C++ Torch-TensorRT runtime" - ) - encoded_metadata = [ - engine - for name, engine in compiled_submodules - if name == "engine" - ][0].__getstate__()[0][SERIALIZED_METADATA_IDX] - weight_name_map = TorchTensorRTModule.decode_metadata( - encoded_metadata - )["weight_name_map"] - - if not isinstance( - compiled_submodule, torch.fx.graph_module.GraphModule - ): - logger.warning( - "The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map." - ) - if not weight_name_map: - use_weight_map_cache = False - logger.warning( - "This engine does not have a weight map cache. Rebuilding the weight map" - ) - # Rexporting the TRT compiled graph module and loading it back doesn't preserve # the instance type; choose the engine handle based on the actual engine object. if isinstance(compiled_submodule.engine, TRTEngine): @@ -561,26 +392,12 @@ def refit_module_weights( to_torch_device(settings.device), name, ) - try: - _refit_single_trt_engine_with_gm( - new_gm=new_submodule, - old_engine=engine, - input_list=submodule_inputs, - settings=settings, - weight_name_map=weight_name_map, - ) - - except AssertionError as e: - # If fast_refit is used and failed, we fall back to regular refit - logger.warning(e) - if use_weight_map_cache and weight_name_map: - _refit_single_trt_engine_with_gm( - new_gm=new_submodule, - old_engine=engine, - input_list=submodule_inputs, - settings=settings, - weight_name_map=None, - ) + _refit_single_trt_engine_with_gm( + new_gm=new_submodule, + old_engine=engine, + input_list=submodule_inputs, + settings=settings, + ) # clear EXCLUDE_WEIGHTS flag and set INCLUDE_REFIT flag to make the engine refittable serialization_config = engine.create_serialization_config() @@ -646,19 +463,6 @@ def refit_module_weights( if outputs_match: logger.info("Refitting Succeed!") else: - if weight_name_map: - logger.warning( - "Refitting with weight_name_map yielded incorrect result! The outputs do not match." - ) - return refit_module_weights( - compiled_module, - new_weight_module, - arg_inputs, - kwarg_inputs, - verify_output, - use_weight_map_cache=False, - in_place=in_place, - ) logger.error("Refitting Failed! The outputs do not match.") else: logger.info("Refitting Completed! Output verification skipped.") diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 1b7982f074..984ed6bfdf 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -1,4 +1,3 @@ -import gc import logging import os import warnings @@ -13,7 +12,6 @@ Sequence, Set, Tuple, - Union, ) import numpy as np @@ -26,7 +24,6 @@ from torch.utils._python_dispatch import _disable_current_modes from torch_tensorrt import ENABLED_FEATURES from torch_tensorrt._enums import dtype -from torch_tensorrt._features import needs_refit from torch_tensorrt._Input import Input from torch_tensorrt._utils import is_tensorrt_version_supported from torch_tensorrt.dynamo._engine_cache import BaseEngineCache @@ -52,7 +49,6 @@ DYNAMIC_DIM, deallocate_module, get_cpu_memory_usage, - to_torch_device, ) from torch_tensorrt.logging import TRT_LOGGER @@ -71,7 +67,6 @@ class TRTInterpreterResult(NamedTuple): engine: trt.ICudaEngine input_names: List[str] output_names: List[str] - weight_name_map: Optional[dict[Any, Any]] requires_output_allocator: bool requires_native_multidevice: bool @@ -146,7 +141,6 @@ def __init__( # Mapping of constants to shapes and dtypes self.const_mapping: Dict[str, Tuple[Sequence[int], str]] = {} - self.weight_name_map: Optional[Dict[str, Any]] = None # Engine cache for storing and reusing TRT engines self.engine_cache = engine_cache @@ -386,216 +380,6 @@ def _construct_trt_network_def(self) -> None: f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}" ) - @staticmethod - def find_weight( - weight_name: str, - weight_refit_map: dict[str, Any], - state_dict: dict[str, Any], - device: torch.device, - ) -> str: - """ - We need to build map from engine weight name to state_dict weight name. - The purpose of this function is to find the corresponding weight name in module state_dict. - - weight_name: the target weight name we want to search for - np_map: the map from weight name to np values in INetworkDefinition - state_dict: state of the graph module - """ - with unset_fake_temporarily(): - network_weight = weight_refit_map[weight_name].to(device) - for sd_w_name, sd_weight in state_dict.items(): - if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device): - del state_dict[sd_w_name] - return sd_w_name - return "" - - @staticmethod - def check_weight_equal( - sd_weight: torch.tensor, - network_weight: Union[torch.Tensor, np.ndarray], - device: torch.device, - ) -> Any: - with unset_fake_temporarily(): - if network_weight.device != device: - network_weight = network_weight.to(device) - try: - return sd_weight.shape == network_weight.shape and torch.all( - torch.abs(sd_weight - network_weight) < 0.01 - ) - except Exception: - return torch.all(sd_weight == network_weight) - - @needs_refit # type: ignore - def _save_weight_mapping(self) -> None: - """ - Construct the weight name mapping from engine weight name to state_dict weight name. - Cache the weight name for future refitting usecases. - Two-stage weight name tracing: - 1. Name transformation from engine weight name to state_dict weight name - 2. Value mapping that, for each weight in INetworkDefinition search for identical weight in state_dict - """ - - MODULE_MAP = { - "SCALE": ( - trt.IScaleLayer, - [ - ( - "scale", - "SCALE", - ("weight", "bias", "running_mean", "running_var"), - ), - ( - "shift", - "SHIFT", - ("weight", "bias", "running_mean", "running_var"), - ), - ], - ), - "CONVOLUTION": ( - trt.IConvolutionLayer, - [("kernel", "KERNEL", "weight"), ("bias", "BIAS", "bias")], - ), - "DECONVOLUTION": ( - trt.IDeconvolutionLayer, - [("kernel", "KERNEL", "weight"), ("bias", "BIAS", "bias")], - ), - "CONSTANT": ( - trt.IConstantLayer, - [("weights", "CONSTANT", ("weight", "bias"))], - ), - } - """ - The structure of this map is: - { - layer_type: ( - Corresponding ILayer type to cast, - [ - ( - ILayer weight attribute, - Weight name postfix in TRT Engine, - Weight name postfix in state_dict - ), - ... - ] - ) - } - """ - _LOGGER.info("Building weight name mapping...") - # Stage 1: Name mapping - torch_device = to_torch_device(self.compilation_settings.device) - sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()} - weight_name_map: dict[str, Any] = {} - weight_refit_map = self.ctx.weight_refit_map - constant_mapping = {k: v for k, v in weight_refit_map.items() if v.numel() == 1} - net = self.ctx.net - for i in range(net.num_layers): - layer = net[i] - layer_type: str = layer.type.name - if layer_type in MODULE_MAP: - # Name mapping - for weight_type, weight_name, torch_attr in MODULE_MAP[layer_type][1]: - engine_weight_name = f"{layer.name} {weight_name}" - # Infer the corresponding weight name(s) in state_dict - sd_weight_name_list = ( - layer.name.split("-")[-1] - .replace("[", "") - .replace("]", "") - .split("/") - ) - sd_weight_name: Any = ".".join( - [i for i in sd_weight_name_list[:-1] if i] - ) - suffix = sd_weight_name_list[-1] - # Retrieve each weight name(s) in state_dict - if layer_type == "CONSTANT": - if ( - "embedding" in suffix - or "weight" in suffix - or "mm_other" in suffix - ): - sd_weight_name = f"{sd_weight_name}.weight" - elif "running_mean" in suffix: - sd_weight_name = f"{sd_weight_name}.running_mean" - elif "running_var" in suffix: - sd_weight_name = f"{sd_weight_name}.running_var" - elif "bias" in suffix: - sd_weight_name = f"{sd_weight_name}.bias" - else: - sd_weight_name = f"{sd_weight_name}.unknown" - elif layer_type == "SCALE": - # Batch norm needs all weights to calculate scale and shift - sd_weight_name = [f"{sd_weight_name}.{n}" for n in torch_attr] - else: - sd_weight_name = f"{sd_weight_name}.{torch_attr}" - - if engine_weight_name in weight_refit_map: - weight_name_map[engine_weight_name] = sd_weight_name - - # Stage 2: Value mapping - for engine_weight_name, sd_weight_name in weight_name_map.items(): - if "SCALE" in engine_weight_name: - # There is no direct connection in batch_norm layer. So skip it - pass - elif sd_weight_name not in sd or not TRTInterpreter.check_weight_equal( - sd[sd_weight_name], weight_refit_map[engine_weight_name], torch_device - ): - weight_name_map[engine_weight_name] = TRTInterpreter.find_weight( - engine_weight_name, weight_refit_map, sd, torch_device - ) - if ( - weight_name_map[engine_weight_name] != "" - and engine_weight_name in constant_mapping - ): - # If the weight is found in state_dict, remove it from constant_mapping - del constant_mapping[engine_weight_name] - - weight_name_map[engine_weight_name] = [ - weight_name_map[engine_weight_name], - weight_refit_map[engine_weight_name].dtype, - ] - - # Stage 3: Slice matching for unmatched non-scalar CONSTANT weights. - # complex_graph_detection unpacks complex buffers to real: - # freqs (S,D complex64) → freqs_unpacked_complex (S,D,2 float32) - # The real and imag slices (freqs_unpacked_complex[...,0] and [...,1]) are - # embedded as separate TRT constants, but their shapes differ from the source - # buffer, so Stage 2 value matching fails. Here we try selecting each slice - # along the last dimension of every sd entry to find the match. - for engine_weight_name, val in list(weight_name_map.items()): - if not isinstance(val, list) or len(val) != 2: - continue - sd_weight_name, dtype_val = val - if sd_weight_name != "" or engine_weight_name not in weight_refit_map: - continue - ew_tensor = weight_refit_map[engine_weight_name].to(torch_device) - if ew_tensor.numel() <= 1: - continue # scalars are handled via constant_mapping - matched = False - for sd_key, sd_tensor in sd.items(): - if sd_tensor.dim() < 1 or sd_tensor.shape[-1] < 2: - continue - last_dim = sd_tensor.dim() - 1 - for idx in range(sd_tensor.shape[last_dim]): - sd_slice = sd_tensor.select(last_dim, idx) - if TRTInterpreter.check_weight_equal( - sd_slice, ew_tensor, torch_device - ): - weight_name_map[engine_weight_name] = [ - (sd_key, last_dim, idx), - dtype_val, - ] - matched = True - break - if matched: - break - - weight_name_map["constant_mapping"] = constant_mapping - self.weight_name_map = weight_name_map - - del weight_refit_map, sd - gc.collect() - torch.cuda.empty_cache() - def run( self, strict_type_constraints: bool = False, @@ -614,9 +398,6 @@ def run( f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB" ) - if not self.compilation_settings.immutable_weights: - self._save_weight_mapping() - if self.compilation_settings.offload_module_to_cpu: deallocate_module(self.module) @@ -669,7 +450,6 @@ def run( cuda_engine, self._input_names, self._output_names, - self.weight_name_map, self.ctx.requires_output_allocator, self.ctx.requires_native_multidevice, ) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index d712d7f150..1a2fd6ca2f 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -33,7 +33,6 @@ class SerializedInterpreterResult(NamedTuple): serialized_engine: bytes input_names: List[str] output_names: List[str] - weight_name_map: Optional[dict[Any, Any]] requires_output_allocator: bool symbolic_shape_expressions: Dict[str, List[Dict[str, Any]]] requires_native_multidevice: bool @@ -92,7 +91,6 @@ def insert_engine_to_cache( interpreter_result.output_names, inputs, settings, - interpreter_result.weight_name_map, interpreter_result.requires_output_allocator, interpreter_result.requires_native_multidevice, ), @@ -130,7 +128,6 @@ def pull_cached_engine( output_names, cached_engine_inputs, cached_engine_compilation_settings, - weight_name_map, requires_output_allocator, requires_native_multidevice, ) = cached_data @@ -170,7 +167,6 @@ def pull_cached_engine( old_engine=engine, input_list=inputs, settings=settings, - weight_name_map=weight_name_map, ) serialization_config = engine.create_serialization_config() @@ -189,7 +185,6 @@ def pull_cached_engine( serialized_engine=serialized_engine, input_names=input_names, output_names=output_names, - weight_name_map=weight_name_map, requires_output_allocator=requires_output_allocator, requires_native_multidevice=requires_native_multidevice, symbolic_shape_expressions=symbolic_shape_expressions, @@ -318,7 +313,6 @@ def interpret_module_to_result( serialized_engine=serialized_engine, input_names=interpreter_result.input_names, output_names=interpreter_result.output_names, - weight_name_map=interpreter_result.weight_name_map, requires_output_allocator=interpreter_result.requires_output_allocator, requires_native_multidevice=interpreter_result.requires_native_multidevice, symbolic_shape_expressions=symbolic_shape_expressions, @@ -375,7 +369,6 @@ def convert_module( output_binding_names=list(serialized_interpreter_result.output_names), name=name, settings=settings, - weight_name_map=serialized_interpreter_result.weight_name_map, requires_output_allocator=serialized_interpreter_result.requires_output_allocator, requires_native_multidevice=serialized_interpreter_result.requires_native_multidevice, symbolic_shape_expressions=serialized_interpreter_result.symbolic_shape_expressions, diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index dc542363ae..b36d239edb 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -308,7 +308,6 @@ def refit_gm(self) -> None: self.exp_program, self.arg_inputs, self.kwarg_inputs, - use_weight_map_cache=True, in_place=True, ) diff --git a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py index d2ec9e702d..fb28f16118 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py +++ b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py @@ -373,7 +373,6 @@ def _load_serialized_info( metadata = self.decode_metadata(self.serialized_metadata) self.settings = metadata.get("settings", CompilationSettings()) - self.weight_name_map = metadata.get("weight_name_map") self.symbolic_shape_expressions = metadata.get("inout_symexprs") self.output_tensors_are_unowned = metadata.get( "output_tensors_are_unowned", False diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 1d83bd646f..25df31dc7f 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -61,7 +61,6 @@ def __init__( *, name: str = "", settings: CompilationSettings = CompilationSettings(), - weight_name_map: Optional[dict[Any, Any]] = None, requires_output_allocator: bool = False, requires_native_multidevice: bool = False, symbolic_shape_expressions: Optional[Dict[str, List[Dict[str, Any]]]] = None, @@ -82,7 +81,6 @@ def __init__( Keyword Arguments: name (str): Name for module settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed - weight_name_map (dict): Mapping of engine weight name to state_dict weight name requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators) requires_native_multidevice (bool): Boolean flag indicating if the converter creates operators which require multiple devices to run (e.g. multi-device collective operations) symbolic_shape_expressions (List[Any]): List of symbolic shape expressions for each input binding @@ -109,7 +107,6 @@ def __init__( output_binding_names: Output tensor names in return order. name: Logical name for logging and serialization. settings: Compilation/runtime settings (device, lazy init, cross-compile, etc.). - weight_name_map: Engine weight name to ``state_dict`` key mapping (refit). requires_output_allocator: Engine needs TRT dynamic output allocation. symbolic_shape_expressions: Optional symbolic shape metadata from compile. """ @@ -124,7 +121,6 @@ def __init__( self.name = name self.hardware_compatible = settings.hardware_compatible self.settings = copy.deepcopy(settings) - self.weight_name_map = weight_name_map self.serialized_engine = serialized_engine self.engine: Optional[Any] = None self.requires_output_allocator = requires_output_allocator @@ -175,7 +171,6 @@ def _pack_engine_info(self) -> List[str | bytes]: ) metadata = { "settings": self.settings, - "weight_name_map": self.weight_name_map, "inout_symexprs": self.symbolic_shape_expressions, "output_tensors_are_unowned": ( False @@ -383,7 +378,6 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: assert isinstance(serialized_metadata, bytes) metadata = TorchTensorRTModule.decode_metadata(serialized_metadata) self.settings = metadata["settings"] - self.weight_name_map = metadata["weight_name_map"] self.symbolic_shape_expressions = metadata["inout_symexprs"] if ENABLED_FEATURES.torch_tensorrt_runtime: diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 2ef7c9ba3f..b024c627a2 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -125,7 +125,6 @@ def forward(self, x): compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, - use_weight_map_cache=True, verify_output=True, ) @@ -185,7 +184,6 @@ def forward(self, x): compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, - use_weight_map_cache=True, verify_output=True, ) @@ -245,7 +243,6 @@ def forward(self, x): compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, - use_weight_map_cache=False, verify_output=True, ) @@ -296,7 +293,6 @@ def test_refit_one_engine_with_weightmap(): compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, - use_weight_map_cache=True, verify_output=True, ) @@ -315,114 +311,6 @@ def test_refit_one_engine_with_weightmap(): torch._dynamo.reset() -@unittest.skipIf( - not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, - "TorchScript Frontend is not available", -) -@unittest.skipIf( - not torch_trt.ENABLED_FEATURES.refit, - "Refit feature is not supported in Python 3.13 or higher", -) -@unittest.skipIf( - not importlib.util.find_spec("torchvision"), - "torchvision is not installed", -) -@pytest.mark.unit -def test_refit_one_engine_no_map_with_weightmap(): - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") - inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] - min_block_size = 1 - exp_program = torch.export.export(model, tuple(inputs)) - exp_program2 = torch.export.export(model2, tuple(inputs)) - - trt_gm = torchtrt.dynamo.compile( - exp_program, - tuple(inputs), - min_block_size=min_block_size, - immutable_weights=False, - ) - - trt_gm._run_on_acc_0.weight_name_map = None - - new_trt_gm = refit_module_weights( - compiled_module=trt_gm, - new_weight_module=exp_program2, - arg_inputs=inputs, - use_weight_map_cache=True, - ) - - # Check the output - model2.to("cuda") - expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( - *inputs - ) - for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): - assertions.assertTrue( - torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), - "Refit Result is not correct. Refit failed", - ) - # Clean up model env - - torch._dynamo.reset() - - -@unittest.skipIf( - not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, - "TorchScript Frontend is not available", -) -@unittest.skipIf( - not torch_trt.ENABLED_FEATURES.refit, - "Refit feature is not supported in Python 3.13 or higher", -) -@unittest.skipIf( - not importlib.util.find_spec("torchvision"), - "torchvision is not installed", -) -@pytest.mark.unit -def test_refit_one_engine_with_wrong_weightmap(): - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") - inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] - min_block_size = 1 - exp_program = torch.export.export(model, tuple(inputs)) - exp_program2 = torch.export.export(model2, tuple(inputs)) - - trt_gm = torchtrt.dynamo.compile( - exp_program, - tuple(inputs), - min_block_size=min_block_size, - immutable_weights=False, - ) - # Manually Deleted all batch norm layer. This suppose to fail the fast refit - trt_gm._run_on_acc_0.weight_name_map = { - k: v - for k, v in trt_gm._run_on_acc_0.weight_name_map.items() - if "[SCALE]" not in k - } - - new_trt_gm = refit_module_weights( - compiled_module=trt_gm, - new_weight_module=exp_program2, - arg_inputs=inputs, - use_weight_map_cache=True, - ) - - # Check the output - model2.to("cuda") - expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( - *inputs - ) - for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): - assertions.assertTrue( - torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), - "Refit Result is not correct. Refit failed", - ) - # Clean up model env - - torch._dynamo.reset() - - @unittest.skipIf( not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", @@ -460,7 +348,6 @@ def test_refit_one_engine_bert_with_weightmap(): compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, - use_weight_map_cache=True, ) # Check the output @@ -522,7 +409,6 @@ def test_refit_one_engine_inline_runtime_with_weightmap(tmpdir): compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, - use_weight_map_cache=True, ) # Check the output @@ -573,7 +459,6 @@ def test_refit_one_engine_python_runtime_with_weightmap(): compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, - use_weight_map_cache=True, ) # Check the output @@ -642,7 +527,6 @@ def forward(self, x): compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, - use_weight_map_cache=True, ) # Check the output @@ -704,7 +588,6 @@ def forward(self, x): compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, - use_weight_map_cache=True, ) model2.cuda() # Check the output @@ -757,7 +640,6 @@ def test_refit_one_engine_without_weightmap(): compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, - use_weight_map_cache=False, ) # Check the output @@ -812,7 +694,6 @@ def test_refit_one_engine_bert_without_weightmap(): compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, - use_weight_map_cache=False, ) # Check the output @@ -872,7 +753,6 @@ def test_refit_one_engine_inline_runtime_without_weightmap(tmpdir): compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, - use_weight_map_cache=False, ) # Check the output @@ -922,7 +802,6 @@ def test_refit_one_engine_python_runtime_without_weightmap(): compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, - use_weight_map_cache=False, ) # Check the output @@ -991,7 +870,6 @@ def forward(self, x): compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, - use_weight_map_cache=False, ) # Check the output @@ -1051,7 +929,6 @@ def forward(self, x): compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, - use_weight_map_cache=True, ) model2.to("cuda") @@ -1124,7 +1001,6 @@ def make_freqs() -> torch.Tensor: compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, - use_weight_map_cache=True, verify_output=True, ) @@ -1195,7 +1071,6 @@ def make_freqs() -> torch.Tensor: compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, - use_weight_map_cache=True, verify_output=True, ) @@ -1264,7 +1139,6 @@ def make_freqs() -> torch.Tensor: compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, - use_weight_map_cache=True, verify_output=True, )