Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 106 additions & 37 deletions py/torch_tensorrt/dynamo/conversion/impl/cat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Sequence, Union
from typing import List, Optional, Sequence, Union

import numpy as np
import tensorrt as trt
Expand All @@ -11,11 +11,101 @@
from torch_tensorrt.dynamo.conversion.converter_utils import (
cast_trt_tensor,
get_positive_dim,
get_trt_tensor,
set_layer_name,
)


def unify_and_concat_trt_tensors(
ctx: ConversionContext,
target: Target,
name: str,
inputs: Sequence[Union[int, np.ndarray, torch.Tensor, TRTTensor]],
concat_axis: int,
cast_dtype: Union[_enums.dtype, trt.DataType, np.dtype] = None,
force_trt_output: bool = False,
) -> Union[TRTTensor, List[int]]:
"""
Normalize all inputs to TRT tensors if needed, optionally cast, and concat if any dynamic.

Args:
ctx: TensorRT conversion context.
target: Operation Target.
name: Operation Name.
inputs: Sequence of ints / numpy arrays / torch tensors / TRT tensors.
concat_axis: Axis along which to concatenate tensors if dynamic.
cast_dtype: Optional target dtype for casting TRT tensors.
force_trt_output: If True, return TRT tensor even if all inputs are static ints. (True for concat operations)
"""
has_dynamic = any(not isinstance(x, int) for x in inputs)
trt_tensors = []

for i, x in enumerate(inputs):
# convert to TRTTensor
if isinstance(x, TRTTensor):
t = x
elif isinstance(x, int) and not has_dynamic and not force_trt_output:
t = x # pure static path
else:
t = ctx.net.add_constant((1,), np.array([x], dtype=np.int32))
set_layer_name(t, target, f"{name}_dim{i}_const")
t = t.get_output(0)

# optional cast
if cast_dtype and isinstance(t, TRTTensor):
t = cast_trt_tensor(ctx, t, cast_dtype, f"{name}_cast_{i}")

trt_tensors.append(t)

if not has_dynamic and not force_trt_output:
return trt_tensors # all ints

final_dtype = None
if cast_dtype:
# Explicit cast requested
if isinstance(cast_dtype, _enums.dtype):
final_dtype = cast_dtype.to(trt.DataType)
elif isinstance(cast_dtype, np.dtype):
final_dtype = _enums.dtype._from(cast_dtype).to(trt.DataType)
else:
final_dtype = cast_dtype # already trt.DataType
else:
# Automatic promotion
promoted_type = None
for t in trt_tensors:
if isinstance(t, TRTTensor):
if promoted_type is None:
promoted_type = t.dtype
else:
promoted_type = _enums.dtype._from(
torch.promote_types(
_enums.dtype._from(promoted_type).to(torch.dtype),
_enums.dtype._from(t.dtype).to(torch.dtype),
)
).to(trt.DataType)
final_dtype = promoted_type

# promote remaining ints to TRT consts before concat
for i, t in enumerate(trt_tensors):
if isinstance(t, int):
const = ctx.net.add_constant((1,), np.array([t], dtype=np.int32))
set_layer_name(const, target, f"{name}_static_{i}_const")
trt_tensors[i] = const.get_output(0)

# final cast
if final_dtype is not None:
casted = []
for i, t in enumerate(trt_tensors):
if isinstance(t, TRTTensor):
t = cast_trt_tensor(ctx, t, final_dtype, f"{name}_cast_{i}")
casted.append(t)
trt_tensors = casted

concat = ctx.net.add_concatenation(trt_tensors)
concat.axis = concat_axis
set_layer_name(concat, target, f"{name}_concat")
return concat.get_output(0)


