diff --git a/examples/distributed_inference/tensor_parallel_initialize_dist.py b/examples/distributed_inference/tensor_parallel_initialize_dist.py index 98d3ca18e9..5fffb3fa00 100644 --- a/examples/distributed_inference/tensor_parallel_initialize_dist.py +++ b/examples/distributed_inference/tensor_parallel_initialize_dist.py @@ -14,32 +14,13 @@ import tensorrt as trt import torch import torch.distributed as dist -from torch.distributed._tensor.device_mesh import init_device_mesh +from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh +logger = logging.getLogger(__name__) -def find_repo_root(max_depth=10): - dir_path = os.path.dirname(os.path.realpath(__file__)) - for i in range(max_depth): - files = os.listdir(dir_path) - if "MODULE.bazel" in files: - return dir_path - else: - dir_path = os.path.dirname(dir_path) - raise RuntimeError("Could not find repo root") - - -def initialize_logger(rank, logger_file_name): - logger = logging.getLogger() - logger.setLevel(logging.INFO) - fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") - fh.setLevel(logging.INFO) - logger.addHandler(fh) - return logger - - -# This is required for env initialization since we use mpirun -def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500): +# this is kept at the application level, when mpirun is used to run the application +def initialize_distributed_env(rank=0, world_size=1, port=29500): local_rank = int( os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) ) @@ -50,9 +31,6 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950 os.environ["WORLD_SIZE"] = str(world_size) os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(port) - os.environ["TRTLLM_PLUGINS_PATH"] = ( - find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so" - ) # Necessary to assign a device to each rank. torch.cuda.set_device(local_rank) @@ -66,16 +44,49 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950 device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) rank = device_mesh.get_rank() assert rank == local_rank - logger = initialize_logger(rank, logger_file_name) device_id = ( rank % torch.cuda.device_count() ) # Ensure each rank gets a unique device torch.cuda.set_device(device_id) - return device_mesh, world_size, rank, logger + return device_mesh, world_size, rank def cleanup_distributed_env(): """Clean up distributed process group to prevent resource leaks.""" if dist.is_initialized(): dist.destroy_process_group() + + +def check_tensor_parallel_device_number(world_size: int) -> None: + if world_size % 2 != 0: + raise ValueError( + f"TP examples require even number of GPUs, but got {world_size} gpus" + ) + + +def get_tensor_parallel_device_mesh( + rank: int = 0, world_size: int = 1 +) -> tuple[DeviceMesh, int, int]: + local_rank = int( + os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) + ) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) + device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) + rank = device_mesh.get_rank() + assert rank == local_rank + device_id = ( + rank % torch.cuda.device_count() + ) # Ensure each rank gets a unique device + torch.cuda.set_device(device_id) + + return device_mesh, world_size, rank + + +def initialize_distributed_logger(rank: int, logger_file_name: str) -> logging.Logger: + logger = logging.getLogger() + logger.setLevel(logging.INFO) + fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") + fh.setLevel(logging.INFO) + logger.addHandler(fh) + return logger diff --git a/examples/distributed_inference/tensor_parallel_rotary_embedding.py b/examples/distributed_inference/tensor_parallel_rotary_embedding.py index da3f3fd8fd..7a55497703 100644 --- a/examples/distributed_inference/tensor_parallel_rotary_embedding.py +++ b/examples/distributed_inference/tensor_parallel_rotary_embedding.py @@ -9,26 +9,31 @@ """ -import logging -import os import time import torch -import torch_tensorrt -from rotary_embedding import RotaryAttention, parallel_rotary_block +import torch.distributed as dist from tensor_parallel_initialize_dist import ( cleanup_distributed_env, + get_tensor_parallel_device_mesh, initialize_distributed_env, + initialize_distributed_logger, ) -device_mesh, _world_size, _rank, logger = initialize_distributed_env( - "./tensor_parallel_rotary_embedding" -) +if not dist.is_initialized(): + initialize_distributed_env() +import torch_tensorrt + +device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh() +logger = initialize_distributed_logger(_rank, "tensor_parallel_rotary_embedding") + +from rotary_embedding import RotaryAttention, parallel_rotary_block """ This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning -Command to run with single GPU: mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.py +Command to run with single GPU: USE_TRTLLM_PLUGINS=1 mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.py +Command to run with 2 GPUs: USE_TRTLLM_PLUGINS=1 mpirun -n 2 --allow-run-as-root python tensor_parallel_rotary_embedding.py """ BATCH = 2 diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index c5688c6e5b..bf0c13560f 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -16,7 +16,7 @@ ----- .. code-block:: bash - mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py + USE_TRTLLM_PLUGINS=1 mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py """ import time @@ -25,22 +25,31 @@ import torch import torch.distributed as dist import torch.nn as nn -import torch_tensorrt from tensor_parallel_initialize_dist import ( cleanup_distributed_env, + get_tensor_parallel_device_mesh, initialize_distributed_env, + initialize_distributed_logger, ) + +if not dist.is_initialized(): + initialize_distributed_env() +import torch_tensorrt from torch.distributed._tensor import Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, RowwiseParallel, parallelize_module, ) - -device_mesh, _world_size, _rank, logger = initialize_distributed_env( - "./tensor_parallel_simple_example" +from torch_tensorrt.dynamo.distributed.utils import ( + get_tensor_parallel_device_mesh, + initialize_distributed_logger, ) +device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh() +logger = initialize_distributed_logger(_rank, "tensor_parallel_simple_example") + + """ This example takes some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py """ diff --git a/py/torch_tensorrt/_utils.py b/py/torch_tensorrt/_utils.py index f59dce9b1c..a259a54997 100644 --- a/py/torch_tensorrt/_utils.py +++ b/py/torch_tensorrt/_utils.py @@ -5,6 +5,7 @@ import platform import sys import tempfile +import time import urllib.request from pathlib import Path from typing import Any, Optional @@ -143,13 +144,65 @@ def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path: ) +def extract_wheel_file(wheel_path: Path, extract_dir: Path) -> None: + """ + Safely extract a wheel file to a directory with a lock to prevent concurrent extraction. + """ + rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0)) # MPI rank from OpenMPI + torch.cuda.set_device(rank) + lock_file = extract_dir / ".extracting" + + # Rank 0 performs extraction + if rank == 0: + logger.debug( + f"[Rank {rank}] Starting extraction of {wheel_path} to {extract_dir}" + ) + try: + import zipfile + except ImportError as e: + raise ImportError( + "zipfile module is required but not found. Please install zipfile" + ) + # Create lock file to signal extraction in progress + extract_dir.mkdir(parents=True, exist_ok=False) + lock_file.touch(exist_ok=False) + try: + with zipfile.ZipFile(wheel_path) as zip_ref: + zip_ref.extractall(extract_dir) + logger.debug(f"[Rank {rank}] Extraction complete: {extract_dir}") + except FileNotFoundError as e: + logger.error(f"[Rank {rank}] Wheel file not found at {wheel_path}: {e}") + raise RuntimeError( + f"Failed to find downloaded wheel file at {wheel_path}" + ) from e + except zipfile.BadZipFile as e: + logger.error(f"[Rank {rank}] Invalid or corrupted wheel file: {e}") + raise RuntimeError( + "Downloaded wheel file is corrupted or not a valid zip archive" + ) from e + except Exception as e: + logger.error(f"[Rank {rank}] Unexpected error while extracting wheel: {e}") + raise RuntimeError( + "Unexpected error during extraction of TensorRT-LLM wheel" + ) from e + finally: + # Remove lock file to signal completion + lock_file.unlink(missing_ok=True) + + else: + # Other ranks wait for extraction to complete + while lock_file.exists(): + logger.debug( + f"[Rank {rank}] Waiting for extraction to finish at {extract_dir}..." + ) + time.sleep(0.5) + + def download_and_get_plugin_lib_path() -> Optional[str]: """ Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary. - Args: platform (str): Platform identifier (e.g., 'linux_x86_64') - Returns: Optional[str]: Path to shared library or None if operation fails. """ @@ -174,7 +227,6 @@ def download_and_get_plugin_lib_path() -> Optional[str]: return str(plugin_lib_path) wheel_path.parent.mkdir(parents=True, exist_ok=True) - extract_dir.mkdir(parents=True, exist_ok=True) if not wheel_path.exists(): base_url = "https://pypi.nvidia.com/tensorrt-llm/" @@ -194,32 +246,7 @@ def download_and_get_plugin_lib_path() -> Optional[str]: except OSError as e: logger.error(f"Local file write error: {e}") - try: - import zipfile - except ImportError as e: - raise ImportError( - "zipfile module is required but not found. Please install zipfile" - ) - try: - with zipfile.ZipFile(wheel_path) as zip_ref: - zip_ref.extractall(extract_dir) - logger.debug(f"Extracted wheel to {extract_dir}") - except FileNotFoundError as e: - # This should capture the errors in the download failure above - logger.error(f"Wheel file not found at {wheel_path}: {e}") - raise RuntimeError( - f"Failed to find downloaded wheel file at {wheel_path}" - ) from e - except zipfile.BadZipFile as e: - logger.error(f"Invalid or corrupted wheel file: {e}") - raise RuntimeError( - "Downloaded wheel file is corrupted or not a valid zip archive" - ) from e - except Exception as e: - logger.error(f"Unexpected error while extracting wheel: {e}") - raise RuntimeError( - "Unexpected error during extraction of TensorRT-LLM wheel" - ) from e + extract_wheel_file(wheel_path, extract_dir) try: wheel_path.unlink(missing_ok=True) @@ -238,10 +265,8 @@ def download_and_get_plugin_lib_path() -> Optional[str]: def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool: """ Loads and initializes the TensorRT-LLM plugin from the given shared library path. - Args: plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library. - Returns: bool: True if successful, False otherwise. """ @@ -293,7 +318,6 @@ def load_tensorrt_llm_for_nccl() -> bool: Attempts to load the TensorRT-LLM plugin and initialize it. Either the env variable TRTLLM_PLUGINS_PATH can specify the path Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it - Returns: bool: True if the plugin was successfully loaded and initialized, False otherwise. """ diff --git a/tests/py/dynamo/distributed/distributed_utils.py b/tests/py/dynamo/distributed/distributed_utils.py index bc058aaaec..6d13ecb1a1 100644 --- a/tests/py/dynamo/distributed/distributed_utils.py +++ b/tests/py/dynamo/distributed/distributed_utils.py @@ -1,5 +1,6 @@ import logging import os +import random import numpy as np import tensorrt as trt @@ -8,24 +9,21 @@ from torch.distributed._tensor.device_mesh import init_device_mesh -def set_environment_variables_pytest(): +# the below two functions are used to set the environment variables for the pytest single and multi process +# this is for the github CI where we use pytest +def set_environment_variables_pytest_single_process(): + port = 29500 + random.randint(1, 1000) os.environ["WORLD_SIZE"] = str(1) os.environ["RANK"] = str(0) os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(29500) - - -def initialize_logger(rank, logger_file_name): - logger = logging.getLogger() - logger.setLevel(logging.INFO) - fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") - fh.setLevel(logging.INFO) - logger.addHandler(fh) - return logger + os.environ["MASTER_PORT"] = str(port) -# This is required for env initialization since we use mpirun -def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500): +def set_environment_variables_pytest_multi_process( + rank: int = 0, world_size: int = 1 +) -> None: + port = 29500 + random.randint(1, 1000) + # these variables are set by mpirun -n 2 local_rank = int( os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) ) @@ -36,7 +34,6 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950 os.environ["WORLD_SIZE"] = str(world_size) os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(port) - os.environ["TRTLLM_PLUGINS_PATH"] = "./tmp/lib/libnvinfer_plugin_tensorrt_llm.so" # Necessary to assign a device to each rank. torch.cuda.set_device(local_rank) @@ -46,14 +43,3 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950 # set a manual seed for reproducibility torch.manual_seed(1111) - - device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) - rank = device_mesh.get_rank() - assert rank == local_rank - logger = initialize_logger(rank, logger_file_name) - device_id = ( - rank % torch.cuda.device_count() - ) # Ensure each rank gets a unique device - torch.cuda.set_device(device_id) - - return device_mesh, world_size, rank, logger diff --git a/tests/py/dynamo/distributed/test_nccl_ops.py b/tests/py/dynamo/distributed/test_nccl_ops.py index eafe16d455..d239179d23 100644 --- a/tests/py/dynamo/distributed/test_nccl_ops.py +++ b/tests/py/dynamo/distributed/test_nccl_ops.py @@ -5,11 +5,26 @@ import torch.distributed as dist import torch.nn as nn from conversion.harness import DispatchTestCase -from distributed_utils import set_environment_variables_pytest + +from distributed_utils import ( + set_environment_variables_pytest_multi_process, + set_environment_variables_pytest_single_process, +) from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt._utils import is_platform_supported_for_trtllm +if "OMPI_COMM_WORLD_SIZE" in os.environ: + set_environment_variables_pytest_multi_process() +else: + set_environment_variables_pytest_single_process() + +if not dist.is_initialized(): + dist.init_process_group( + backend="nccl", + init_method="env://", + ) + class DistributedGatherModel(nn.Module): def __init__(self, input_dim, world_size, group_name): @@ -48,11 +63,9 @@ class TestNcclOpsConverter(DispatchTestCase): ) @classmethod def setUpClass(cls): - set_environment_variables_pytest() - cls.world_size = 1 - if not dist.is_initialized(): - dist.init_process_group(backend="nccl") - cls.group = dist.new_group(ranks=[0]) + cls.world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1)) + cls.rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0)) + cls.group = dist.new_group(ranks=list(range(cls.world_size))) cls.group_name = cls.group.group_name @classmethod