Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions py/torch_tensorrt/_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar

import tensorrt

from packaging import version
from torch_tensorrt._utils import (
check_cross_compile_trt_win_lib,
check_native_trt_collectives,
load_tensorrt_llm_for_nccl,
sanitized_torch_version,
)

from packaging import version

FeatureSet = namedtuple(
"FeatureSet",
[
Expand Down Expand Up @@ -230,6 +230,18 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
return wrapper


def needs_native_collectives(f: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
if ENABLED_FEATURES.native_trt_collectives:
return f(*args, **kwargs)
else:
raise NotImplementedError(
"TensorRT 11+ is required for native NCCL collectives"
)

return wrapper


def for_all_methods(
decorator: Callable[..., Any], exclude: Optional[List[str]] = None
) -> Callable[..., Any]:
Expand Down
62 changes: 58 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
tensorrt_fused_nccl_all_gather_op,
tensorrt_fused_nccl_all_reduce_op,
tensorrt_fused_nccl_all_to_all_op,
tensorrt_fused_nccl_gather_op,
tensorrt_fused_nccl_reduce_scatter_op,
tensorrt_fused_nccl_scatter_op,
)

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand All @@ -27,15 +30,15 @@
@dynamo_tensorrt_converter(
tensorrt_fused_nccl_all_gather_op, requires_native_multidevice=True
)
def fused_nccl_gather(
def fused_nccl_all_gather(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
"""All-gather using native TensorRT DistCollective API"""
return impl.nccl_ops.nccl_gather_native(
return impl.nccl_ops.nccl_all_gather_native(
ctx,
target,
SourceIR.ATEN,
Expand Down Expand Up @@ -86,6 +89,57 @@ def fused_nccl_all_reduce(
reduce_op=reduce_op,
)

@dynamo_tensorrt_converter(
tensorrt_fused_nccl_all_to_all_op, requires_native_multidevice=True
)
def fused_nccl_all_to_all(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
"""All-to-all using native TensorRT DistCollective API."""
return impl.nccl_ops.nccl_all_to_all_native(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
)

@dynamo_tensorrt_converter(
tensorrt_fused_nccl_scatter_op, requires_native_multidevice=True
)
def fused_nccl_scatter(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
"""Scatter using native TensorRT DistCollective API."""
root = args[1] if len(args) > 1 else 0
return impl.nccl_ops.nccl_scatter_native(
ctx, target, SourceIR.ATEN, name, [args[0]], root=root
)

@dynamo_tensorrt_converter(
tensorrt_fused_nccl_gather_op, requires_native_multidevice=True
)
def fused_nccl_gather(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
"""Gather using native TensorRT DistCollective API."""
root = args[1] if len(args) > 1 else 0
return impl.nccl_ops.nccl_gather_native(
ctx, target, SourceIR.ATEN, name, [args[0]], root=root
)


# Conditionally register NCCL converters only if TensorRT-LLM plugin is available.
# We use an `if` statement instead of @needs_trtllm_for_nccl decorator because
Expand All @@ -101,14 +155,14 @@ def fused_nccl_all_reduce(
)

@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
def fused_nccl_gather(
def fused_nccl_all_gather(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.nccl_ops.nccl_gather(
return impl.nccl_ops.nccl_all_gather(
ctx,
target,
SourceIR.ATEN,
Expand Down
Loading
Loading