diff --git a/docsrc/contributors/resource_management.rst b/docsrc/contributors/resource_management.rst new file mode 100644 index 0000000000..e30fcd0a75 --- /dev/null +++ b/docsrc/contributors/resource_management.rst @@ -0,0 +1,77 @@ +.. _resource_management: + +Resource Management +=================== + +Overview +-------- + +Efficient control of CPU and GPU memory is essential for successful model compilation, +especially when working with large models such as LLMs or diffusion models. +Uncontrolled memory growth can cause compilation failures or process termination. +This guide describes the symptoms of excessive memory usage and provides methods +to reduce both CPU and GPU memory consumption. + +Memory Usage Control +-------------------- + +CPU Memory +^^^^^^^^^^ + +By default, Torch-TensorRT may consume up to **5x** the model size in CPU memory. +This can exceed system limits when compiling large models. + +**Common symptoms of high CPU memory usage:** + +- Program freeze +- Process terminated by the operating system + +**Ways to lower CPU memory usage:** + +1. **Enable memory trimming** + + Set the following environment variable: + + .. code-block:: bash + + export TORCHTRT_ENABLE_BUILDER_MALLOC_TRIM=1 + + This reduces approximately **2x** of redundant model copies, limiting + total CPU memory usage to up to **3x** the model size. + +2. **Disable CPU offloading** + + In compilation settings, set: + + .. code-block:: python + + offload_module_to_cpu = False + + This removes another **1x** model copy, reducing peak CPU memory + usage to about **2x** the model size. + +GPU Memory +^^^^^^^^^^ + +By default, Torch-TensorRT may consume up to **2x** the model size in GPU memory. + +**Common symptoms of high GPU memory usage:** + +- CUDA out-of-memory errors +- TensorRT compilation errors + +**Ways to lower GPU memory usage:** + +1. **Enable offloading to CPU** + + In compilation settings, set: + + .. code-block:: python + + offload_module_to_cpu = True + + This shifts one model copy from GPU to CPU memory. + As a result, peak GPU memory usage decreases to about **1x** + the model size, while one more copy of the model will occupy the CPU memory so CPU memory usage increases by roughly **1x**. + + diff --git a/docsrc/index.rst b/docsrc/index.rst index 671379d004..ed58ed4d52 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -234,6 +234,7 @@ Contributor Documentation contributors/writing_dynamo_aten_lowering_passes contributors/ts_converters contributors/useful_links + contributors/resource_management Indices ---------------- diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index c8ad938032..bc345947d3 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -42,6 +42,7 @@ ) from torch_tensorrt.dynamo.utils import ( deallocate_module, + get_cpu_memory_usage, get_flat_args_with_check, get_output_metadata, parse_graph_io, @@ -681,7 +682,7 @@ def compile( "offload_module_to_cpu": offload_module_to_cpu, "use_distributed_mode_trace": use_distributed_mode_trace, } - + logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB") settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) exported_program = pre_export_lowering(exported_program, settings) @@ -695,14 +696,17 @@ def compile( # Apply lowering on the graph module gm = post_lowering(gm, settings) + logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB") logger.debug("Lowered Input graph: " + str(gm.graph)) # Move the weights in the state_dict to CPU if offload_module_to_cpu: + deallocate_module(gm, delete_module=False) deallocate_module(exported_program.module(), delete_module=False) logger.info( "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False" ) + logger.debug(f"CPU memory usage after CPU offload: {get_cpu_memory_usage()} MB") else: remaining_memory, total_memory = torch.cuda.mem_get_info() if remaining_memory < total_memory // 2: @@ -868,6 +872,11 @@ def preserve_module_specs( # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those + # Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function. + # This is done to release CPU memory. + for attr in dir(gm): + if attr.startswith("_frozen_param"): + delattr(gm, attr) for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) # filter on the GraphModule @@ -1243,7 +1252,7 @@ def convert_exported_program_to_serialized_trt_engine( # Prepare torch_trt inputs trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) - trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) + trt_kwarg_inputs: Optional[dict[str, Any]] = prepare_inputs(kwarg_inputs) device = to_torch_tensorrt_device(device) enabled_precisions = {dtype._from(p) for p in enabled_precisions} @@ -1330,7 +1339,7 @@ def convert_exported_program_to_serialized_trt_engine( ) flattened_input_list = get_flat_args_with_check( - exported_program, list(trt_arg_inputs), trt_kwarg_inputs + exported_program, list(trt_arg_inputs), trt_kwarg_inputs # type: ignore )[0] try: diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 73af09448e..e31c423581 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -1,5 +1,4 @@ import gc -import io import logging import os import warnings @@ -50,7 +49,12 @@ from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger from torch_tensorrt.dynamo.observer import Observer -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device +from torch_tensorrt.dynamo.utils import ( + DYNAMIC_DIM, + deallocate_module, + get_cpu_memory_usage, + to_torch_device, +) from torch_tensorrt.logging import TRT_LOGGER _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -65,7 +69,7 @@ class UnsupportedOperatorException(RuntimeError): class TRTInterpreterResult(NamedTuple): - serialized_engine: bytes + engine: trt.ICudaEngine input_names: Sequence[str] output_names: Sequence[str] weight_name_map: Optional[dict[Any, Any]] @@ -512,8 +516,7 @@ def _save_weight_mapping(self) -> None: _LOGGER.info("Building weight name mapping...") # Stage 1: Name mapping torch_device = to_torch_device(self.compilation_settings.device) - self.module.to(torch_device) - sd = self.module.state_dict() + sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()} weight_name_map: dict[str, Any] = {} weight_refit_map = self.ctx.weight_refit_map constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1} @@ -591,34 +594,6 @@ def _save_weight_mapping(self) -> None: gc.collect() torch.cuda.empty_cache() - @needs_refit # type: ignore[misc] - def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None: - # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine - # if not self.compilation_settings.strip_engine_weights: - # # set EXCLUDE_WEIGHTS flag to strip weights - # runtime = trt.Runtime(TRT_LOGGER) - # engine = runtime.deserialize_cuda_engine(serialized_engine) - - # serialization_config = engine.create_serialization_config() - # serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) - # serialized_engine = engine.serialize_with_config( - # serialization_config - # ) - - # Cache weighted engine for now - self.engine_cache.insert( # type: ignore[union-attr] - hash_val, - ( - serialized_engine, - self._input_names, - self._output_names, - self.input_specs, - self.compilation_settings, - self.weight_name_map, - self.ctx.requires_output_allocator, - ), - ) - @needs_refit # type: ignore[misc] def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]: # query the cached TRT engine @@ -671,7 +646,6 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]: settings=self.compilation_settings, weight_name_map=self.weight_name_map, ) - serialized_engine = engine.serialize() # TODO: @Evan is waiting for TRT's feature to load the weight-stripped engine # # EXCLUDE_WEIGHTS flag must be cleared @@ -684,12 +658,8 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]: # ) # # As of now, the engine becomes non-refittable because when EXCLUDE_WEIGHTS flag is cleared, the REFIT flag is also cleared by TRT to make the plan file smaller - with io.BytesIO() as engine_bytes: - engine_bytes.write(serialized_engine) - engine_str = engine_bytes.getvalue() - return TRTInterpreterResult( - engine_str, + engine, self._input_names, self._output_names, self.weight_name_map, @@ -733,6 +703,9 @@ def run( return interpreter_result # type: ignore[no-any-return] self._construct_trt_network_def() + _LOGGER.debug( + f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB" + ) if not self.compilation_settings.immutable_weights: self._save_weight_mapping() @@ -750,36 +723,39 @@ def run( self._create_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) - serialized_engine = self.builder.build_serialized_network( - self.ctx.net, builder_config + + if ( + ENABLED_FEATURES.tensorrt_rtx + or self.compilation_settings.version_compatible + ): + # TODO: When TRT-RTX matures, change it to build_engine_with_config + serialized_engine = self.builder.build_serialized_network( + self.ctx.net, builder_config + ) + runtime = trt.Runtime(TRT_LOGGER) + cuda_engine = runtime.deserialize_cuda_engine(serialized_engine) + else: + + cuda_engine = self.builder.build_engine_with_config( + self.ctx.net, builder_config + ) + assert cuda_engine + + _LOGGER.debug( + f"CPU memory usage after engine building: {get_cpu_memory_usage()} MB" ) - assert serialized_engine _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) - _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") - self.ctx.clear_cpu_weights_reference_holder() self._save_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) - # Engine caching only for refittable engines - if ( - not self.compilation_settings.immutable_weights - and self.compilation_settings.cache_built_engines - and self.engine_cache is not None - ): - self._insert_engine_to_cache(hash_val, serialized_engine) - - with io.BytesIO() as engine_bytes: - engine_bytes.write(serialized_engine) - engine_str = engine_bytes.getvalue() - return TRTInterpreterResult( - engine_str, + cuda_engine, self._input_names, self._output_names, self.weight_name_map, diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 35b6c26617..76926107a4 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -1,7 +1,8 @@ from __future__ import annotations +import io import logging -from typing import Any, List, Optional, Sequence +from typing import Any, List, NamedTuple, Optional, Sequence import torch from torch_tensorrt._enums import dtype @@ -9,16 +10,25 @@ from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._engine_cache import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.conversion._TRTInterpreter import ( - TRTInterpreter, - TRTInterpreterResult, -) +from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule -from torch_tensorrt.dynamo.utils import get_output_dtypes +from torch_tensorrt.dynamo.utils import ( + get_cpu_memory_usage, + get_output_dtypes, + release_host_and_device_memory, +) logger = logging.getLogger(__name__) +class SerializedInterpreterResult(NamedTuple): + serialized_engine: bytes + input_names: Sequence[str] + output_names: Sequence[str] + weight_name_map: Optional[dict[Any, Any]] + requires_output_allocator: bool + + def infer_module_output_dtypes( module: torch.fx.GraphModule, truncate_double: bool = False, @@ -39,7 +49,7 @@ def interpret_module_to_result( arg_inputs: Optional[Sequence[Input]] = None, kwarg_inputs: Optional[dict[str, Any]] = None, engine_cache: Optional[BaseEngineCache] = None, -) -> TRTInterpreterResult: +) -> SerializedInterpreterResult: """Interpret an FX module to a TRTInterpreterResult Args: module: FX GraphModule to interpret @@ -50,8 +60,9 @@ def interpret_module_to_result( settings: Compilation settings engine_cache: Engine cache instance Returns: - TRTInterpreterResult + SerializedInterpreterResult """ + output_dtypes = infer_module_output_dtypes( module, truncate_double=settings.truncate_double ) @@ -65,7 +76,53 @@ def interpret_module_to_result( ) interpreter_result = interpreter.run() - return interpreter_result + # Delete the frozen parameters from the module to release CPU memory + del interpreter + for attr in dir(module): + if attr.startswith("_frozen_param"): + delattr(module, attr) + release_host_and_device_memory() + logger.debug( + f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB" + ) + + serialized_engine = interpreter_result.engine.serialize() + with io.BytesIO() as engine_bytes: + engine_bytes.write(serialized_engine) + serialized_engine = engine_bytes.getvalue() + logger.debug( + f"CPU memory usage after serializing engine: {get_cpu_memory_usage()} MB" + ) + + # Engine caching only for refittable engines + if ( + not settings.immutable_weights + and settings.cache_built_engines + and engine_cache is not None + ): + hash_val = engine_cache.get_hash(module, inputs, settings) + engine_cache.insert( + hash_val, + ( + serialized_engine, + interpreter_result.input_names, + interpreter_result.output_names, + inputs, + settings, + interpreter_result.weight_name_map, + interpreter_result.requires_output_allocator, + ), + ) + + serialized_interpreter_result = SerializedInterpreterResult( + serialized_engine=serialized_engine, + input_names=interpreter_result.input_names, + output_names=interpreter_result.output_names, + weight_name_map=interpreter_result.weight_name_map, + requires_output_allocator=interpreter_result.requires_output_allocator, + ) + + return serialized_interpreter_result def convert_module( @@ -85,7 +142,7 @@ def convert_module( Returns: PythonTorchTensorRTModule or TorchTensorRTModule """ - interpreter_result = interpret_module_to_result( + serialized_interpreter_result = interpret_module_to_result( module, inputs, settings, engine_cache=engine_cache ) @@ -104,11 +161,11 @@ def convert_module( ) return rt_cls( - serialized_engine=interpreter_result.serialized_engine, - input_binding_names=list(interpreter_result.input_names), - output_binding_names=list(interpreter_result.output_names), + serialized_engine=serialized_interpreter_result.serialized_engine, + input_binding_names=list(serialized_interpreter_result.input_names), + output_binding_names=list(serialized_interpreter_result.output_names), name=name, settings=settings, - weight_name_map=interpreter_result.weight_name_map, - requires_output_allocator=interpreter_result.requires_output_allocator, + weight_name_map=serialized_interpreter_result.weight_name_map, + requires_output_allocator=serialized_interpreter_result.requires_output_allocator, ) diff --git a/py/torch_tensorrt/dynamo/debug/_Debugger.py b/py/torch_tensorrt/dynamo/debug/_Debugger.py index 39e4217f73..e565929861 100644 --- a/py/torch_tensorrt/dynamo/debug/_Debugger.py +++ b/py/torch_tensorrt/dynamo/debug/_Debugger.py @@ -220,6 +220,7 @@ def get_logging_config(self, log_level: Optional[int] = None) -> dict[str, Any]: "class": "logging.FileHandler", "filename": f"{self.cfg.logging_dir}/torch_tensorrt_logging.log", "formatter": "standard", + "mode": "w", # This will clear the previous content } config["loggers"][""]["handlers"].append("file") return config diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 5ba84b09b0..a1f8b9c6be 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -36,9 +36,16 @@ def constant_fold( # The constants are created on CPU to save GPU memory for TensorRT compilation. # For TRT INetwork construction the constants are moved to CPU in get_attr call. for node, constant in cf.node_replacements.items(): - replace_node_with_constant( - gm, node, torch.nn.Parameter(constant, requires_grad=False) - ) + if settings.offload_module_to_cpu: + replace_node_with_constant( + gm, + node, + torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False), + ) + else: + replace_node_with_constant( + gm, node, torch.nn.Parameter(constant, requires_grad=False) + ) erased_params = [] for node in gm.graph.nodes: diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 97328acd6d..711a07ee27 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -1,7 +1,10 @@ from __future__ import annotations +import ctypes import gc import logging +import os +import platform import warnings from dataclasses import fields, replace from enum import Enum @@ -17,6 +20,7 @@ ) import numpy as np +import psutil import sympy import tensorrt as trt import torch @@ -853,3 +857,39 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype] f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple/list of torch.fx.node.Node" ) return output_dtypes + + +def is_tegra_platform() -> bool: + if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]: + return True + return False + + +def is_thor() -> bool: + if torch.cuda.get_device_capability() in [(11, 0)]: + return True + return False + + +def get_cpu_memory_usage() -> Any: + return psutil.Process().memory_info().rss / 1024 / 1024 + + +def release_host_and_device_memory() -> None: + gc.collect() + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + torch.cuda.synchronize() + + if ( + platform.system() == "Linux" + and os.environ.get("TORCHTRT_ENABLE_BUILDER_MALLOC_TRIM", "0") == "1" + ): + try: + libc = ctypes.CDLL("libc.so.6") + if libc.malloc_trim(0) != 1: + logger.warning("Failed to release CPU memory.") + except Exception: + logger.warning("Failed to release CPU memory.") diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 31eaca5917..b1b73cf3b8 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -208,8 +208,9 @@ def run_test( interpreter_result = interpreter.run() sec = time.perf_counter() - start _LOGGER.info(f"Interpreter run time(s): {sec}") + serialized_engine = interpreter_result.engine.serialize() trt_mod = rt_cls( - serialized_engine=interpreter_result.serialized_engine, + serialized_engine=serialized_engine, input_binding_names=list(interpreter_result.input_names), output_binding_names=list(interpreter_result.output_names), name="test_engine", @@ -291,8 +292,9 @@ def run_test_custom_compare_results( self.assert_has_op(mod, expected_ops) interpreter_result = interpreter.run() + serialized_engine = interpreter_result.engine.serialize() trt_mod = rt_cls( - serialized_engine=interpreter_result.serialized_engine, + serialized_engine=serialized_engine, input_binding_names=list(interpreter_result.input_names), output_binding_names=list(interpreter_result.output_names), name="test_engine", diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index d4133ff4b4..a1600e46eb 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -55,6 +55,52 @@ def test_resnet18(ir): torch._dynamo.reset() +def compile_one(idx: int, ir: str): + model = models.resnet18(pretrained=True).eval().to("cuda") + input = torch.randn((idx + 1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "pass_through_build_failures": True, + "optimization_level": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"In multiprocess compilation test, process {idx} failed: Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) +def test_resnet18_multiprocess(ir): + import torch.multiprocessing as mp + + mp.set_start_method("spawn", force=True) + procs = [] + for i in range(3): + p = mp.Process(target=compile_one, args=(i, ir)) + p.start() + procs.append(p) + for p in procs: + p.join() + torch._dynamo.reset() + + @pytest.mark.unit @unittest.skipIf( not importlib.util.find_spec("torchvision"), diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index a82384fda9..c86ee6f3a4 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -23,6 +23,7 @@ torch.ops.aten.scaled_dot_product_attention.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, ) @@ -43,6 +44,7 @@ def _remove_decompositions(): REPLACEABLE_ATEN_OPS = { torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, } from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( @@ -79,7 +81,10 @@ def _process_sdpa_node( ValueError: If the SDPA node has an unexpected number of arguments """ - if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default: + if node.target in [ + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + ]: if len(node.args) == 7: ( query,