From 59d9a8e19b61f56c6e7a907463b86c6954f0295f Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 9 Sep 2025 00:35:12 +0000 Subject: [PATCH 1/3] Initial commit --- .../distributed_neighborloader_test.py | 176 +++++++++++++++++- 1 file changed, 175 insertions(+), 1 deletion(-) diff --git a/python/tests/unit/distributed/distributed_neighborloader_test.py b/python/tests/unit/distributed/distributed_neighborloader_test.py index eed655705..926bb8fb0 100644 --- a/python/tests/unit/distributed/distributed_neighborloader_test.py +++ b/python/tests/unit/distributed/distributed_neighborloader_test.py @@ -1,10 +1,11 @@ import unittest from collections.abc import Mapping -from typing import Optional, Union +from typing import Literal, Optional, Union import torch import torch.multiprocessing as mp from graphlearn_torch.distributed import shutdown_rpc +from graphlearn_torch.typing import reverse_edge_type from parameterized import param, parameterized from torch_geometric.data import Data, HeteroData @@ -549,6 +550,87 @@ def _run_cora_supervised_node_classification( shutdown_rpc() +def _run_subgraph_looks_as_expected_given_edge_direction( + _, + dataset: DistLinkPredictionDataset, + expected_reversed_edge_index: dict[EdgeType, torch.Tensor], +): + torch.distributed.init_process_group( + rank=0, world_size=1, init_method=get_process_group_init_method() + ) + + assert isinstance(dataset.node_ids, Mapping) + + user_loader = DistNeighborLoader( + dataset=dataset, + input_nodes=(_USER, dataset.node_ids[_USER]), + num_neighbors=[2, 2], + pin_memory_device=torch.device("cpu"), + batch_size=1, + ) + + story_loader = DistNeighborLoader( + dataset=dataset, + input_nodes=(_STORY, dataset.node_ids[_STORY]), + num_neighbors=[2, 2], + pin_memory_device=torch.device("cpu"), + batch_size=1, + ) + + for user_datum, story_datum in zip(user_loader, story_loader): + for edge_type in user_datum.edge_types: + # First, we need to remap the edge index with local node ids in the HeteroData object to an edge index with the global node ids + global_src_nodes = user_datum[edge_type[0]].node + global_dst_nodes = user_datum[edge_type[2]].node + global_src_edge_index = global_src_nodes[ + user_datum[edge_type].edge_index[0] + ] + global_dst_edge_index = global_dst_nodes[ + user_datum[edge_type].edge_index[1] + ] + global_edge_index = torch.stack( + [global_src_edge_index, global_dst_edge_index], dim=0 + ) + + # Then, we can compare the global edge index with the expected reversed edge index from the input graph + assert ( + edge_type in expected_reversed_edge_index + ), f"User HeteroData contains edge type {edge_type} that is not in the expected graph edge types: {list(expected_reversed_edge_index.keys())}" + matches = global_edge_index == expected_reversed_edge_index[edge_type] + column_matches = matches.all(dim=0) + contains_column = column_matches.any() + assert ( + contains_column + ), f"User HeteroData contains an edge for edge type {edge_type} in {user_datum[edge_type].edge_index} that is not in the expected graph: {expected_reversed_edge_index[edge_type]}" + + for edge_type in story_datum.edge_types: + assert ( + edge_type in expected_reversed_edge_index + ), f"Story HeteroData contains edge type {edge_type} that is not inthe expected graph edge types: {list(expected_reversed_edge_index.keys())}" + # First, we need to remap the edge index with local node ids in the HeteroData object to an edge index with the global node ids + global_src_nodes = story_datum[edge_type[0]].node + global_dst_nodes = story_datum[edge_type[2]].node + global_src_edge_index = global_src_nodes[ + story_datum[edge_type].edge_index[0] + ] + global_dst_edge_index = global_dst_nodes[ + story_datum[edge_type].edge_index[1] + ] + global_edge_index = torch.stack( + [global_src_edge_index, global_dst_edge_index], dim=0 + ) + + # Then, we can compare the global edge index with the expected reversed edge index from the input graph + matches = global_edge_index == expected_reversed_edge_index[edge_type] + column_matches = matches.all(dim=0) + contains_column = column_matches.any() + assert ( + contains_column + ), f"User HeteroData contains an edge for edge type {edge_type} in {user_datum[edge_type].edge_index} that is not in the expected graph: {expected_reversed_edge_index[edge_type]}" + + shutdown_rpc() + + class DistributedNeighborLoaderTest(unittest.TestCase): def setUp(self): self._master_ip_address = "localhost" @@ -1049,6 +1131,98 @@ def test_cora_supervised_node_classification(self): ), ) + @parameterized.expand( + [ + param( + "Test subgraph looks as expected given outward edge direction", "out" + ), + param("Test subgraph looks as expected given inward edge direction", "in"), + ] + ) + def test_subgraph_looks_as_expected_given_edge_direction( + self, _, edge_direction: Literal["in", "out"] + ): + # We define the graph here so that we have edges + # User -> Story + # 0 -> 0 + # 1 -> 1 + # 2 -> 2 + # 3 -> 3 + # 4 -> 4 + + # Story -> User + # 0 -> 1 + # 1 -> 2 + # 2 -> 3 + # 3 -> 4 + # 4 -> 0 + + user_to_story_edge_index = torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]) + story_to_user_edge_index = torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 0]]) + + partition_output = PartitionOutput( + node_partition_book={ + _USER: torch.zeros(5), + _STORY: torch.zeros(5), + }, + edge_partition_book={ + _USER_TO_STORY: torch.zeros(5), + _STORY_TO_USER: torch.zeros(5), + }, + partitioned_edge_index={ + _USER_TO_STORY: GraphPartitionData( + edge_index=user_to_story_edge_index, + edge_ids=None, + ), + _STORY_TO_USER: GraphPartitionData( + edge_index=story_to_user_edge_index, + edge_ids=None, + ), + }, + partitioned_node_features={ + _USER: FeaturePartitionData( + feats=torch.zeros(5, 2), ids=torch.arange(5) + ), + _STORY: FeaturePartitionData( + feats=torch.zeros(5, 2), ids=torch.arange(5) + ), + }, + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels=None, + ) + + dataset = DistLinkPredictionDataset( + rank=0, world_size=1, edge_dir=edge_direction + ) + dataset.build(partition_output=partition_output) + + if edge_direction == "out": + # If the edge direction is out, we expect the produced HeteroData object to have the edge type reversed and the + # edge index tensor also swapped. This is because GLT swaps the outward direction under-the-hood as a convenience for message passing: + # https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/loader/transform.py#L116-L124 + expected_reversed_edge_index = { + reverse_edge_type(_USER_TO_STORY): user_to_story_edge_index[[1, 0], :], + reverse_edge_type(_STORY_TO_USER): story_to_user_edge_index[[1, 0], :], + } + else: + # If the edge direction is in, we expect the produced HeteroData object to have the edge type and edge tensor be the same as the input + # graph. This is because GLT swaps the inward direction under-the-hood as a convenience for message passing: + # https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/loader/transform.py#L116-L124 + expected_reversed_edge_index = { + _USER_TO_STORY: user_to_story_edge_index, + _STORY_TO_USER: story_to_user_edge_index, + } + + mp.spawn( + fn=_run_subgraph_looks_as_expected_given_edge_direction, + args=( + dataset, + expected_reversed_edge_index, + ), + ) + if __name__ == "__main__": unittest.main() From 9ba9d130651036ba3b99b5fbb6c355bcc4f46ae1 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 9 Sep 2025 00:37:48 +0000 Subject: [PATCH 2/3] Update --- .../distributed_neighborloader_test.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/python/tests/unit/distributed/distributed_neighborloader_test.py b/python/tests/unit/distributed/distributed_neighborloader_test.py index 926bb8fb0..415166e14 100644 --- a/python/tests/unit/distributed/distributed_neighborloader_test.py +++ b/python/tests/unit/distributed/distributed_neighborloader_test.py @@ -553,7 +553,7 @@ def _run_cora_supervised_node_classification( def _run_subgraph_looks_as_expected_given_edge_direction( _, dataset: DistLinkPredictionDataset, - expected_reversed_edge_index: dict[EdgeType, torch.Tensor], + possible_edge_indices_in_subgraph: dict[EdgeType, torch.Tensor], ): torch.distributed.init_process_group( rank=0, world_size=1, init_method=get_process_group_init_method() @@ -594,19 +594,19 @@ def _run_subgraph_looks_as_expected_given_edge_direction( # Then, we can compare the global edge index with the expected reversed edge index from the input graph assert ( - edge_type in expected_reversed_edge_index - ), f"User HeteroData contains edge type {edge_type} that is not in the expected graph edge types: {list(expected_reversed_edge_index.keys())}" - matches = global_edge_index == expected_reversed_edge_index[edge_type] + edge_type in possible_edge_indices_in_subgraph + ), f"User HeteroData contains edge type {edge_type} that is not in the expected graph edge types: {list(possible_edge_indices_in_subgraph.keys())}" + matches = global_edge_index == possible_edge_indices_in_subgraph[edge_type] column_matches = matches.all(dim=0) contains_column = column_matches.any() assert ( contains_column - ), f"User HeteroData contains an edge for edge type {edge_type} in {user_datum[edge_type].edge_index} that is not in the expected graph: {expected_reversed_edge_index[edge_type]}" + ), f"User HeteroData contains an edge for edge type {edge_type} in {user_datum[edge_type].edge_index} that is not in the expected graph: {possible_edge_indices_in_subgraph[edge_type]}" for edge_type in story_datum.edge_types: assert ( - edge_type in expected_reversed_edge_index - ), f"Story HeteroData contains edge type {edge_type} that is not inthe expected graph edge types: {list(expected_reversed_edge_index.keys())}" + edge_type in possible_edge_indices_in_subgraph + ), f"Story HeteroData contains edge type {edge_type} that is not inthe expected graph edge types: {list(possible_edge_indices_in_subgraph.keys())}" # First, we need to remap the edge index with local node ids in the HeteroData object to an edge index with the global node ids global_src_nodes = story_datum[edge_type[0]].node global_dst_nodes = story_datum[edge_type[2]].node @@ -621,12 +621,12 @@ def _run_subgraph_looks_as_expected_given_edge_direction( ) # Then, we can compare the global edge index with the expected reversed edge index from the input graph - matches = global_edge_index == expected_reversed_edge_index[edge_type] + matches = global_edge_index == possible_edge_indices_in_subgraph[edge_type] column_matches = matches.all(dim=0) contains_column = column_matches.any() assert ( contains_column - ), f"User HeteroData contains an edge for edge type {edge_type} in {user_datum[edge_type].edge_index} that is not in the expected graph: {expected_reversed_edge_index[edge_type]}" + ), f"User HeteroData contains an edge for edge type {edge_type} in {user_datum[edge_type].edge_index} that is not in the expected graph: {possible_edge_indices_in_subgraph[edge_type]}" shutdown_rpc() @@ -1202,7 +1202,7 @@ def test_subgraph_looks_as_expected_given_edge_direction( # If the edge direction is out, we expect the produced HeteroData object to have the edge type reversed and the # edge index tensor also swapped. This is because GLT swaps the outward direction under-the-hood as a convenience for message passing: # https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/loader/transform.py#L116-L124 - expected_reversed_edge_index = { + possible_edge_indices_in_subgraph = { reverse_edge_type(_USER_TO_STORY): user_to_story_edge_index[[1, 0], :], reverse_edge_type(_STORY_TO_USER): story_to_user_edge_index[[1, 0], :], } @@ -1210,7 +1210,7 @@ def test_subgraph_looks_as_expected_given_edge_direction( # If the edge direction is in, we expect the produced HeteroData object to have the edge type and edge tensor be the same as the input # graph. This is because GLT swaps the inward direction under-the-hood as a convenience for message passing: # https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/loader/transform.py#L116-L124 - expected_reversed_edge_index = { + possible_edge_indices_in_subgraph = { _USER_TO_STORY: user_to_story_edge_index, _STORY_TO_USER: story_to_user_edge_index, } @@ -1219,7 +1219,7 @@ def test_subgraph_looks_as_expected_given_edge_direction( fn=_run_subgraph_looks_as_expected_given_edge_direction, args=( dataset, - expected_reversed_edge_index, + possible_edge_indices_in_subgraph, ), ) From 07cf6e4a521ad41fd5e50385e7ee5bcaad2c8116 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 9 Sep 2025 00:40:28 +0000 Subject: [PATCH 3/3] Update --- .../distributed_neighborloader_test.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/python/tests/unit/distributed/distributed_neighborloader_test.py b/python/tests/unit/distributed/distributed_neighborloader_test.py index 415166e14..89052caff 100644 --- a/python/tests/unit/distributed/distributed_neighborloader_test.py +++ b/python/tests/unit/distributed/distributed_neighborloader_test.py @@ -553,7 +553,7 @@ def _run_cora_supervised_node_classification( def _run_subgraph_looks_as_expected_given_edge_direction( _, dataset: DistLinkPredictionDataset, - possible_edge_indices_in_subgraph: dict[EdgeType, torch.Tensor], + possible_edge_indices: dict[EdgeType, torch.Tensor], ): torch.distributed.init_process_group( rank=0, world_size=1, init_method=get_process_group_init_method() @@ -592,21 +592,21 @@ def _run_subgraph_looks_as_expected_given_edge_direction( [global_src_edge_index, global_dst_edge_index], dim=0 ) - # Then, we can compare the global edge index with the expected reversed edge index from the input graph + # Then, we can compare the global edge index with the possible edge indices that can exist for this edge type assert ( - edge_type in possible_edge_indices_in_subgraph - ), f"User HeteroData contains edge type {edge_type} that is not in the expected graph edge types: {list(possible_edge_indices_in_subgraph.keys())}" - matches = global_edge_index == possible_edge_indices_in_subgraph[edge_type] + edge_type in possible_edge_indices + ), f"User HeteroData contains edge type {edge_type} that is not in the expected graph edge types: {list(possible_edge_indices.keys())}" + matches = global_edge_index == possible_edge_indices[edge_type] column_matches = matches.all(dim=0) contains_column = column_matches.any() assert ( contains_column - ), f"User HeteroData contains an edge for edge type {edge_type} in {user_datum[edge_type].edge_index} that is not in the expected graph: {possible_edge_indices_in_subgraph[edge_type]}" + ), f"User HeteroData contains an edge for edge type {edge_type} in {user_datum[edge_type].edge_index} that is not in the expected graph: {possible_edge_indices[edge_type]}" for edge_type in story_datum.edge_types: assert ( - edge_type in possible_edge_indices_in_subgraph - ), f"Story HeteroData contains edge type {edge_type} that is not inthe expected graph edge types: {list(possible_edge_indices_in_subgraph.keys())}" + edge_type in possible_edge_indices + ), f"Story HeteroData contains edge type {edge_type} that is not inthe expected graph edge types: {list(possible_edge_indices.keys())}" # First, we need to remap the edge index with local node ids in the HeteroData object to an edge index with the global node ids global_src_nodes = story_datum[edge_type[0]].node global_dst_nodes = story_datum[edge_type[2]].node @@ -621,12 +621,12 @@ def _run_subgraph_looks_as_expected_given_edge_direction( ) # Then, we can compare the global edge index with the expected reversed edge index from the input graph - matches = global_edge_index == possible_edge_indices_in_subgraph[edge_type] + matches = global_edge_index == possible_edge_indices[edge_type] column_matches = matches.all(dim=0) contains_column = column_matches.any() assert ( contains_column - ), f"User HeteroData contains an edge for edge type {edge_type} in {user_datum[edge_type].edge_index} that is not in the expected graph: {possible_edge_indices_in_subgraph[edge_type]}" + ), f"User HeteroData contains an edge for edge type {edge_type} in {user_datum[edge_type].edge_index} that is not in the expected graph: {possible_edge_indices[edge_type]}" shutdown_rpc() @@ -1202,7 +1202,7 @@ def test_subgraph_looks_as_expected_given_edge_direction( # If the edge direction is out, we expect the produced HeteroData object to have the edge type reversed and the # edge index tensor also swapped. This is because GLT swaps the outward direction under-the-hood as a convenience for message passing: # https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/loader/transform.py#L116-L124 - possible_edge_indices_in_subgraph = { + possible_edge_indices = { reverse_edge_type(_USER_TO_STORY): user_to_story_edge_index[[1, 0], :], reverse_edge_type(_STORY_TO_USER): story_to_user_edge_index[[1, 0], :], } @@ -1210,7 +1210,7 @@ def test_subgraph_looks_as_expected_given_edge_direction( # If the edge direction is in, we expect the produced HeteroData object to have the edge type and edge tensor be the same as the input # graph. This is because GLT swaps the inward direction under-the-hood as a convenience for message passing: # https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/loader/transform.py#L116-L124 - possible_edge_indices_in_subgraph = { + possible_edge_indices = { _USER_TO_STORY: user_to_story_edge_index, _STORY_TO_USER: story_to_user_edge_index, } @@ -1219,7 +1219,7 @@ def test_subgraph_looks_as_expected_given_edge_direction( fn=_run_subgraph_looks_as_expected_given_edge_direction, args=( dataset, - possible_edge_indices_in_subgraph, + possible_edge_indices, ), )