Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions docsrc/tutorials/runtime_opt/python_runtime.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,6 @@ produced outside Torch-TensorRT):
``settings`` (:class:`~torch_tensorrt.dynamo._settings.CompilationSettings`, optional)
Device and runtime options (must match how the engine was built).

``weight_name_map`` (``dict``, optional)
Mapping of TRT weight names to PyTorch state dict names. Required for refit
support via :func:`~torch_tensorrt.dynamo.refit_module_weights`.

``requires_output_allocator`` (``bool``, default ``False``)
Set to ``True`` if the engine contains data-dependent-shape ops (``nonzero``,
``unique``, etc.) that require TRT's output allocator.
Expand Down
7 changes: 0 additions & 7 deletions docsrc/tutorials/weight_refit/refit.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ API
arg_inputs=None,
kwarg_inputs=None,
verify_output=False,
use_weight_map_cache=True,
in_place=False,
)

Expand All @@ -114,12 +113,6 @@ API
PyTorch on the provided sample inputs. Useful for catching silent refit failures
during development.

``use_weight_map_cache`` (``bool``, default ``True``)
When torch-tensorrt programs are compiled, the TRTIntpereter builds a map of which
exported program nodes correspond to which TensorRT layers. This mapping is stored as metadata in serialized
torch-tensorrt programs. This cache is not gaurenteed to be an exact match but to a new
unseen exported program but when it does, it reduces refit time by ~50%.

``in_place`` (``bool``, default ``False``)
If ``True``, modify the compiled module in-place rather than returning a copy.
Not supported for ``ExportedProgram`` inputs (use the returned module instead).
Expand Down
16 changes: 5 additions & 11 deletions examples/dynamo/refit_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,21 +114,15 @@
#
# There are a number of settings you can use to control the refit process
#
# Weight Map Cache
# Output Verification
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Weight refitting works by matching the weights of the compiled module with the new weights from
# the user supplied ExportedProgram. Since 1:1 name matching from PyTorch to TensorRT is hard to accomplish,
# the only gaurenteed way to match weights at *refit-time* is to pass the new ExportedProgram through the
# early phases of the compilation process to generate near identical weight names. This can be expensive
# and is not always necessary.
# the user supplied ExportedProgram. To do this, the new ExportedProgram is passed through the early
# phases of the compilation process to generate near identical weight names, which are then used to
# refit the existing TensorRT engine in place without rebuilding it.
#
# To avoid this, **At initial compile**, Torch-TensorRt will attempt to cache a direct mapping from PyTorch
# weights to TensorRT weights. This cache is stored in the compiled module as metadata and can be used
# to speed up refit. If the cache is not present, the refit system will fallback to rebuilding the mapping at
# refit-time. Use of this cache is controlled by the ``use_weight_map_cache`` parameter.
#
# Since the cache uses a heuristic based system for matching PyTorch and TensorRT weights, you may want to verify the refitting. This can be done by setting
# You may want to verify the refitting. This can be done by setting
# ``verify_output`` to True and providing sample ``arg_inputs`` and ``kwarg_inputs``. When this is done, the refit
# system will run the refitted module and the user supplied module on the same inputs and compare the outputs.
#
Expand Down
11 changes: 3 additions & 8 deletions py/torch_tensorrt/dynamo/_engine_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
List[str],
Sequence[Input],
CompilationSettings,
Optional[Dict[str, Any]],
bool,
bool,
]
Expand Down Expand Up @@ -108,19 +107,17 @@ def pack(
output_names: List[str],
input_specs: Sequence[Input],
compilation_settings: CompilationSettings,
weight_name_map: Optional[Dict[Any, Any]],
requires_output_allocator: bool,
requires_native_multidevice: bool,
) -> bytes:
"""Pack serialized engine, input names, output names, and weight map into a single blob
"""Pack serialized engine, input names, and output names into a single blob

Args:
serialized_engine (bytes): serialized TRT engine
input_names (List[str]): input names of TRT engine
output_names (List[str]): output names of TRT engine
input_specs (Sequence[Input]): input specs of TRT engine
compilation_settings (CompilationSettings): compilation settings of TRT engine
weight_name_map (Optional[Dict[Any, Any]]): weight name map for refitting
requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators)
requires_native_multidevice (bool): Boolean flag indicating if the converter creates operators which require multiple devices to run (e.g. multi-device collective operations)
Returns:
Expand All @@ -135,21 +132,20 @@ def pack(
"output_names": output_names,
"input_specs": input_specs,
"compilation_settings": settings,
"weight_name_map": weight_name_map,
"requires_output_allocator": requires_output_allocator,
"requires_native_multidevice": requires_native_multidevice,
}
)

@staticmethod
def unpack(packed_obj: bytes) -> UnpackedCacheHit:
"""Unpack packed blob into serialized engine, input names, output names, and weight map
"""Unpack packed blob into serialized engine, input names, and output names

Args:
packed_obj (bytes): packed blob

Returns:
Tuple[bytes, List[str], List[str], Sequence[Input], CompilationSettings, Optional[Dict[str, Any]]]: serialized engine, input names, output names, input specs, CompilationSettings, weight name map
Tuple[bytes, List[str], List[str], Sequence[Input], CompilationSettings, bool, bool]: serialized engine, input names, output names, input specs, CompilationSettings, requires_output_allocator, requires_native_multidevice
"""
unpacked = pickle.loads(packed_obj)
return (
Expand All @@ -158,7 +154,6 @@ def unpack(packed_obj: bytes) -> UnpackedCacheHit:
unpacked["output_names"],
unpacked["input_specs"],
unpacked["compilation_settings"],
unpacked["weight_name_map"],
unpacked["requires_output_allocator"],
unpacked.get("requires_native_multidevice", False),
)
Expand Down
Loading
Loading