From 4a4cec1a0cc2738fd939aece6b36b5e8f2aa1941 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 26 Sep 2025 18:56:00 -0700 Subject: [PATCH 1/8] addresses the case when shape of upsample tensor contains ITensor --- .../dynamo/conversion/impl/shape.py | 38 +++++++++++++++++++ .../dynamo/conversion/impl/upsample.py | 12 ++++-- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py index 27af02e5bb..c487dfe598 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py @@ -123,3 +123,41 @@ def get_shape_with_dynamic_shape( select_layer = ctx.net.add_select(condition_val, input_shape, scale_res) set_layer_name(select_layer, target, f"{name}_select") return select_layer.get_output(0) + + +def to_trt_shape_tensor( + ctx: ConversionContext, target: Target, name: str, shape_list: List[int | TRTTensor] +) -> TRTTensor: + """ + Convert a mixed shape list (ints + ITensors) into a single ITensor. + + Args: + ctx: ConversionContext + target: fx node target (used for naming). + name (str): base name for layer naming. + shape_list (list[int | ITensor]): list containing static ints and/or ITensors. + + Returns: + ITensor if shape_list contains any ITensors, else plain Python list of ints. + """ + trt_tensors = [] + + for i, s in enumerate(shape_list): + if isinstance(s, int): + const = ctx.net.add_constant((1,), np.array([s], dtype=np.int32)) + set_layer_name(const, target, f"{name}_dim{i}_const") + trt_tensors.append(const.get_output(0)) + else: + # Assume it's already an ITensor + trt_tensors.append(s) + + if trt_tensors: + if any(not isinstance(s, int) for s in shape_list): + # Concatenate everything into a single ITensor + concat_layer = ctx.net.add_concatenation(trt_tensors) + concat_layer.axis = 0 + set_layer_name(concat_layer, target, f"{name}_shape_concat") + return concat_layer.get_output(0) + + # If no ITensor found, return plain list of ints + return shape_list diff --git a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py index 4b47ca5dec..4cedb396d1 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py @@ -9,7 +9,10 @@ has_dynamic_shape, set_layer_name, ) -from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape +from torch_tensorrt.dynamo.conversion.impl.shape import ( + get_shape_with_dynamic_shape, + to_trt_shape_tensor, +) def upsample( @@ -28,14 +31,17 @@ def upsample( if scale_factor is not None: layer.scales = [1.0, 1.0] + list(scale_factor) else: - shape = list(input.shape)[:2] + list(size) + shape = list(input.shape)[:2] + if size is not None: + shape += list(size) if has_dynamic_shape(shape): shape = get_shape_with_dynamic_shape( ctx, target, source_ir, name, shape, input ) layer.set_input(1, shape) else: - layer.shape = shape + layer.shape = to_trt_shape_tensor(ctx, target, name, shape) + layer.set_input(1, layer.shape) if mode == "nearest": layer.resize_mode = trt.InterpolationMode.NEAREST From 975b403d64ad70c20aa72a58ecc9e094e2e22fe2 Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 14 Oct 2025 23:16:31 -0700 Subject: [PATCH 2/8] adding test case and correcting a case --- .../dynamo/conversion/impl/shape.py | 18 ++++---- .../dynamo/conversion/impl/upsample.py | 11 +++-- .../dynamo/conversion/test_upsample_aten.py | 44 +++++++++++++++++++ 3 files changed, 59 insertions(+), 14 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py index c487dfe598..4a0d838a44 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py @@ -132,8 +132,8 @@ def to_trt_shape_tensor( Convert a mixed shape list (ints + ITensors) into a single ITensor. Args: - ctx: ConversionContext - target: fx node target (used for naming). + ctx (ConversionContext): TensorRT ConversionContext object. + target (Target): Target of fx node. name (str): base name for layer naming. shape_list (list[int | ITensor]): list containing static ints and/or ITensors. @@ -148,16 +148,14 @@ def to_trt_shape_tensor( set_layer_name(const, target, f"{name}_dim{i}_const") trt_tensors.append(const.get_output(0)) else: - # Assume it's already an ITensor trt_tensors.append(s) - if trt_tensors: - if any(not isinstance(s, int) for s in shape_list): - # Concatenate everything into a single ITensor - concat_layer = ctx.net.add_concatenation(trt_tensors) - concat_layer.axis = 0 - set_layer_name(concat_layer, target, f"{name}_shape_concat") - return concat_layer.get_output(0) + if any(not isinstance(s, int) for s in shape_list): + # Concatenate everything into a single ITensor if there are any ITensors/Tensors + concat_layer = ctx.net.add_concatenation(trt_tensors) + concat_layer.axis = 0 + set_layer_name(concat_layer, target, f"{name}_shape_concat") + return concat_layer.get_output(0) # If no ITensor found, return plain list of ints return shape_list diff --git a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py index 4cedb396d1..55d1bfe0d7 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py @@ -32,16 +32,19 @@ def upsample( layer.scales = [1.0, 1.0] + list(scale_factor) else: shape = list(input.shape)[:2] - if size is not None: - shape += list(size) + if size is not None: + shape += list(size) if has_dynamic_shape(shape): shape = get_shape_with_dynamic_shape( ctx, target, source_ir, name, shape, input ) layer.set_input(1, shape) else: - layer.shape = to_trt_shape_tensor(ctx, target, name, shape) - layer.set_input(1, layer.shape) + trt_shape = to_trt_shape_tensor(ctx, target, name, shape) + if isinstance(trt_shape, list): + layer.shape = trt_shape + else: + layer.set_input(1, trt_shape) if mode == "nearest": layer.resize_mode = trt.InterpolationMode.NEAREST diff --git a/tests/py/dynamo/conversion/test_upsample_aten.py b/tests/py/dynamo/conversion/test_upsample_aten.py index 44c4af2a92..6646cfa63e 100644 --- a/tests/py/dynamo/conversion/test_upsample_aten.py +++ b/tests/py/dynamo/conversion/test_upsample_aten.py @@ -296,6 +296,50 @@ def forward(self, x): ] self.run_test_with_dynamic_shape(TestModule(), input_specs) + @parameterized.expand( + [ + ([torch.tensor(3), 3], None), + (None, [torch.tensor(0.5), 1.5]), + ] + ) + def test_nearest2d_mixed_dynamic_shape(self, output_size, scale_factors): + class TestModule(torch.nn.Module): + def forward(self, x): + out_size = output_size + scale = scale_factors + + return torch.ops.aten.upsample_nearest2d.vec(x, out_size, scale) + + input_specs = [ + Input( + min_shape=(1, 1, 1, 1), + opt_shape=(5, 5, 5, 5), + max_shape=(9, 9, 9, 9), + dtype=torch.float32, + ) + ] + self.run_test_with_dynamic_shape(TestModule(), input_specs) + + @parameterized.expand( + [ + # Mix of Tensor and int in output_size + ([torch.tensor(3), 3], None), + # Mix of Tensor and float in scale_factors + (None, [torch.tensor(0.5), 1.5]), + ] + ) + def test_nearest2d_mixed_static_input(self, output_size, scale_factors): + class TestModule(torch.nn.Module): + def forward(self, x): + out_size = output_size + scale = scale_factors + return torch.ops.aten.upsample_nearest2d.vec(x, out_size, scale) + + input_size = [7, 7] # H, W + inputs = [torch.randn([1, 1] + input_size)] # shape [1, 1, 7, 7] + + self.run_test(TestModule(), inputs) + if __name__ == "__main__": run_tests() From 8db4c74f411caba452d53f761e7a74847712c208 Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 15 Oct 2025 09:38:56 -0700 Subject: [PATCH 3/8] adding torch Tensor in check condition for CI --- py/torch_tensorrt/dynamo/conversion/impl/shape.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py index 4a0d838a44..517d33e0c8 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py @@ -143,7 +143,7 @@ def to_trt_shape_tensor( trt_tensors = [] for i, s in enumerate(shape_list): - if isinstance(s, int): + if isinstance(s, (int, torch.Tensor)): const = ctx.net.add_constant((1,), np.array([s], dtype=np.int32)) set_layer_name(const, target, f"{name}_dim{i}_const") trt_tensors.append(const.get_output(0)) From 3fcf398b9d7adc92ba65f55de6b4c481316388e4 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 20 Oct 2025 15:32:30 -0700 Subject: [PATCH 4/8] addressing review comment- unifying the shape functionality for upsample with concat --- .../dynamo/conversion/impl/cat.py | 74 +++++++++++++++++-- .../dynamo/conversion/impl/shape.py | 51 ++++++++++++- .../dynamo/conversion/impl/upsample.py | 8 +- 3 files changed, 125 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py index 68bbcc31d0..cac3e248a5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cat.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, Union +from typing import List, Optional, Sequence, Union import numpy as np import tensorrt as trt @@ -16,6 +16,63 @@ ) +def unify_trt_tensors( + ctx: ConversionContext, + target: Target, + name: str, + inputs: Sequence[Union[int, np.ndarray, torch.Tensor, TRTTensor]], + concat_axis: int, + cast_dtype: Union[_enums.dtype, trt.DataType, np.dtype] = None, + force_trt_output: bool = False, +) -> Union[TRTTensor, List[int]]: + """ + Normalize all inputs to TRT tensors if needed, optionally cast, and concat if any dynamic. + + Args: + ctx: TensorRT conversion context. + target: FX target for naming. + name: Base name for layers. + inputs: Sequence of ints / numpy arrays / torch tensors / TRT tensors. + concat_axis: Axis along which to concatenate tensors if dynamic. + cast_dtype: Optional target dtype for casting TRT tensors. + force_trt_output: If True, return TRT tensor even if all inputs are static ints. + """ + has_dynamic = any(not isinstance(x, int) for x in inputs) + trt_tensors = [] + + for i, x in enumerate(inputs): + # convert to TRTTensor + if isinstance(x, TRTTensor): + t = x + elif isinstance(x, int) and not has_dynamic and not force_trt_output: + t = x # pure static path + else: + t = ctx.net.add_constant((1,), np.array([x], dtype=np.int32)) + set_layer_name(t, target, f"{name}_dim{i}_const") + t = t.get_output(0) + + # optional cast + if cast_dtype and isinstance(t, TRTTensor): + t = cast_trt_tensor(ctx, t, cast_dtype, f"{name}_cast_{i}") + + trt_tensors.append(t) + + if not has_dynamic and not force_trt_output: + return trt_tensors # all ints + + # promote remaining ints to TRT consts before concat + for i, t in enumerate(trt_tensors): + if isinstance(t, int): + const = ctx.net.add_constant((1,), np.array([t], dtype=np.int32)) + set_layer_name(const, target, f"{name}_static_{i}_const") + trt_tensors[i] = const.get_output(0) + + concat = ctx.net.add_concatenation(trt_tensors) + concat.axis = concat_axis + set_layer_name(concat, target, f"{name}_concat") + return concat.get_output(0) + + def cat( ctx: ConversionContext, target: Target, @@ -54,9 +111,16 @@ def cat( ) trt_casted_inputs.append(casted_input) trt_inputs = trt_casted_inputs + else: + trt_promoted_type = None - concat_layer = ctx.net.add_concatenation(trt_inputs) dim = get_positive_dim(dim, len(trt_inputs[0].shape)) - concat_layer.axis = dim - set_layer_name(concat_layer, target, f"{name}_gather", source_ir) - return concat_layer.get_output(0) + return unify_trt_tensors( + ctx, + target, + name, + trt_inputs, + concat_axis=dim, + cast_dtype=trt_promoted_type, + force_trt_output=True, + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py index 517d33e0c8..515f6ab398 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional, Tuple +from typing import List, Optional, Sequence, Tuple, Union import numpy as np import tensorrt as trt @@ -159,3 +159,52 @@ def to_trt_shape_tensor( # If no ITensor found, return plain list of ints return shape_list + + +def collect_and_concat_trt_inputs( + ctx: ConversionContext, + target: Target, + name: str, + inputs: Sequence[Union[int, TRTTensor, torch.Tensor, np.ndarray]], + concat_axis: int = 0, + allow_static_return: bool = False, +) -> Union[TRTTensor, List[int]]: + """ + Normalize a sequence of values into TRT ITensors and concatenate them. + If `allow_static_return=True` and all inputs are ints, return a Python + list of ints instead of creating any TRT layers. + """ + trt_tensors = [] + has_dynamic = False + + for i, x in enumerate(inputs): + if isinstance(x, TRTTensor): + trt_tensors.append(x) + has_dynamic = True + + elif isinstance(x, (int, np.integer)): + # keep raw for now, convert only if dynamic found + trt_tensors.append(int(x)) + + else: + # torch/np tensor -> TRT tensor + t = get_trt_tensor(ctx, x, f"{name}_tensor_{i}") + trt_tensors.append(t) + has_dynamic = True + + # fully static shape case + if not has_dynamic and allow_static_return: + return [int(v) for v in trt_tensors] + + # promote remaining ints to TRT constants + for i, v in enumerate(trt_tensors): + if isinstance(v, int): + const = ctx.net.add_constant((1,), np.array([v], dtype=np.int32)) + set_layer_name(const, target, f"{name}_static_dim{i}_const") + trt_tensors[i] = const.get_output(0) + + # concatenate + concat = ctx.net.add_concatenation(trt_tensors) + concat.axis = concat_axis + set_layer_name(concat, target, f"{name}_concat") + return concat.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py index 55d1bfe0d7..8560fff9a5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py @@ -9,9 +9,11 @@ has_dynamic_shape, set_layer_name, ) +from torch_tensorrt.dynamo.conversion.impl.cat import ( + unify_trt_tensors as unify_trt_shape_tensors, +) from torch_tensorrt.dynamo.conversion.impl.shape import ( get_shape_with_dynamic_shape, - to_trt_shape_tensor, ) @@ -40,7 +42,9 @@ def upsample( ) layer.set_input(1, shape) else: - trt_shape = to_trt_shape_tensor(ctx, target, name, shape) + trt_shape = unify_trt_shape_tensors( + ctx, target, name, shape, concat_axis=0, force_trt_output=False + ) if isinstance(trt_shape, list): layer.shape = trt_shape else: From 3d3a8ee3d169c3072454ac3ab3c8e6693f209d76 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 20 Oct 2025 15:34:10 -0700 Subject: [PATCH 5/8] changing function name --- .../dynamo/conversion/impl/cat.py | 10 +-- .../dynamo/conversion/impl/shape.py | 87 +------------------ .../dynamo/conversion/impl/upsample.py | 2 +- 3 files changed, 7 insertions(+), 92 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py index cac3e248a5..e50e57c22c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cat.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -16,7 +16,7 @@ ) -def unify_trt_tensors( +def unify_and_concat_trt_tensors( ctx: ConversionContext, target: Target, name: str, @@ -30,12 +30,12 @@ def unify_trt_tensors( Args: ctx: TensorRT conversion context. - target: FX target for naming. - name: Base name for layers. + target: Operation Target. + name: Operation Name. inputs: Sequence of ints / numpy arrays / torch tensors / TRT tensors. concat_axis: Axis along which to concatenate tensors if dynamic. cast_dtype: Optional target dtype for casting TRT tensors. - force_trt_output: If True, return TRT tensor even if all inputs are static ints. + force_trt_output: If True, return TRT tensor even if all inputs are static ints. (True for concat operations) """ has_dynamic = any(not isinstance(x, int) for x in inputs) trt_tensors = [] @@ -115,7 +115,7 @@ def cat( trt_promoted_type = None dim = get_positive_dim(dim, len(trt_inputs[0].shape)) - return unify_trt_tensors( + return unify_and_concat_trt_tensors( ctx, target, name, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py index 515f6ab398..27af02e5bb 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional, Sequence, Tuple, Union +from typing import List, Optional, Tuple import numpy as np import tensorrt as trt @@ -123,88 +123,3 @@ def get_shape_with_dynamic_shape( select_layer = ctx.net.add_select(condition_val, input_shape, scale_res) set_layer_name(select_layer, target, f"{name}_select") return select_layer.get_output(0) - - -def to_trt_shape_tensor( - ctx: ConversionContext, target: Target, name: str, shape_list: List[int | TRTTensor] -) -> TRTTensor: - """ - Convert a mixed shape list (ints + ITensors) into a single ITensor. - - Args: - ctx (ConversionContext): TensorRT ConversionContext object. - target (Target): Target of fx node. - name (str): base name for layer naming. - shape_list (list[int | ITensor]): list containing static ints and/or ITensors. - - Returns: - ITensor if shape_list contains any ITensors, else plain Python list of ints. - """ - trt_tensors = [] - - for i, s in enumerate(shape_list): - if isinstance(s, (int, torch.Tensor)): - const = ctx.net.add_constant((1,), np.array([s], dtype=np.int32)) - set_layer_name(const, target, f"{name}_dim{i}_const") - trt_tensors.append(const.get_output(0)) - else: - trt_tensors.append(s) - - if any(not isinstance(s, int) for s in shape_list): - # Concatenate everything into a single ITensor if there are any ITensors/Tensors - concat_layer = ctx.net.add_concatenation(trt_tensors) - concat_layer.axis = 0 - set_layer_name(concat_layer, target, f"{name}_shape_concat") - return concat_layer.get_output(0) - - # If no ITensor found, return plain list of ints - return shape_list - - -def collect_and_concat_trt_inputs( - ctx: ConversionContext, - target: Target, - name: str, - inputs: Sequence[Union[int, TRTTensor, torch.Tensor, np.ndarray]], - concat_axis: int = 0, - allow_static_return: bool = False, -) -> Union[TRTTensor, List[int]]: - """ - Normalize a sequence of values into TRT ITensors and concatenate them. - If `allow_static_return=True` and all inputs are ints, return a Python - list of ints instead of creating any TRT layers. - """ - trt_tensors = [] - has_dynamic = False - - for i, x in enumerate(inputs): - if isinstance(x, TRTTensor): - trt_tensors.append(x) - has_dynamic = True - - elif isinstance(x, (int, np.integer)): - # keep raw for now, convert only if dynamic found - trt_tensors.append(int(x)) - - else: - # torch/np tensor -> TRT tensor - t = get_trt_tensor(ctx, x, f"{name}_tensor_{i}") - trt_tensors.append(t) - has_dynamic = True - - # fully static shape case - if not has_dynamic and allow_static_return: - return [int(v) for v in trt_tensors] - - # promote remaining ints to TRT constants - for i, v in enumerate(trt_tensors): - if isinstance(v, int): - const = ctx.net.add_constant((1,), np.array([v], dtype=np.int32)) - set_layer_name(const, target, f"{name}_static_dim{i}_const") - trt_tensors[i] = const.get_output(0) - - # concatenate - concat = ctx.net.add_concatenation(trt_tensors) - concat.axis = concat_axis - set_layer_name(concat, target, f"{name}_concat") - return concat.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py index 8560fff9a5..ac54e18f3a 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py @@ -10,7 +10,7 @@ set_layer_name, ) from torch_tensorrt.dynamo.conversion.impl.cat import ( - unify_trt_tensors as unify_trt_shape_tensors, + unify_and_concat_trt_tensors as unify_trt_shape_tensors, ) from torch_tensorrt.dynamo.conversion.impl.shape import ( get_shape_with_dynamic_shape, From 5b77d3f35d795b925162a2dcb636d7db2a5198f9 Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 22 Oct 2025 21:29:46 -0700 Subject: [PATCH 6/8] add finalcast for cat case --- .../dynamo/conversion/impl/cat.py | 75 ++++++++++--------- 1 file changed, 40 insertions(+), 35 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py index e50e57c22c..32eba00bf9 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cat.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -11,7 +11,6 @@ from torch_tensorrt.dynamo.conversion.converter_utils import ( cast_trt_tensor, get_positive_dim, - get_trt_tensor, set_layer_name, ) @@ -60,6 +59,31 @@ def unify_and_concat_trt_tensors( if not has_dynamic and not force_trt_output: return trt_tensors # all ints + final_dtype = None + if cast_dtype: + # Explicit cast requested + if isinstance(cast_dtype, _enums.dtype): + final_dtype = cast_dtype.to(trt.DataType) + elif isinstance(cast_dtype, np.dtype): + final_dtype = _enums.dtype._from(cast_dtype).to(trt.DataType) + else: + final_dtype = cast_dtype # already trt.DataType + else: + # Automatic promotion + promoted_type = None + for t in trt_tensors: + if isinstance(t, TRTTensor): + if promoted_type is None: + promoted_type = t.dtype + else: + promoted_type = _enums.dtype._from( + torch.promote_types( + _enums.dtype._from(promoted_type).to(torch.dtype), + _enums.dtype._from(t.dtype).to(torch.dtype), + ) + ).to(trt.DataType) + final_dtype = promoted_type + # promote remaining ints to TRT consts before concat for i, t in enumerate(trt_tensors): if isinstance(t, int): @@ -67,6 +91,15 @@ def unify_and_concat_trt_tensors( set_layer_name(const, target, f"{name}_static_{i}_const") trt_tensors[i] = const.get_output(0) + # final cast + if final_dtype is not None: + casted = [] + for i, t in enumerate(trt_tensors): + if isinstance(t, TRTTensor): + t = cast_trt_tensor(ctx, t, final_dtype, f"{name}_cast_{i}") + casted.append(t) + trt_tensors = casted + concat = ctx.net.add_concatenation(trt_tensors) concat.axis = concat_axis set_layer_name(concat, target, f"{name}_concat") @@ -82,45 +115,17 @@ def cat( dim: int, cast_dtype: Union[_enums.dtype, trt.DataType, np.dtype] = None, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - trt_inputs = [] - for i, each_input in enumerate(input): - if not isinstance(each_input, TRTTensor): - each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}") - if cast_dtype: - each_input = cast_trt_tensor( - ctx, each_input, cast_dtype, f"{name}_tensor_int32_cast_{i}" - ) - trt_inputs.append(each_input) - - if len(trt_inputs) > 1: - # Cast to promoted type for all inputs - promoted_type = trt_inputs[0].dtype - for each_input in trt_inputs[1:]: - promoted_type = _enums.dtype._from( - torch.promote_types( - _enums.dtype._from(promoted_type).to(torch.dtype), - _enums.dtype._from(each_input.dtype).to(torch.dtype), - ) - ) - trt_promoted_type = promoted_type.to(trt.DataType) - - trt_casted_inputs = [] - for i, each_input in enumerate(trt_inputs): - casted_input = cast_trt_tensor( - ctx, each_input, trt_promoted_type, f"{name}_input_casted_{i}" - ) - trt_casted_inputs.append(casted_input) - trt_inputs = trt_casted_inputs + # int is only when cat called in other ops like pad + if not isinstance(input[0], int): + dim = get_positive_dim(dim, len(input[0].shape)) else: - trt_promoted_type = None - - dim = get_positive_dim(dim, len(trt_inputs[0].shape)) + dim = 0 return unify_and_concat_trt_tensors( ctx, target, name, - trt_inputs, + input, concat_axis=dim, - cast_dtype=trt_promoted_type, + cast_dtype=cast_dtype, force_trt_output=True, ) From ce56ca603b0d78c76b63deef91dfc51eb5fd6202 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 27 Oct 2025 21:13:15 -0700 Subject: [PATCH 7/8] addressing embedding bag CI error --- py/torch_tensorrt/dynamo/conversion/impl/cat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py index 32eba00bf9..d5be564284 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cat.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -46,7 +46,7 @@ def unify_and_concat_trt_tensors( elif isinstance(x, int) and not has_dynamic and not force_trt_output: t = x # pure static path else: - t = ctx.net.add_constant((1,), np.array([x], dtype=np.int32)) + t = ctx.net.add_constant((x.numel(),), np.array([x], dtype=np.int32)) set_layer_name(t, target, f"{name}_dim{i}_const") t = t.get_output(0) From 91a2519b15e364804c4991ae4d16417d9391567f Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 28 Oct 2025 13:15:33 -0700 Subject: [PATCH 8/8] correcting the int case --- py/torch_tensorrt/dynamo/conversion/impl/cat.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py index d5be564284..be2743ad5b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cat.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -46,9 +46,15 @@ def unify_and_concat_trt_tensors( elif isinstance(x, int) and not has_dynamic and not force_trt_output: t = x # pure static path else: - t = ctx.net.add_constant((x.numel(),), np.array([x], dtype=np.int32)) - set_layer_name(t, target, f"{name}_dim{i}_const") - t = t.get_output(0) + const_arr = np.array([x], dtype=np.int32) + shape = (1,) + if not isinstance(x, int): + const_arr = np.array(x, dtype=np.int32) + shape = (x.numel(),) + + layer = ctx.net.add_constant(shape, const_arr) + set_layer_name(layer, target, f"{name}_dim{i}_const") + t = layer.get_output(0) # optional cast if cast_dtype and isinstance(t, TRTTensor):