def cat(
ctx: ConversionContext,
target: Target,
Expand All @@ -25,38 +115,17 @@ def cat(
dim: int,
cast_dtype: Union[_enums.dtype, trt.DataType, np.dtype] = None,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
trt_inputs = []
for i, each_input in enumerate(input):
if not isinstance(each_input, TRTTensor):
each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}")
if cast_dtype:
each_input = cast_trt_tensor(
ctx, each_input, cast_dtype, f"{name}_tensor_int32_cast_{i}"
)
trt_inputs.append(each_input)

if len(trt_inputs) > 1:
# Cast to promoted type for all inputs
promoted_type = trt_inputs[0].dtype
for each_input in trt_inputs[1:]:
promoted_type = _enums.dtype._from(
torch.promote_types(
_enums.dtype._from(promoted_type).to(torch.dtype),
_enums.dtype._from(each_input.dtype).to(torch.dtype),
)
)
trt_promoted_type = promoted_type.to(trt.DataType)

trt_casted_inputs = []
for i, each_input in enumerate(trt_inputs):
casted_input = cast_trt_tensor(
ctx, each_input, trt_promoted_type, f"{name}_input_casted_{i}"
)
trt_casted_inputs.append(casted_input)
trt_inputs = trt_casted_inputs

concat_layer = ctx.net.add_concatenation(trt_inputs)
dim = get_positive_dim(dim, len(trt_inputs[0].shape))
concat_layer.axis = dim
set_layer_name(concat_layer, target, f"{name}_gather", source_ir)
return concat_layer.get_output(0)
# int is only when cat called in other ops like pad
if not isinstance(input[0], int):
dim = get_positive_dim(dim, len(input[0].shape))
else:
dim = 0
return unify_and_concat_trt_tensors(
ctx,
target,
name,
input,
concat_axis=dim,
cast_dtype=cast_dtype,
force_trt_output=True,
)
19 changes: 16 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
has_dynamic_shape,
set_layer_name,
)
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.dynamo.conversion.impl.cat import (
unify_and_concat_trt_tensors as unify_trt_shape_tensors,
)
from torch_tensorrt.dynamo.conversion.impl.shape import (
get_shape_with_dynamic_shape,
)


def upsample(
Expand All @@ -28,14 +33,22 @@ def upsample(
if scale_factor is not None:
layer.scales = [1.0, 1.0] + list(scale_factor)
else:
shape = list(input.shape)[:2] + list(size)
shape = list(input.shape)[:2]
if size is not None:
shape += list(size)
if has_dynamic_shape(shape):
shape = get_shape_with_dynamic_shape(
ctx, target, source_ir, name, shape, input
)
layer.set_input(1, shape)
else:
layer.shape = shape
trt_shape = unify_trt_shape_tensors(
ctx, target, name, shape, concat_axis=0, force_trt_output=False
)
if isinstance(trt_shape, list):
layer.shape = trt_shape
else:
layer.set_input(1, trt_shape)

if mode == "nearest":
layer.resize_mode = trt.InterpolationMode.NEAREST
Expand Down
44 changes: 44 additions & 0 deletions tests/py/dynamo/conversion/test_upsample_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,50 @@ def forward(self, x):
]
self.run_test_with_dynamic_shape(TestModule(), input_specs)

@parameterized.expand(
[
([torch.tensor(3), 3], None),
(None, [torch.tensor(0.5), 1.5]),
]
)
def test_nearest2d_mixed_dynamic_shape(self, output_size, scale_factors):
class TestModule(torch.nn.Module):
def forward(self, x):
out_size = output_size
scale = scale_factors

return torch.ops.aten.upsample_nearest2d.vec(x, out_size, scale)

input_specs = [
Input(
min_shape=(1, 1, 1, 1),
opt_shape=(5, 5, 5, 5),
max_shape=(9, 9, 9, 9),
dtype=torch.float32,
)
]
self.run_test_with_dynamic_shape(TestModule(), input_specs)

@parameterized.expand(
[
# Mix of Tensor and int in output_size
([torch.tensor(3), 3], None),
# Mix of Tensor and float in scale_factors
(None, [torch.tensor(0.5), 1.5]),
]
)
def test_nearest2d_mixed_static_input(self, output_size, scale_factors):
class TestModule(torch.nn.Module):
def forward(self, x):
out_size = output_size
scale = scale_factors
return torch.ops.aten.upsample_nearest2d.vec(x, out_size, scale)

input_size = [7, 7] # H, W
inputs = [torch.randn([1, 1] + input_size)] # shape [1, 1, 7, 7]

self.run_test(TestModule(), inputs)


if __name__ == "__main__":
run_tests()
Loading