Skip to content

Conversation

@swong3-sc
Copy link
Collaborator

@swong3-sc swong3-sc commented Oct 31, 2025

Scope of work done

  • Added Heterogeneous functionality for LightGCN, using HeteroData input to forward

Where is the documentation for this feature?: N/A

Did you add automated tests or write a test plan?

  • Added unit test to check final embeddings for bipartite user-item graph
  • Added unit test to check anchor-node functionality for a bipartite user-item graph

Updated Changelog.md? NO

Ready for code review?: YES

data (Union[Data, HeteroData]): Graph data (homogeneous or heterogeneous).
data (Union[Data, HeteroData]): Graph data.
- For homogeneous: Data object with edge_index and node field
- For bipartite: HeteroData with 2 node types and edge_index_dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we restricting this to bipartite? Does it get easier if we do?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, LightGCN is not defined for cases beyond bipartite. We decided for now to only worry about bipartite.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did this happen when I was gone? I really don't see the point in doing so, I don't think the implementation for a "heterogeneous" implementation would be any different.

And if we go with this approach then in the future, will we have a third _forward_heterogeneous implementation? Is that not just more complicated?

It seems to me like the current _forward_bipartite implementation has no restrictions on this being bipartite or not, right? I don't see any asserts/etc in the code that would restrict us right?

Why not keep the current requirement that users provide two node types, but rename this function to _forward_heterogeneous?

Or am I missing something and the code here is in fact only correct for bipartite?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're asking if the code can handle n different node types, then yes I don't see why not. The problem I have with heterogeneous is that it kind of implies that it should be able to handle multiple node types and multiple edge types. It is much more complicated to consider multiple edge types. For now, I changed the naming convention to _forward_heterogeneous, and generalized the comments to heterogeneous.

src_offset = node_type_to_offset[src_node_type]
dst_offset = node_type_to_offset[dst_node_type]

offset_edge_index = edge_index.clone()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we clone here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, is this because we mutate this? does to create a copy already?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe it creates a copy necessarily. So if data[edge_type_tuple].edge_index is already on the target device, .to(device) would return the same tensor. Then when we do the mutation operations, this would affect the original graph data, which we don't want I don't think.


for node_type in output_node_types:
node_type_str = str(node_type)
key = f"{node_type_str}_id"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably parameterise this somehow, either with some NODE_ID_FMT = "{node_type}_id"; NODE_ID_FMT.format(node_type=node_type_str) or like, def get_nt_key().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree. I added a static method to centralize the key naming.

# LightGCN propagation across node types
all_node_types = list(node_type_to_embeddings_0.keys())

# For bipartite, we need to create a unified edge representation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to do this? Don't we have separate embedding tables for the different nodes?

Or does the pyg convolution require all nodes be in the same space?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGConv, unlike other HeteroConvs, doesn't support heterogeneous data, so we'd have to re-implement it. Thus, all nodes need to be in a single unified index space. Additionally, we could run multiple forward passes for each type, but this seems inefficient, as we'd have to aggregate the results anyway.

Comment on lines 454 to 458
if anchor_node_ids is not None:
for node_type in all_node_types:
if isinstance(anchor_node_ids, dict) and node_type in anchor_node_ids:
anchors = anchor_node_ids[node_type].to(device).long()
final_embeddings[node_type] = final_embeddings[node_type][anchors]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have all node types: a, b, c and anchor node ids {a: [1, 2], b, [3, 4]} then the final embeddings would still contain all of the embeddings for c, right?

