From b7f2deeae9a85c5cd7addb4e8a24c3a119319ae1 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 16 Oct 2025 12:45:11 -0700 Subject: [PATCH 1/3] addressing cat empty tensor case.Fixes gpt2 data distributed example --- .../data_parallel_stable_diffusion.py | 2 - .../dynamo/conversion/aten_ops_converters.py | 12 ++++- .../dynamo/conversion/impl/cat.py | 10 ++++ tests/py/dynamo/conversion/test_cat_aten.py | 54 +++++++++++++++++++ 4 files changed, 75 insertions(+), 3 deletions(-) diff --git a/examples/distributed_inference/data_parallel_stable_diffusion.py b/examples/distributed_inference/data_parallel_stable_diffusion.py index 5c0e3113e5..023d7e8e63 100644 --- a/examples/distributed_inference/data_parallel_stable_diffusion.py +++ b/examples/distributed_inference/data_parallel_stable_diffusion.py @@ -53,7 +53,5 @@ # Assume there are 2 processes (2 devices) with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt: - print("before \n") result = pipe(prompt).images[0] - print("after ") result.save(f"result_{distributed_state.process_index}.png") diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 164f0c1065..af9c34d08c 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -217,7 +217,17 @@ def aten_ops_native_group_norm( ) -@dynamo_tensorrt_converter(torch.ops.aten.cat.default, supports_dynamic_shapes=True) +def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool: + # Validate only one user, which is a getitem node that accesses the first element in the list + for each_input in node.args[0]: + if isinstance(each_input, TRTTensor) and any(s == 0 for s in each_input.shape): + return False + return True + + +@dynamo_tensorrt_converter( + torch.ops.aten.cat.default, supports_dynamic_shapes=True, validator=cat_validator +) def aten_ops_cat( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py index 68bbcc31d0..e3dd01477f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cat.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -1,3 +1,4 @@ +import logging from typing import Optional, Sequence, Union import numpy as np @@ -15,6 +16,8 @@ set_layer_name, ) +logger = logging.getLogger(__name__) + def cat( ctx: ConversionContext, @@ -27,6 +30,13 @@ def cat( ) -> Union[TRTTensor, Sequence[TRTTensor]]: trt_inputs = [] for i, each_input in enumerate(input): + if isinstance(each_input, torch.Tensor) and each_input.numel() == 0: + logger.warning( + f"Warning: empty tensor in cat input {i}, replacing with zeros" + ) + # ITensor with same condition leads to [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed. + # hence the validator + continue if not isinstance(each_input, TRTTensor): each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}") if cast_dtype: diff --git a/tests/py/dynamo/conversion/test_cat_aten.py b/tests/py/dynamo/conversion/test_cat_aten.py index a9e4a45c81..4d7bc02d1f 100644 --- a/tests/py/dynamo/conversion/test_cat_aten.py +++ b/tests/py/dynamo/conversion/test_cat_aten.py @@ -25,6 +25,60 @@ def forward(self, x, y, z): inputs, ) + @parameterized.expand( + [ + ("pos", 0), + ("neg", -3), + ] + ) + def test_cat_with_scalar_inputs(self, _, dim): + # Ensure scalar tensor wrap works + class Cat(nn.Module): + def forward(self, x, y): + # y is a scalar, x is a tensor + return torch.ops.aten.cat.default((x, y), dim) + + x = torch.randn(1, 2, 3, device="cuda") + y = torch.ones_like(x) * 5.0 # simulate scalar broadcast + inputs = [x, y] + self.run_test(Cat(), inputs) + + @parameterized.expand( + [ + ("pos", 0), + ("neg", -3), + ] + ) + def test_cat_with_empty_tensor(self, _, dim): + # Handle empty tensor in concat + class Cat(nn.Module): + def forward(self, x): + y = torch.empty(0, 2, 3, device="cuda") + return torch.ops.aten.cat.default((x, y), dim) + + inputs = [ + torch.randn(1, 2, 3, device="cuda"), + ] + self.run_test(Cat(), inputs) + + @parameterized.expand( + [ + ("pos", 2), + ("neg", -1), + ] + ) + def test_cat_with_different_dtypes(self, _, dim): + # check dtype promotion path in concat + class Cat(nn.Module): + def forward(self, x, y): + return torch.ops.aten.cat.default((x, y), dim) + + inputs = [ + torch.ones(1, 2, 3, dtype=torch.float32, device="cuda"), + torch.ones(1, 2, 3, dtype=torch.float16, device="cuda"), + ] + self.run_test(Cat(), inputs) + @parameterized.expand( [ ("pos", 1), From 88659a161a111615a10a7b229c46641a02955659 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 16 Oct 2025 12:50:19 -0700 Subject: [PATCH 2/3] correcting the validator error message --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index af9c34d08c..080a751a5a 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -218,7 +218,7 @@ def aten_ops_native_group_norm( def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool: - # Validate only one user, which is a getitem node that accesses the first element in the list + # empty tensor in cat input as ITensor leads to [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed. for each_input in node.args[0]: if isinstance(each_input, TRTTensor) and any(s == 0 for s in each_input.shape): return False @@ -226,7 +226,9 @@ def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> @dynamo_tensorrt_converter( - torch.ops.aten.cat.default, supports_dynamic_shapes=True, validator=cat_validator + torch.ops.aten.cat.default, + capability_validator=cat_validator, + supports_dynamic_shapes=True, ) def aten_ops_cat( ctx: ConversionContext, From e6fc22b260cbdf679fe8edc83edf030d47c3439a Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 17 Oct 2025 13:45:34 -0700 Subject: [PATCH 3/3] expanding cat converter to address CI error --- .../dynamo/conversion/aten_ops_converters.py | 44 ++++++++++++++++--- tests/py/dynamo/conversion/test_cat_aten.py | 17 +++++++ 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 080a751a5a..c6e07d481f 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2,7 +2,7 @@ import logging import operator -from typing import Callable, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -217,9 +217,42 @@ def aten_ops_native_group_norm( ) +def parse_cat_args( + args: Tuple[Argument, ...], kwargs: Dict[str, Any] +) -> Tuple[List[Any], int]: + """ + Process inputs for torch.ops.aten.cat.default. + + Handles these valid patterns: + 1. args = ((t1, t2, ...), dim) + 2. args = ((t1, t2, ...),), kwargs = {dim: X} with optional dim in kwargs + + Returns: + (input_tensors, dim) + input_tensors: tuple of tensor arguments + dim: integer concatenation dimension (default 0) + """ + + if len(args) > 1 and isinstance(args[0], (list, tuple)): + input_tensors = list(args[0]) + dim = args_bounds_check(args, 1, 0) + + else: + # If single arg is itself a tuple/list, unwrap it + if len(args) == 1 and isinstance(args[0], (list, tuple)): + input_tensors = list(args[0]) + else: + input_tensors = list(args) + + dim = kwargs.get("dim", 0) + + return input_tensors, dim + + def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool: # empty tensor in cat input as ITensor leads to [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed. - for each_input in node.args[0]: + inputs, _ = parse_cat_args(node.args, node.kwargs) + for each_input in inputs: if isinstance(each_input, TRTTensor) and any(s == 0 for s in each_input.shape): return False return True @@ -227,8 +260,8 @@ def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> @dynamo_tensorrt_converter( torch.ops.aten.cat.default, - capability_validator=cat_validator, supports_dynamic_shapes=True, + capability_validator=cat_validator, ) def aten_ops_cat( ctx: ConversionContext, @@ -237,13 +270,14 @@ def aten_ops_cat( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: + inputs, dim = parse_cat_args(args, kwargs) return impl.cat.cat( ctx, target, SourceIR.ATEN, name, - input=args[0], - dim=args_bounds_check(args, 1, 0), + input=inputs, + dim=dim, ) diff --git a/tests/py/dynamo/conversion/test_cat_aten.py b/tests/py/dynamo/conversion/test_cat_aten.py index 4d7bc02d1f..15aa8b0d80 100644 --- a/tests/py/dynamo/conversion/test_cat_aten.py +++ b/tests/py/dynamo/conversion/test_cat_aten.py @@ -25,6 +25,23 @@ def forward(self, x, y, z): inputs, ) + @parameterized.expand( + [ + ("pos", 1), + ("neg", -2), + ] + ) + def test_cat_dim_in_kwargs(self, _, dim): + class Cat(nn.Module): + def forward(self, x, y, z): + return torch.ops.aten.cat.default((x, y, z), dim=dim) + + inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)] + self.run_test( + Cat(), + inputs, + ) + @parameterized.expand( [ ("pos", 0),