diff --git a/examples/id_embeddings/__init__.py b/examples/id_embeddings/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/id_embeddings/configs/e2e_hom_cora_sup_task_config.yaml b/examples/id_embeddings/configs/e2e_hom_cora_sup_task_config.yaml new file mode 100644 index 000000000..38d0047f1 --- /dev/null +++ b/examples/id_embeddings/configs/e2e_hom_cora_sup_task_config.yaml @@ -0,0 +1,41 @@ +# This config is used to run homogeneous CORA supervised training and inference using in memory GiGL SGS. This can be run with `make run_hom_cora_sup_test`. +graphMetadata: + edgeTypes: + - dstNodeType: paper + relation: cites + srcNodeType: paper + nodeTypes: + - paper +datasetConfig: + dataPreprocessorConfig: + dataPreprocessorConfigClsPath: gigl.src.mocking.mocking_assets.passthrough_preprocessor_config_for_mocked_assets.PassthroughPreprocessorConfigForMockedAssets + dataPreprocessorArgs: + # This argument is specific for the `PassthroughPreprocessorConfigForMockedAssets` preprocessor to indicate which dataset we should be using + mocked_dataset_name: 'cora_homogeneous_node_anchor_edge_features_user_defined_labels' +trainerConfig: + trainerArgs: + # Example argument to trainer + log_every_n_batch: "50" # Frequency in which we log batch information + num_neighbors: "[10, 10]" # Fanout per hop, specified as a string representation of a list for the homogeneous use case + command: python -m examples.id_embeddings.homogeneous_training +inferencerConfig: + inferencerArgs: + # Example argument to inferencer + log_every_n_batch: "50" # Frequency in which we log batch information + num_neighbors: "[10, 10]" # Fanout per hop, specified as a string representation of a list for the homogeneous use case + inferenceBatchSize: 512 + command: python -m examples.id_embeddings.homogeneous_inference +sharedConfig: + shouldSkipAutomaticTempAssetCleanup: false + shouldSkipInference: false + # Model Evaluation is currently only supported for tabularized SGS GiGL pipelines. This will soon be added for in-mem SGS GiGL pipelines. + shouldSkipModelEvaluation: true +taskMetadata: + nodeAnchorBasedLinkPredictionTaskMetadata: + supervisionEdgeTypes: + - dstNodeType: paper + relation: cites + srcNodeType: paper +featureFlags: + should_run_glt_backend: 'True' + data_preprocessor_num_shards: '2' diff --git a/examples/id_embeddings/homogeneous_training.py b/examples/id_embeddings/homogeneous_training.py new file mode 100644 index 000000000..00ba29a22 --- /dev/null +++ b/examples/id_embeddings/homogeneous_training.py @@ -0,0 +1,912 @@ +""" +This file contains an example for how to run homogeneous training using live subgraph sampling powered by GraphLearn-for-PyTorch (GLT). +While `run_example_training` is coupled with GiGL orchestration, the `_training_process` and `testing_process` functions are generic +and can be used as references for writing training for pipelines not dependent on GiGL orchestration. + +To run this file with GiGL orchestration, set the fields similar to below: + +trainerConfig: + trainerArgs: + # Example argument to trainer + log_every_n_batch: "50" + command: python -m examples.id_embeddings.homogeneous_training +featureFlags: + should_run_glt_backend: 'True' + +You can run this example in a full pipeline with `make run_hom_cora_sup_test` from GiGL root. + +Given a frozen task config with some already populated data preprocessor output, the following training script can be run locally using: +WORLD_SIZE=1 RANK=0 MASTER_ADDR="localhost" MASTER_PORT=20000 python -m examples.id_embeddings.homogeneous_training --task_config_uri= + +A frozen task config with data preprocessor outputs can be generated by running an e2e pipeline with `stop_after=data_preprocessor` and using the +frozen config generated from the `config_populator` component after the run has completed. +""" + +from __future__ import annotations +import os + +# Suppress TensorFlow logs +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # isort: skip + +import argparse +import statistics +import time +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Literal, Optional, Sequence, Union + +import torch +import torch.distributed +import torch.multiprocessing as mp +from torch import nn +from torch.optim import AdamW +from torch_geometric.data import Data +from torchrec.distributed.embedding_types import EmbeddingComputeKernel + +# TorchRec / DMP imports +from torchrec.distributed.model_parallel import DistributedModelParallel as DMP + +import gigl.distributed.utils +from gigl.common import Uri, UriFactory +from gigl.common.logger import Logger +from gigl.common.utils.torch_training import is_distributed_available_and_initialized +from gigl.distributed import ( + DistABLPLoader, + DistDataset, + build_dataset_from_task_config_uri, +) +from gigl.distributed.distributed_neighborloader import DistNeighborLoader +from gigl.distributed.utils import get_available_device + +from gigl.module.models import LightGCN, LinkPredictionGNN +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict +from gigl.types.graph import to_homogeneous +from gigl.utils.iterator import InfiniteIterator +from gigl.utils.sampling import parse_fanout + +logger = Logger() + + +@dataclass +class DMPConfig: + device: torch.device + world_size: int + local_world_size: int + pg: Optional[torch.distributed.ProcessGroup] = None + compute_device: str = "cuda" # or "cpu" + prefer_sharding_types: Optional[Sequence[str]] = ("table_wise", "row_wise") + compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED + + +def wrap_with_dmp( + model: nn.Module, cfg: DMPConfig +) -> Union[tuple[nn.Module, object], nn.Module]: + """Wraps `model` with TorchRec DMP (shards EBCs, DP for the rest). Returns (dmp_model, plan).""" + # model.to(cfg.device) + + # sharders = [EmbeddingBagCollectionSharder()] + + # constraints = None + # if cfg.prefer_sharding_types: + # constraints = { + # "_embedding_bag_collection": ParameterConstraints( + # sharding_types=list(cfg.prefer_sharding_types), + # compute_kernel=cfg.compute_kernel, + # ) + # } + + # topology = Topology( + # world_size=cfg.world_size, + # local_world_size=cfg.local_world_size, + # compute_device=cfg.compute_device, + # ) + # planner = EmbeddingShardingPlanner(topology=topology, constraints=constraints) + # env = ShardingEnv.from_process_group(cfg.pg) + # plan = planner.collective_plan(module=model, sharders=sharders, pg=cfg.pg) + + # dmp_model = DMP( + # module=model, + # env=env, + # plan=plan, + # sharders=sharders, + # device=cfg.device, + # init_data_parallel=True, + # ) + + dmp_model = DMP(module=model, device=cfg.device) + + return dmp_model + + +def unwrap_from_dmp(model: nn.Module) -> nn.Module: + """Return the underlying nn.Module if wrapped by DMP, otherwise the module itself.""" + return getattr(model, "module", model) + + +def _sync_metric_across_processes(metric: torch.Tensor) -> float: + """ + Takes the average of a training metric across multiple processes. Note that this function requires DDP/DMP to be initialized. + Args: + metric (torch.Tensor): The metric, expressed as a torch Tensor, which should be synced across multiple processes + Returns: + float: The average of the provided metric across all training processes + """ + assert is_distributed_available_and_initialized(), "DDP/DMP is not initialized" + # Make a copy of the local loss tensor + loss_tensor = metric.detach().clone() + torch.distributed.all_reduce(loss_tensor, op=torch.distributed.ReduceOp.SUM) + return loss_tensor.item() / torch.distributed.get_world_size() + + +def _setup_dataloaders( + dataset: DistDataset, + split: Literal["train", "val", "test"], + num_neighbors: list[int], + sampling_workers_per_process: int, + main_batch_size: int, + random_batch_size: int, + device: torch.device, + sampling_worker_shared_channel_size: str, + process_start_gap_seconds: int, +) -> tuple[DistABLPLoader, DistNeighborLoader]: + """ + Sets up main and random dataloaders for training and testing purposes + Args: + dataset (DistDataset): Loaded Distributed Dataset for training and testing + split (Literal["train", "val", "test"]): The current split which we are loading data for + num_neighbors: list[int]: Fanout for subgraph sampling, where the ith item corresponds to the number of items to sample for the ith hop + sampling_workers_per_process (int): sampling_workers_per_process (int): Number of sampling workers per training/testing process + main_batch_size (int): Batch size for main dataloader with query and labeled nodes + random_batch_size (int): Batch size for random negative dataloader + device (torch.device): Device to put loaded subgraphs on + sampling_worker_shared_channel_size (str): Shared-memory buffer size (bytes) allocated for the channel during sampling + process_start_gap_seconds (int): The amount of time to sleep for initializing each dataloader. For large-scale settings, consider setting this + field to 30-60 seconds to ensure dataloaders don't compete for memory during initialization, causing OOM. + Returns: + DistABLPLoader: Dataloader for loading main batch data with query and labeled nodes + DistNeighborLoader: Dataloader for loading random negative data + """ + + rank = torch.distributed.get_rank() + + if split == "train": + main_input_nodes = to_homogeneous(dataset.train_node_ids) + shuffle = True + elif split == "val": + main_input_nodes = to_homogeneous(dataset.val_node_ids) + shuffle = False + else: + main_input_nodes = to_homogeneous(dataset.test_node_ids) + shuffle = False + + main_loader = DistABLPLoader( + dataset=dataset, + num_neighbors=num_neighbors, + input_nodes=main_input_nodes, + num_workers=sampling_workers_per_process, + batch_size=main_batch_size, + pin_memory_device=device, + worker_concurrency=sampling_workers_per_process, + channel_size=sampling_worker_shared_channel_size, + # Each main_loader will wait for `process_start_gap_seconds` * `local_process_rank` seconds before initializing to reduce peak memory usage. + # This is done so that each process on the current machine which initializes a `main_loader` doesn't compete for memory, causing potential OOM + process_start_gap_seconds=process_start_gap_seconds, + shuffle=shuffle, + ) + + logger.info(f"---Rank {rank} finished setting up main loader") + + # We need to wait for all processes to finish initializing the main_loader before creating the random_negative_loader so that its initialization doesn't compete for memory with the main_loader, causing potential OOM. + torch.distributed.barrier() + + random_negative_loader = DistNeighborLoader( + dataset=dataset, + num_neighbors=num_neighbors, + input_nodes=to_homogeneous(dataset.node_ids), + num_workers=sampling_workers_per_process, + batch_size=random_batch_size, + pin_memory_device=device, + worker_concurrency=sampling_workers_per_process, + channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + shuffle=shuffle, + ) + + logger.info(f"--Rank {rank} finished setting up random negative loader") + + # Wait for all processes to finish initializing the random_loader + torch.distributed.barrier() + + return main_loader, random_negative_loader + + +def bpr_loss( + query_emb: torch.Tensor, # [M, D] + pos_emb: torch.Tensor, # [M, D] + neg_emb: torch.Tensor, # [M, D] or [M, K, D] + l2_lambda: float = 0.0, + l2_params: Optional[Sequence[torch.Tensor]] = None, +) -> torch.Tensor: + """Bayesian Personalized Ranking loss with dot-product scores. + + + Supports one negative per positive ([M, D]) or K negatives ([M, K, D]). + """ + # s_pos: [M] + s_pos = (query_emb * pos_emb).sum(dim=-1) + + if neg_emb.dim() == 2: # [M, D] + s_neg = (query_emb * neg_emb).sum(dim=-1) + loss = -torch.nn.functional.logsigmoid(s_pos - s_neg) + elif neg_emb.dim() == 3: # [M, K, D] + # Broadcast query: [M, 1, D] + s_neg = (query_emb.unsqueeze(1) * neg_emb).sum(dim=-1) # [M, K] + loss = -torch.nn.functional.logsigmoid(s_pos.unsqueeze(1) - s_neg).mean(dim=1) + else: + raise ValueError("neg_emb must be [M, D] or [M, K, D]") + + loss = loss.mean() + + if l2_lambda > 0.0 and l2_params: + l2 = sum(p.pow(2).sum() for p in l2_params) + loss = loss + l2_lambda * l2 + + return loss + + +def _compute_bpr_batch( + model: nn.Module, + main_data: Data, + random_negative_data: Data, + device: torch.device, + num_random_negs_per_pos: int = 1, + use_hard_negs: bool = True, + l2_lambda: float = 0.0, +) -> torch.Tensor: + """Compute a BPR batch using LightGCN embeddings and GLT-batched indices. + + + Strategy: one (or K) random negative(s) per positive. If hard negatives exist in the batch, + we concatenate them as additional negatives (weighting equally). + """ + + # Encode + main_emb = model(data=main_data, device=device) # [N_main, D] + rand_emb = model(data=random_negative_data, device=device) # [N_rand, D] + + # Query indices and positives from the main batch + B = int(main_data.batch_size) + query_idx = torch.arange(B, device=device) # [B] + + pos_idx = torch.cat(list(main_data.y_positive.values())).to(device) # [M] + # Repeat queries to align with positives + rep_query_idx = query_idx.repeat_interleave( + torch.tensor([len(v) for v in main_data.y_positive.values()], device=device) + ) # [M] + + # Optional hard negatives from the main batch + if use_hard_negs and hasattr(main_data, "y_negative"): + hard_neg_idx = torch.cat(list(main_data.y_negative.values())).to(device) + hard_neg_emb = main_emb[hard_neg_idx] # [H, D] + else: + hard_neg_idx = torch.empty(0, dtype=torch.long, device=device) + hard_neg_emb = torch.empty(0, main_emb.size(1), device=device) + + # Random negatives: take the first K*M rows from rand_emb for simplicity + M = rep_query_idx.numel() + D = main_emb.size(1) + + total_needed = M * max(1, num_random_negs_per_pos) + if rand_emb.size(0) < total_needed: + # Tile if fewer than needed + tile = (total_needed + rand_emb.size(0) - 1) // rand_emb.size(0) + rand_pool = rand_emb.repeat(tile, 1)[:total_needed] + else: + rand_pool = rand_emb[:total_needed] + + if num_random_negs_per_pos == 1: + rand_neg_emb = rand_pool # [M, D] + else: + rand_neg_emb = rand_pool.view(M, num_random_negs_per_pos, D) # [M, K, D] + + # Positive and query embeddings + q = main_emb[rep_query_idx] # [M, D] + pos = main_emb[pos_idx] # [M, D] + + # If we have hard negatives, merge with random negatives by stacking along K + if hard_neg_emb.numel() > 0: + # Align hard negatives count to M. + if hard_neg_emb.size(0) < M: + ht = (M + hard_neg_emb.size(0) - 1) // hard_neg_emb.size(0) + hard_neg_emb = hard_neg_emb.repeat(ht, 1)[:M] + if rand_neg_emb.dim() == 2: # [M, D] + neg = torch.stack([rand_neg_emb, hard_neg_emb], dim=1) # [M, 2, D] + else: # [M, K, D] + neg = torch.cat( + [rand_neg_emb, hard_neg_emb.unsqueeze(1)], dim=1 + ) # [M, K+1, D] + else: + neg = rand_neg_emb # [M, D] or [M, K, D] + + loss = bpr_loss(q, pos, neg, l2_lambda=l2_lambda, l2_params=None) + return loss + +def _training_process( + local_rank: int, + local_world_size: int, + machine_rank: int, + machine_world_size: int, + dataset: DistDataset, + num_nodes_total: int, + embedding_dim: int, + num_layers: int, + master_ip_address: str, + master_default_process_group_port: int, + model_uri: Uri, + num_neighbors: list[int], + sampling_workers_per_process: int, + main_batch_size: int, + random_batch_size: int, + sampling_worker_shared_channel_size: str, + process_start_gap_seconds: int, + log_every_n_batch: int, + learning_rate: float, + weight_decay: float, + num_max_train_batches: int, + num_val_batches: int, + val_every_n_batch: int, + should_skip_training: bool, + num_random_negs_per_pos: int, + l2_lambda: float, +) -> None: + """ + This function is spawned by each machine for training a GNN model given some loaded distributed dataset. + Args: + local_rank (int): Process number on the current machine + local_world_size (int): Number of training processes spawned by each machine + machine_rank (int): Rank of the current machine + machine_world_size (int): Total number of machines + dataset (DistDataset): Loaded Distributed Dataset for training + num_nodes_total (int): Total number of nodes in the dataset + embedding_dim (int): Embedding dimension for the model + num_layers (int): Number of layers in the model + master_ip_address (str): IP Address of the master worker for distributed communication + master_default_process_group_port (int): Port on the master worker for setting up distributed process group communication + model_uri (Uri): URI Path to save the model to + num_neighbors: list[int]: Fanout for subgraph sampling, where the ith item corresponds to the number of items to sample for the ith hop + sampling_workers_per_process (int): Number of sampling workers per training process + main_batch_size (int): Batch size for main dataloader with query and labeled nodes + random_batch_size (int): Batch size for random negative dataloader + sampling_worker_shared_channel_size (str): Shared-memory buffer size (bytes) allocated for the channel during sampling + process_start_gap_seconds (int): The amount of time to sleep for initializing each dataloader. For large-scale settings, consider setting this + field to 30-60 seconds to ensure dataloaders don't compete for memory during initialization, causing OOM. + log_every_n_batch (int): The frequency we should log batch information when training + learning_rate (float): Learning rate for training + weight_decay (float): Weight decay for training + num_max_train_batches (int): The maximum number of batches to train for across all training processes + num_val_batches (int): The number of batches to do validation for across all training processes + val_every_n_batch: (int): The frequency we should log batch information when validating + should_skip_training (bool): Whether training should be skipped and we should only run testing. Assumes model has been uploaded to the model_uri. + num_random_negs_per_pos (int): Number of random negatives per positive + l2_lambda (float): L2 regularization parameter + """ + + world_size = machine_world_size * local_world_size + rank = machine_rank * local_world_size + local_rank + logger.info( + f"---Current training process rank: {rank}, training process world size: {world_size}" + ) + + torch.distributed.init_process_group( + backend="nccl" if torch.cuda.is_available() else "gloo", + init_method=f"tcp://{master_ip_address}:{master_default_process_group_port}", + rank=rank, + world_size=world_size, + ) + + logger.info(f"---Rank {rank} training process started") + + device = get_available_device(local_process_rank=local_rank) + if torch.cuda.is_available(): + torch.cuda.set_device(device) + logger.info(f"---Rank {rank} training process set device {device}") + + logger.info(f"---Rank {rank} training process group initialized") + + # loss_fn = RetrievalLoss( + # loss=torch.nn.CrossEntropyLoss(reduction="mean"), + # temperature=0.07, + # remove_accidental_hits=True, + # ) + + # Dataloaders + if not should_skip_training: + train_main_loader, train_random_negative_loader = _setup_dataloaders( + dataset=dataset, + split="train", + num_neighbors=num_neighbors, + sampling_workers_per_process=sampling_workers_per_process, + main_batch_size=main_batch_size, + random_batch_size=random_batch_size, + device=device, + sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + ) + + # We keep track of both the dataloader and the iterator for it + # so we can clean up resources from the dataloader later. + train_main_loader_iter = InfiniteIterator(train_main_loader) + train_random_negative_loader_iter = InfiniteIterator( + train_random_negative_loader + ) + + val_main_loader, val_random_negative_loader = _setup_dataloaders( + dataset=dataset, + split="val", + num_neighbors=num_neighbors, + sampling_workers_per_process=sampling_workers_per_process, + main_batch_size=main_batch_size, + random_batch_size=random_batch_size, + device=device, + sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + ) + + # We keep track of both the dataloader and the iterator for it + # so we can clean up resources from the dataloader later. + val_main_loader_iter = InfiniteIterator(val_main_loader) + val_random_negative_loader_iter = InfiniteIterator(val_random_negative_loader) + + # Build LightGCN + logger.info(f"---Rank {rank} building LightGCN model") + base_model = LightGCN( + node_type_to_num_nodes=num_nodes_total, + embedding_dim=embedding_dim, + num_layers=num_layers, + device=device, + ) + + logger.info(f"---Rank {rank} wrapping LightGCN model with DMP") + dmp_model = wrap_with_dmp( + base_model, + DMPConfig( + device=device, + world_size=world_size, + local_world_size=local_world_size, + pg=torch.distributed.group.WORLD, + compute_device="cuda" if device.type == "cuda" else "cpu", + prefer_sharding_types=("table_wise", "row_wise"), + compute_kernel=EmbeddingComputeKernel.FUSED, + ), + ) + + logger.info(f"---Rank {rank} initializing optimizer") + optimizer = AdamW( + params=dmp_model.parameters(), lr=learning_rate, weight_decay=weight_decay + ) + + logger.info( + f"Model initialized on rank {rank} training device {device}\n{dmp_model}" + ) + + # We add a barrier to wait for all processes to finish preparing the dataloader and initializing the model prior to the start of training + torch.distributed.barrier() + + # Train + if not should_skip_training: + # Entering the training loop + training_start_time = time.time() + batch_idx = 0 + # avg_train_loss = 0.0 + last_n_batch_avg_loss: list[float] = [] + last_n_batch_time: list[float] = [] + num_max_train_batches_per_process = max(1, num_max_train_batches // world_size) + num_val_batches_per_process = max(1, num_val_batches // world_size) + logger.info( + f"num_max_train_batches_per_process is set to {num_max_train_batches_per_process}" + ) + + dmp_model.train() + + # start_time gets updated every log_every_n_batch batches, batch_start gets updated every batch + batch_start = time.time() + for main_data, random_data in zip( + train_main_loader_iter, train_random_negative_loader_iter + ): + if batch_idx >= num_max_train_batches_per_process: + logger.info( + f"rank={rank}: num_max_train_batches_per_process={num_max_train_batches_per_process} reached, " + f"stopping training on machine {machine_rank} local rank {local_rank}" + ) + break + loss = _compute_bpr_batch( + model=dmp_model, + main_data=main_data, + random_negative_data=random_data, + device=device, + num_random_negs_per_pos=num_random_negs_per_pos, + use_hard_negs=True, + l2_lambda=l2_lambda, + ) + optimizer.zero_grad() + loss.backward() + optimizer.step() + avg_train_loss = _sync_metric_across_processes(metric=loss) + last_n_batch_avg_loss.append(avg_train_loss) + last_n_batch_time.append(time.time() - batch_start) + batch_start = time.time() + batch_idx += 1 + # log_every_n_batch = 50 + if batch_idx % log_every_n_batch == 0: + logger.info( + f"rank={rank}, batch={batch_idx}, latest local train_loss={loss:.6f}" + ) + if torch.cuda.is_available(): + # Wait for GPU operations to finish + torch.cuda.synchronize() + logger.info( + f"rank={rank}, mean(batch_time)={statistics.mean(last_n_batch_time):.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec" + ) + last_n_batch_time.clear() + # log the global average training loss + logger.info( + f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {log_every_n_batch} mean(avg_train_loss)={statistics.mean(last_n_batch_avg_loss):.6f}" + ) + last_n_batch_time.clear() + last_n_batch_avg_loss.clear() + + # if batch_idx % val_every_n_batch == 0: + # logger.info(f"rank={rank}, batch={batch_idx}, validating...") + # dmp_model.eval() + # _run_validation_loops( + # model=dmp_model, + # main_loader=val_main_loader_iter, + # random_negative_loader=val_random_negative_loader_iter, + # device=device, + # log_every_n_batch=log_every_n_batch, + # num_batches=num_val_batches_per_process, + # num_random_negs_per_pos=num_random_negs_per_pos, + # l2_lambda=l2_lambda, + # ) + + logger.info(f"---Rank {rank} finished training") + + # Memory cleanup and waiting for all processes to finish + if torch.cuda.is_available(): + torch.cuda.empty_cache() # Releases all unoccupied cached memory currently held by the caching allocator on the CUDA-enabled GPU + torch.cuda.synchronize() # Ensures all CUDA operations have finished + torch.distributed.barrier() # Waits for all processes to reach the current point + + # We explicitly shutdown all the dataloaders to reduce their memory footprint. Otherwise, experimentally we have + # observed that not all memory may be cleaned up, leading to OOM. + train_main_loader.shutdown() + train_random_negative_loader.shutdown() + val_main_loader.shutdown() + val_random_negative_loader.shutdown() + + # We save the model on the process with the 0th node rank and 0th local rank. + if machine_rank == 0 and local_rank == 0: + logger.info( + f"Training loop finished, took {time.time() - training_start_time:.3f} seconds, saving model to {model_uri}" + ) + # We unwrap the model from DDP to save it + # We do this so we can use the model without DDP later, e.g. for inference. + base_to_save = unwrap_from_dmp(dmp_model) + save_state_dict(model=base_to_save, save_to_path_uri=model_uri) + else: # should_skip_training is True, meaning we should only run testing + state_dict = load_state_dict_from_uri(load_from_uri=model_uri, device=device) + base_model.load_state_dict(state_dict) + logger.info(f"Model loaded on rank {rank}") + + # logger.info(f"---Rank {rank} started testing") + # testing_start_time = time.time() + # # model.eval() + + # test_main_loader, test_random_negative_loader = _setup_dataloaders( + # dataset=dataset, + # split="test", + # num_neighbors=num_neighbors, + # sampling_workers_per_process=sampling_workers_per_process, + # main_batch_size=main_batch_size, + # random_batch_size=random_batch_size, + # device=device, + # sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, + # process_start_gap_seconds=process_start_gap_seconds, + # ) + + # # We keep track of both the dataloader and the iterator for it + # # so we can clean up resources from the dataloader later. + # # Since we are doing testing, we only want to go through the data once, so we use iter instead of InfiniteIterator. + # test_main_loader_iter = iter(test_main_loader) + # test_random_negative_loader_iter = iter(test_random_negative_loader) + + # _run_validation_loops( + # model=dmp_model, + # main_loader=test_main_loader_iter, + # random_negative_loader=test_random_negative_loader_iter, + # device=device, + # log_every_n_batch=log_every_n_batch, + # num_batches=None, + # num_random_negs_per_pos=num_random_negs_per_pos, + # l2_lambda=l2_lambda, + # ) + + # # Memory cleanup and waiting for all processes to finish + # if torch.cuda.is_available(): + # torch.cuda.empty_cache() # Releases all unoccupied cached memory currently held by the caching allocator on the CUDA-enabled GPU + # torch.cuda.synchronize() # Ensures all CUDA operations have finished + # torch.distributed.barrier() # Waits for all processes to reach the current point + + # test_main_loader.shutdown() + # test_random_negative_loader.shutdown() + + # logger.info( + # f"---Rank {rank} finished testing in {time.time() - testing_start_time:.3f} seconds" + # ) + + torch.distributed.destroy_process_group() + + +@torch.inference_mode() +def _run_validation_loops( + model: nn.Module, + main_loader: Iterator[Data], + random_negative_loader: Iterator[Data], + device: torch.device, + log_every_n_batch: int, + num_batches: Optional[int] = None, + num_random_negs_per_pos: int = 1, + l2_lambda: float = 0.0, +) -> None: + """ + Runs validation using the provided models and dataloaders. + This function is shared for both validation while training and testing after training has completed. + Args: + model (LinkPredictionGNN): DDP-wrapped LinkPredictionGNN model for training and testing + main_loader (Iterator[Data]): Dataloader for loading main batch data with query and labeled nodes + random_negative_loader (Iterator[Data]): Dataloader for loading random negative data + device (torch.device): Device to use for training or testing + log_every_n_batch (int): The frequency we should log batch information when training and validating + num_batches (Optional[int]): The number of batches to run the validation loop for. If this is not set, this function will loop until the data loaders are exhausted. + For validation, this field is required to be set, as the data loaders are wrapped with InfiniteIterator. + num_random_negs_per_pos (int): The number of random negative samples to use for each positive sample + l2_lambda (float): The lambda value for the L2 regularization + """ + rank = torch.distributed.get_rank() + + logger.info( + f"Running validation loop on rank={rank}, num_random_negs_per_pos={num_random_negs_per_pos}, l2_lambda={l2_lambda}, num_batches={num_batches}" + ) + + batch_idx = 0 + batch_losses: list[float] = [] + last_n_batch_time: list[float] = [] + batch_start = time.time() + + while True: + if num_batches and batch_idx >= num_batches: + # If num_batches is set, we stop the validation loop after processing num_batches batches. This is not expected to be used for _testing_process. + break + try: + main_data = next(main_loader) + random_data = next(random_negative_loader) + except StopIteration: + # If the test data loader is exhausted, we stop the test loop. + # Note that validation dataloaders are infinite, so this is not expected to happen during training. + break + + loss = _compute_bpr_batch( + model=model, + main_data=main_data, + random_negative_data=random_data, + device=device, + num_random_negs_per_pos=num_random_negs_per_pos, + use_hard_negs=True, + l2_lambda=l2_lambda, + ) + + batch_losses.append(loss.item()) + last_n_batch_time.append(time.time() - batch_start) + batch_start = time.time() + batch_idx += 1 + if batch_idx % log_every_n_batch == 0: + logger.info(f"rank={rank}, batch={batch_idx}, latest test_loss={loss:.6f}") + if torch.cuda.is_available(): + # Wait for GPU operations to finish + torch.cuda.synchronize() + logger.info( + f"rank={rank}, batch={batch_idx}, mean(batch_time)={statistics.mean(last_n_batch_time):.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec" + ) + last_n_batch_time.clear() + + if batch_losses: + local_avg = float(sum(batch_losses) / len(batch_losses)) + else: + local_avg = float("nan") + + logger.info(f"rank={rank} finished validation loop, local loss: {local_avg=:.6f}") + device_for_tensor = device if device.type == "cuda" else torch.device("cpu") + global_avg_val_loss = _sync_metric_across_processes( + metric=torch.tensor(local_avg, device=device_for_tensor) + ) + logger.info(f"rank={rank} got global validation loss {global_avg_val_loss=:.6f}") + + return + + +def _run_example_training( + task_config_uri: str, +): + """ + Runs an example training + testing loop using GiGL Orchestration. + Args: + task_config_uri (str): Path to YAML-serialized GbmlConfig proto. + """ + start_time = time.time() + mp.set_start_method("spawn") + logger.info(f"Starting sub process method: {mp.get_start_method()}") + + gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( + gbml_config_uri=UriFactory.create_uri(task_config_uri) + ) + + # Training Hyperparameters for the training and test processes + + trainer_args = dict(gbml_config_pb_wrapper.trainer_config.trainer_args) + + local_world_size = int(trainer_args.get("local_world_size", "1")) + if torch.cuda.is_available(): + if local_world_size > torch.cuda.device_count(): + raise ValueError( + f"Specified a local world size of {local_world_size} which exceeds the number of devices {torch.cuda.device_count()}" + ) + + # Parses the fanout as a string. For the homogeneous case, the fanouts should be specified as a string of a list of integers, such as "[10, 10]". + fanout = trainer_args.get("num_neighbors", "[10, 10]") + num_neighbors = parse_fanout(fanout) + + # While the ideal value for `sampling_workers_per_process` has been identified to be between `2` and `4`, this may need some tuning depending on the + # pipeline. We default this value to `4` here for simplicity. A `sampling_workers_per_process` which is too small may not have enough parallelization for + # sampling, which would slow down training, while a value which is too large may slow down each sampling process due to competing resources, which would also + # then slow down training. + sampling_workers_per_process: int = int( + trainer_args.get("sampling_workers_per_process", "4") + ) + + main_batch_size = int(trainer_args.get("main_batch_size", "16")) + random_batch_size = int(trainer_args.get("random_batch_size", "16")) + + # LightGCN Hyperparameters + embedding_dim = int(trainer_args.get("embedding_dim", "64")) + num_layers = int(trainer_args.get("num_layers", "2")) + + # BPR params + num_random_negs_per_pos = int(trainer_args.get("num_random_negs_per_pos", "1")) + l2_lambda = float(trainer_args.get("l2_lambda", "0.0")) + + # This value represents the the shared-memory buffer size (bytes) allocated for the channel during sampling, and + # is the place to store pre-fetched data, so if it is too small then prefetching is limited, causing sampling slowdown. This parameter is a string + # with `{numeric_value}{storage_size}`, where storage size could be `MB`, `GB`, etc. We default this value to 4GB, + # but in production may need some tuning. + sampling_worker_shared_channel_size: str = trainer_args.get( + "sampling_worker_shared_channel_size", "4GB" + ) + + process_start_gap_seconds = int(trainer_args.get("process_start_gap_seconds", "0")) + log_every_n_batch = int(trainer_args.get("log_every_n_batch", "25")) + + learning_rate = float(trainer_args.get("learning_rate", "0.0005")) + weight_decay = float(trainer_args.get("weight_decay", "0.0005")) + num_max_train_batches = int(trainer_args.get("num_max_train_batches", "1000")) + num_val_batches = int(trainer_args.get("num_val_batches", "100")) + val_every_n_batch = int(trainer_args.get("val_every_n_batch", "50")) + + logger.info( + f"Got training args local_world_size={local_world_size}, \ + num_neighbors={num_neighbors}, \ + sampling_workers_per_process={sampling_workers_per_process}, \ + main_batch_size={main_batch_size}, \ + random_batch_size={random_batch_size}, \ + embedding_dim={embedding_dim}, \ + num_layers={num_layers}, \ + num_random_negs_per_pos={num_random_negs_per_pos}, \ + l2_lambda={l2_lambda}, \ + sampling_worker_shared_channel_size={sampling_worker_shared_channel_size}, \ + process_start_gap_seconds={process_start_gap_seconds}, \ + log_every_n_batch={log_every_n_batch}, \ + learning_rate={learning_rate}, \ + weight_decay={weight_decay}, \ + num_max_train_batches={num_max_train_batches}, \ + num_val_batches={num_val_batches}, \ + val_every_n_batch={val_every_n_batch}" + ) + + # This `init_process_group` is only called to get the master_ip_address, master port, and rank/world_size fields which help with partitioning, sampling, + # and distributed training/testing. We can use `gloo` here since these fields we are extracting don't require GPU capabilities provided by `nccl`. + # Note that this init_process_group uses env:// to setup the connection. + # In VAI we create one process per node thus these variables are exposed through env i.e. MASTER_PORT , MASTER_ADDR , WORLD_SIZE , RANK that VAI sets up for us. + # If running locally, these env variables will need to be setup by the user manually. + torch.distributed.init_process_group(backend="gloo") + + master_ip_address = gigl.distributed.utils.get_internal_ip_from_master_node() + machine_rank = torch.distributed.get_rank() + machine_world_size = torch.distributed.get_world_size() + master_default_process_group_port = ( + gigl.distributed.utils.get_free_ports_from_master_node(num_ports=1) + )[0] + # Destroying the process group as one will be re-initialized in the training process using above information + torch.distributed.destroy_process_group() + + logger.info(f"--- Launching data loading process ---") + dataset = build_dataset_from_task_config_uri( + task_config_uri=task_config_uri, + is_inference=False, + ) + + all_ids = to_homogeneous(dataset.node_ids) # Tensor of global node IDs + max_id = int(all_ids.max().item()) + num_nodes_total = max_id + 1 + + logger.info( + f"--- Data loading process finished, took {time.time() - start_time:.3f} seconds" + ) + + model_uri = UriFactory.create_uri( + gbml_config_pb_wrapper.gbml_config_pb.shared_config.trained_model_metadata.trained_model_uri + ) + + should_skip_training = gbml_config_pb_wrapper.shared_config.should_skip_training + + logger.info("--- Launching training processes ...\n") + start_time = time.time() + torch.multiprocessing.spawn( + _training_process, + args=( # Corresponding arguments in `_training_process` function + local_world_size, # local_world_size + machine_rank, # machine_rank + machine_world_size, # machine_world_size + dataset, # dataset + num_nodes_total, # num_nodes_total + embedding_dim, # embedding_dim + num_layers, # num_layers + master_ip_address, # master_ip_address + master_default_process_group_port, # master_default_process_group_port + model_uri, # model_uri + num_neighbors, # num_neighbors + sampling_workers_per_process, # sampling_workers_per_process + main_batch_size, # main_batch_size + random_batch_size, # random_batch_size + sampling_worker_shared_channel_size, # sampling_worker_shared_channel_size + process_start_gap_seconds, # process_start_gap_seconds + log_every_n_batch, # log_every_n_batch + learning_rate, # learning_rate + weight_decay, # weight_decay + num_max_train_batches, # num_max_train_batches + num_val_batches, # num_val_batches + val_every_n_batch, # val_every_n_batch + should_skip_training, # should_skip_training + num_random_negs_per_pos, # num_random_negs_per_pos + l2_lambda, # l2_lambda + ), + nprocs=local_world_size, + join=True, + ) + logger.info(f"--- Training finished, took {time.time() - start_time} seconds") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Arguments for distributed model training on VertexAI" + ) + parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") + + # We use parse_known_args instead of parse_args since we only need job_name and task_config_uri for distributed trainer + args, unused_args = parser.parse_known_args() + logger.info(f"Args: {args}") + logger.info(f"Unused arguments: {unused_args}") + + # We only need `task_config_uri` for running trainer + _run_example_training( + task_config_uri=args.task_config_uri, + ) diff --git a/python/gigl/module/models.py b/python/gigl/module/models.py index 20275d722..f1c9d2b99 100644 --- a/python/gigl/module/models.py +++ b/python/gigl/module/models.py @@ -8,6 +8,7 @@ from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.distributed.types import Awaitable from typing_extensions import Self from gigl.src.common.types.graph_data import NodeType @@ -293,6 +294,10 @@ def _forward_homogeneous( key, global_ids ) # shape [N_sub, D], where N_sub is number of nodes in subgraph and D is embedding_dim + # When using DMP, EmbeddingBagCollection returns Awaitable that needs to be resolved + if isinstance(embeddings_0, Awaitable): + embeddings_0 = embeddings_0.wait() + all_layer_embeddings: list[torch.Tensor] = [embeddings_0] embeddings_k = embeddings_0 diff --git a/python/tests/unit/module/models_test.py b/python/tests/unit/module/models_test.py index 7675ee7a6..625a46cbd 100644 --- a/python/tests/unit/module/models_test.py +++ b/python/tests/unit/module/models_test.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn +import torch.distributed as dist from torch_geometric.data import Data, HeteroData from torch_geometric.nn.models import LightGCN as PyGLightGCN @@ -283,6 +284,99 @@ def test_gradient_flow(self): self.assertIsNotNone(embedding_table.weight.grad) self.assertTrue(torch.any(embedding_table.weight.grad != 0)) + def test_dmp_wrapped_model_produces_correct_output(self): + """ + Test that DMP-wrapped LightGCN produces the same output as non-wrapped model. Note: We only test with a single process for unit test. + """ + from torchrec.distributed.model_parallel import DistributedModelParallel as DMP + + # Initialize distributed + if not dist.is_initialized(): + dist.init_process_group( + backend="gloo", + init_method="tcp://localhost:29500", + rank=0, + world_size=1, # Single process for unit test + ) + + # It is worth noting that when using CPU, we must set embeddings again after DMP wrapping + + try: + # Create model and move to device + model = self._create_lightgcn_model(self.num_nodes) + + # Wrap with DMP + dmp_model = DMP( + module=model, + device=self.device, + ) + + # Set embeddings After DMP wrapping + self._set_embeddings(model, "default_homogeneous_node_type") + + # Run forward pass on DMP-wrapped model + with torch.no_grad(): + output = dmp_model(data=self.data, device=self.device) + + # Verify output matches expected values + self.assertTrue( + torch.allclose(output, self.expected_output, atol=1e-4, rtol=1e-4), + f"DMP output doesn't match expected.\nGot:\n{output}\nExpected:\n{self.expected_output}", + ) + + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + def test_dmp_gradient_flow(self): + """ + Test that gradients flow properly through DMP-wrapped model. + """ + from torchrec.distributed.model_parallel import DistributedModelParallel as DMP + + # Initialize distributed + if not dist.is_initialized(): + dist.init_process_group( + backend="gloo", + init_method="tcp://localhost:29500", + rank=0, + world_size=1, + ) + + try: + # Create and wrap model + model = self._create_lightgcn_model(self.num_nodes) + + dmp_model = DMP( + module=model, + device=self.device, + ) + + self._set_embeddings(model, "default_homogeneous_node_type") + + model.train() + + # Forward and backward pass + output = dmp_model(data=self.data, device=self.device) + loss = output.sum() + loss.backward() + + # Check that gradients exist and are non-zero + embedding_table = model._embedding_bag_collection.embedding_bags[ + "node_embedding_default_homogeneous_node_type" + ] + self.assertIsNotNone( + embedding_table.weight.grad, + "Gradients should exist after backward pass", + ) + self.assertTrue( + torch.any(embedding_table.weight.grad != 0), + "Gradients should be non-zero", + ) + + finally: + if dist.is_initialized(): + dist.destroy_process_group() if __name__ == "__main__": unittest.main()