-
Notifications
You must be signed in to change notification settings - Fork 9
Add Bipartite Implementation for LightGCN #370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
python/gigl/module/models.py
Outdated
| data (Union[Data, HeteroData]): Graph data (homogeneous or heterogeneous). | ||
| data (Union[Data, HeteroData]): Graph data. | ||
| - For homogeneous: Data object with edge_index and node field | ||
| - For bipartite: HeteroData with 2 node types and edge_index_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we restricting this to bipartite? Does it get easier if we do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, LightGCN is not defined for cases beyond bipartite. We decided for now to only worry about bipartite.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did this happen when I was gone? I really don't see the point in doing so, I don't think the implementation for a "heterogeneous" implementation would be any different.
And if we go with this approach then in the future, will we have a third _forward_heterogeneous implementation? Is that not just more complicated?
It seems to me like the current _forward_bipartite implementation has no restrictions on this being bipartite or not, right? I don't see any asserts/etc in the code that would restrict us right?
Why not keep the current requirement that users provide two node types, but rename this function to _forward_heterogeneous?
Or am I missing something and the code here is in fact only correct for bipartite?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're asking if the code can handle n different node types, then yes I don't see why not. The problem I have with heterogeneous is that it kind of implies that it should be able to handle multiple node types and multiple edge types. It is much more complicated to consider multiple edge types. For now, I changed the naming convention to _forward_heterogeneous, and generalized the comments to heterogeneous.
python/gigl/module/models.py
Outdated
| src_offset = node_type_to_offset[src_node_type] | ||
| dst_offset = node_type_to_offset[dst_node_type] | ||
|
|
||
| offset_edge_index = edge_index.clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we clone here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm, is this because we mutate this? does to create a copy already?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't believe it creates a copy necessarily. So if data[edge_type_tuple].edge_index is already on the target device, .to(device) would return the same tensor. Then when we do the mutation operations, this would affect the original graph data, which we don't want I don't think.
python/gigl/module/models.py
Outdated
|
|
||
| for node_type in output_node_types: | ||
| node_type_str = str(node_type) | ||
| key = f"{node_type_str}_id" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should probably parameterise this somehow, either with some NODE_ID_FMT = "{node_type}_id"; NODE_ID_FMT.format(node_type=node_type_str) or like, def get_nt_key().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I agree. I added a static method to centralize the key naming.
python/gigl/module/models.py
Outdated
| # LightGCN propagation across node types | ||
| all_node_types = list(node_type_to_embeddings_0.keys()) | ||
|
|
||
| # For bipartite, we need to create a unified edge representation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to do this? Don't we have separate embedding tables for the different nodes?
Or does the pyg convolution require all nodes be in the same space?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGConv, unlike other HeteroConvs, doesn't support heterogeneous data, so we'd have to re-implement it. Thus, all nodes need to be in a single unified index space. Additionally, we could run multiple forward passes for each type, but this seems inefficient, as we'd have to aggregate the results anyway.
python/gigl/module/models.py
Outdated
| if anchor_node_ids is not None: | ||
| for node_type in all_node_types: | ||
| if isinstance(anchor_node_ids, dict) and node_type in anchor_node_ids: | ||
| anchors = anchor_node_ids[node_type].to(device).long() | ||
| final_embeddings[node_type] = final_embeddings[node_type][anchors] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have all node types: a, b, c and anchor node ids {a: [1, 2], b, [3, 4]} then the final embeddings would still contain all of the embeddings for c, right?
Should we fix this? (and also add some test for it?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not exactly sure what you're asking, but I added some anchor node checks in the test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I'm saying is let's say a user provides anchor_node_ids = {a: [10, 20]}, and the graph has node types {a, b}
The returned final_embeddings will be {a: [10, 20], b: all_nodes}. Right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahhh, I see. Yes, this is the expected behavior. I take it you're saying we shouldn't be returning any embeddings for b here?
af3d278 to
371c85a
Compare
b02fb53 to
50e1f47
Compare
python/gigl/nn/models.py
Outdated
| Must have length K+1. If None, uses uniform weights 1/(K+1). Default: None. | ||
| """ | ||
|
|
||
| @staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's not make this a static method?
why not just a _private free function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, that makes sense, changed to a free function outside LightGCN class.
python/gigl/module/models.py
Outdated
| data (Union[Data, HeteroData]): Graph data (homogeneous or heterogeneous). | ||
| data (Union[Data, HeteroData]): Graph data. | ||
| - For homogeneous: Data object with edge_index and node field | ||
| - For bipartite: HeteroData with 2 node types and edge_index_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did this happen when I was gone? I really don't see the point in doing so, I don't think the implementation for a "heterogeneous" implementation would be any different.
And if we go with this approach then in the future, will we have a third _forward_heterogeneous implementation? Is that not just more complicated?
It seems to me like the current _forward_bipartite implementation has no restrictions on this being bipartite or not, right? I don't see any asserts/etc in the code that would restrict us right?
Why not keep the current requirement that users provide two node types, but rename this function to _forward_heterogeneous?
Or am I missing something and the code here is in fact only correct for bipartite?
| # Lookup initial embeddings e^(0) for each node type | ||
| node_type_to_embeddings_0: dict[NodeType, torch.Tensor] = {} | ||
|
|
||
| for node_type in output_node_types: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW we should sort the node types. Otherwise, things may be in different orders on different machines which causes tricky and annoying to debug issues!
We have https://github.com/Snapchat/GiGL/blob/main/python/gigl/common/collections/sorted_dict.py, which is a bit crufty and I can clean up, to enforce that a dict is sorted. Do you think we can leverage it (or just use sorted) here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just used sorted, because we just need to ensure that the output_node_types list is sorted at the beginning of the forward_heterogeneous method I think.
python/gigl/module/models.py
Outdated
| if anchor_node_ids is not None: | ||
| for node_type in all_node_types: | ||
| if isinstance(anchor_node_ids, dict) and node_type in anchor_node_ids: | ||
| anchors = anchor_node_ids[node_type].to(device).long() | ||
| final_embeddings[node_type] = final_embeddings[node_type][anchors] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I'm saying is let's say a user provides anchor_node_ids = {a: [10, 20]}, and the graph has node types {a, b}
The returned final_embeddings will be {a: [10, 20], b: all_nodes}. Right?
|
Note the answers to some of your comments @kmontemayor2-sc are above my review for some reason. |
| # Determine which node types to process | ||
| if output_node_types is None: | ||
| # Sort node types for deterministic ordering across machines | ||
| output_node_types = sorted([NodeType(str(nt)) for nt in data.node_types], key=str) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| output_node_types = sorted([NodeType(str(nt)) for nt in data.node_types], key=str) | |
| output_node_types = sorted(data.node_types) |
nit. This can be cleaner :)
|
|
||
| # LightGCN propagation across node types | ||
| # Sort node types for deterministic ordering across machines | ||
| all_node_types = sorted(node_type_to_embeddings_0.keys(), key=str) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| all_node_types = sorted(node_type_to_embeddings_0.keys(), key=str) | |
| all_node_types = sorted(node_type_to_embeddings_0.keys()) |
Again I don't think we need the key here?
|
|
||
| # Combine all edges into a single edge_index | ||
| combined_edge_list: list[torch.Tensor] = [] | ||
| for edge_type_tuple in data.edge_types: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like we should sort this too.
Or ensure that GLT outputs a sorted list? Which I don't think it does atm. cc: @mkolodner-sc
| output_with_anchors = model( | ||
| data, | ||
| self.device, | ||
| output_node_types=[NodeType("user"), NodeType("item")], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just do the one output node type, and ensure the output dict is of length 1, I think there's still the bug here :) Related to #370 (comment)
| output_node_types=[NodeType("user"), NodeType("item")], | |
| output_node_types=[NodeType("user")], |
Scope of work done
HeteroDatainput toforwardWhere is the documentation for this feature?: N/A
Did you add automated tests or write a test plan?
Updated Changelog.md? NO
Ready for code review?: YES