Should we fix this? (and also add some test for it?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly sure what you're asking, but I added some anchor node checks in the test.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I'm saying is let's say a user provides anchor_node_ids = {a: [10, 20]}, and the graph has node types {a, b}

The returned final_embeddings will be {a: [10, 20], b: all_nodes}. Right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh, I see. Yes, this is the expected behavior. I take it you're saying we shouldn't be returning any embeddings for b here?

@swong3-sc swong3-sc force-pushed the swong3/add_dmp_tests branch from af3d278 to 371c85a Compare November 6, 2025 00:09
Base automatically changed from swong3/add_dmp_tests to main November 8, 2025 01:05
@swong3-sc swong3-sc force-pushed the swong3/add_heterogenous_lightgcn branch from b02fb53 to 50e1f47 Compare November 8, 2025 06:50
@swong3-sc swong3-sc marked this pull request as ready for review November 10, 2025 06:34
Must have length K+1. If None, uses uniform weights 1/(K+1). Default: None.
"""

@staticmethod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not make this a static method?

why not just a _private free function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that makes sense, changed to a free function outside LightGCN class.

data (Union[Data, HeteroData]): Graph data (homogeneous or heterogeneous).
data (Union[Data, HeteroData]): Graph data.
- For homogeneous: Data object with edge_index and node field
- For bipartite: HeteroData with 2 node types and edge_index_dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did this happen when I was gone? I really don't see the point in doing so, I don't think the implementation for a "heterogeneous" implementation would be any different.

And if we go with this approach then in the future, will we have a third _forward_heterogeneous implementation? Is that not just more complicated?

It seems to me like the current _forward_bipartite implementation has no restrictions on this being bipartite or not, right? I don't see any asserts/etc in the code that would restrict us right?

Why not keep the current requirement that users provide two node types, but rename this function to _forward_heterogeneous?

Or am I missing something and the code here is in fact only correct for bipartite?

# Lookup initial embeddings e^(0) for each node type
node_type_to_embeddings_0: dict[NodeType, torch.Tensor] = {}

for node_type in output_node_types:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW we should sort the node types. Otherwise, things may be in different orders on different machines which causes tricky and annoying to debug issues!

We have https://github.com/Snapchat/GiGL/blob/main/python/gigl/common/collections/sorted_dict.py, which is a bit crufty and I can clean up, to enforce that a dict is sorted. Do you think we can leverage it (or just use sorted) here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just used sorted, because we just need to ensure that the output_node_types list is sorted at the beginning of the forward_heterogeneous method I think.

Comment on lines 454 to 458
if anchor_node_ids is not None:
for node_type in all_node_types:
if isinstance(anchor_node_ids, dict) and node_type in anchor_node_ids:
anchors = anchor_node_ids[node_type].to(device).long()
final_embeddings[node_type] = final_embeddings[node_type][anchors]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I'm saying is let's say a user provides anchor_node_ids = {a: [10, 20]}, and the graph has node types {a, b}

The returned final_embeddings will be {a: [10, 20], b: all_nodes}. Right?

@swong3-sc
Copy link
Collaborator Author

Note the answers to some of your comments @kmontemayor2-sc are above my review for some reason.

# Determine which node types to process
if output_node_types is None:
# Sort node types for deterministic ordering across machines
output_node_types = sorted([NodeType(str(nt)) for nt in data.node_types], key=str)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output_node_types = sorted([NodeType(str(nt)) for nt in data.node_types], key=str)
output_node_types = sorted(data.node_types)

nit. This can be cleaner :)


# LightGCN propagation across node types
# Sort node types for deterministic ordering across machines
all_node_types = sorted(node_type_to_embeddings_0.keys(), key=str)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
all_node_types = sorted(node_type_to_embeddings_0.keys(), key=str)
all_node_types = sorted(node_type_to_embeddings_0.keys())

Again I don't think we need the key here?


# Combine all edges into a single edge_index
combined_edge_list: list[torch.Tensor] = []
for edge_type_tuple in data.edge_types:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we should sort this too.

Or ensure that GLT outputs a sorted list? Which I don't think it does atm. cc: @mkolodner-sc

output_with_anchors = model(
data,
self.device,
output_node_types=[NodeType("user"), NodeType("item")],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just do the one output node type, and ensure the output dict is of length 1, I think there's still the bug here :) Related to #370 (comment)

Suggested change
output_node_types=[NodeType("user"), NodeType("item")],
output_node_types=[NodeType("user")],

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants