diff --git a/docker/Dockerfile b/docker/Dockerfile index 2138c94100..739e0aecc1 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -16,8 +16,7 @@ ENV PYTHON_VERSION=${PYTHON_VERSION} ENV DEBIAN_FRONTEND=noninteractive # Install basic dependencies -RUN apt-get update -RUN apt install -y build-essential manpages-dev wget zlib1g software-properties-common git libssl-dev zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev wget ca-certificates curl llvm libncurses5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev mecab-ipadic-utf8 +RUN apt-get update && apt-get install -y build-essential manpages-dev wget zlib1g software-properties-common git libssl-dev zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev wget ca-certificates curl llvm libncurses5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev mecab-ipadic-utf8 # Install PyEnv and desired Python version ENV HOME="/root" @@ -34,8 +33,7 @@ RUN pyenv global ${PYTHON_VERSION} # Install TensorRT + dependencies RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub RUN add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ /" -RUN apt-get update -RUN TENSORRT_MAJOR_VERSION=`echo ${TENSORRT_VERSION} | cut -d '.' -f 1` && \ +RUN apt-get update && TENSORRT_MAJOR_VERSION=`echo ${TENSORRT_VERSION} | cut -d '.' -f 1` && \ apt-get install -y libnvinfer${TENSORRT_MAJOR_VERSION}=${TENSORRT_VERSION}.* \ libnvinfer-plugin${TENSORRT_MAJOR_VERSION}=${TENSORRT_VERSION}.* \ libnvinfer-dev=${TENSORRT_VERSION}.* \ @@ -55,9 +53,9 @@ FROM base as torch-tensorrt-builder-base ARG ARCH="x86_64" ARG TARGETARCH="amd64" -RUN apt-get update -RUN apt-get install -y python3-setuptools -RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub +RUN apt-get update && \ + apt-get install -y python3-setuptools && \ + apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub RUN apt-get update &&\ apt-get install -y --no-install-recommends locales ninja-build &&\ diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 094de488ec..175979ccf9 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -446,53 +446,47 @@ def create_constant( else: shape = list(torch_value.shape) - if torch_value is not None: - - if torch_value.dtype == torch.uint8: - if is_tensorrt_version_supported("10.8.0"): - if ( - target_quantized_type is None - or target_quantized_type != trt.DataType.FP4 - ): - # Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8 - raise ValueError( - "Currently supported target_quantized_type for uint8 is FP4, got {target_quantized_type=}" - ) - shape[-1] = shape[-1] * 2 - weights = to_trt_weights( - ctx, - torch_value, - name, - "CONSTANT", - "CONSTANT", - dtype=trt.DataType.FP4, - count=torch_value.numel() * 2, - ) - constant = ctx.net.add_constant( - shape, - weights, - ) - constant.name = name - return constant.get_output(0) - else: + if torch_value.dtype == torch.uint8: + if is_tensorrt_version_supported("10.8.0"): + if ( + target_quantized_type is None + or target_quantized_type != trt.DataType.FP4 + ): + # Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8 raise ValueError( - "Currently FP4 is only supported in TensorRT 10.8.0 and above" + "Currently supported target_quantized_type for uint8 is FP4, got {target_quantized_type=}" ) - # Record the weight in ctx for refit and cpu memory reference + shape[-1] = shape[-1] * 2 + weights = to_trt_weights( + ctx, + torch_value, + name, + "CONSTANT", + "CONSTANT", + dtype=trt.DataType.FP4, + count=torch_value.numel() * 2, + ) + constant = ctx.net.add_constant( + shape, + weights, + ) + constant.name = name + return constant.get_output(0) + else: + raise ValueError( + "Currently FP4 is only supported in TensorRT 10.8.0 and above" + ) + # Record the weight in ctx for refit and cpu memory reference - # Convert the torch.Tensor to a trt.Weights object - trt_weights = to_trt_weights(ctx, torch_value, name, "CONSTANT", "CONSTANT") - constant = ctx.net.add_constant( - shape, - trt_weights, - ) - constant.name = name + # Convert the torch.Tensor to a trt.Weights object + trt_weights = to_trt_weights(ctx, torch_value, name, "CONSTANT", "CONSTANT") + constant = ctx.net.add_constant( + shape, + trt_weights, + ) + constant.name = name - return constant.get_output(0) - else: - raise ValueError( - f"Cannot convert tensor '{name}' to a TensorRT constant because its value is None." - ) + return constant.get_output(0) def get_trt_tensor( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py index b9ee582d26..9653bdf9f8 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py @@ -109,13 +109,16 @@ def deconvNd( assert len(kernel_shape) > 0, "Deconvolution kernel shape must be non-empty" # add deconv layer + if groups is not None: + num_output_maps = num_output_maps * groups deconv_layer = ctx.net.add_deconvolution_nd( input=input, - num_output_maps=num_output_maps * groups, + num_output_maps=num_output_maps, kernel_shape=kernel_shape, kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight, bias=trt.Weights() if isinstance(bias, TRTTensor) else bias, ) + assert deconv_layer is not None, "Deconvolution layer is None" set_layer_name(deconv_layer, target, name, source_ir) # If the weight is a TRTTensor, set it as an input of the layer @@ -145,7 +148,6 @@ def deconvNd( if output_padding is not None else output_padding ) - # Set relevant attributes of deconvolution layer if padding is not None: deconv_layer.padding_nd = padding @@ -156,19 +158,20 @@ def deconvNd( if groups is not None: deconv_layer.num_groups = groups - ndims = len(padding) - pre_padding_values = [] - post_padding_values = [] + if padding is not None: + ndims = len(padding) + pre_padding_values = [] + post_padding_values = [] - for dim in range(ndims): - pre_padding = padding[dim] - post_padding = padding[dim] - output_padding[dim] + for dim in range(ndims): + pre_padding = padding[dim] + post_padding = padding[dim] - output_padding[dim] - pre_padding_values.append(pre_padding) - post_padding_values.append(post_padding) + pre_padding_values.append(pre_padding) + post_padding_values.append(post_padding) - deconv_layer.pre_padding = tuple(pre_padding_values) - deconv_layer.post_padding = tuple(post_padding_values) + deconv_layer.pre_padding = tuple(pre_padding_values) + deconv_layer.post_padding = tuple(post_padding_values) result = deconv_layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 7d68ea0d93..2372635a40 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -331,7 +331,7 @@ def reduce_operation_with_scatter( scatter_tensor = initial_tensor else: # This case would not be encountered from torch itself - print("Invalid Operation for Reduce op!!") + raise ValueError(f"Invalid Operation for Reduce op: {self}") operation_rhs = torch.scatter(scatter_tensor, dim, index_tensor, src_tensor) device = to_torch_device(scatter_tensor.device) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 97328acd6d..f822e40e1b 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -826,13 +826,13 @@ def get_output_metadata( return [node.meta for node in nodes] -def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]: +def get_output_dtypes(output: Any, truncate_double: bool = False) -> List[dtype]: output_dtypes = [] if isinstance(output, torch.fx.node.Node): if "val" in output.meta: output_meta = output.meta["val"] if isinstance(output_meta, (FakeTensor, torch.Tensor)): - if truncate_doulbe and output_meta.dtype == torch.float64: + if truncate_double and output_meta.dtype == torch.float64: output_dtypes.append(dtype.float32) else: output_dtypes.append(dtype._from(output_meta.dtype)) diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 78ea125424..a63c7cde00 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -173,7 +173,10 @@ def to_numpy( """ output = None - if value is None or isinstance(value, np.ndarray): + if value is None: + return None + + elif isinstance(value, np.ndarray): output = value elif isinstance(value, torch.Tensor): diff --git a/py/torch_tensorrt/fx/tools/timing_cache_utils.py b/py/torch_tensorrt/fx/tools/timing_cache_utils.py index 4580843e98..dcd9bd0f50 100644 --- a/py/torch_tensorrt/fx/tools/timing_cache_utils.py +++ b/py/torch_tensorrt/fx/tools/timing_cache_utils.py @@ -28,12 +28,12 @@ def get_timing_cache_trt(self, timing_cache_file: str) -> bytearray: return None def update_timing_cache( - self, timing_cache_file: str, serilized_cache: bytearray + self, timing_cache_file: str, serialized_cache: bytearray ) -> None: if not self.save_timing_cache: return timing_cache_file = self.get_file_full_name(timing_cache_file) with open(timing_cache_file, "wb") as local_cache: local_cache.seek(0) - local_cache.write(serilized_cache) + local_cache.write(serialized_cache) local_cache.truncate()