gigl.transforms.graph_transformer#
Transform HeteroData to Graph Transformer sequence input.
This module provides functionality to convert PyG HeteroData objects (typically batched 2-hop subgraphs) into sequence format suitable for Graph Transformers.
For each anchor node in the batch, the transform extracts its k-hop neighborhood and creates a fixed-length sequence of node features with padding.
- Example Usage:
>>> from torch_geometric.data import HeteroData >>> from gigl.transforms.graph_transformer import heterodata_to_graph_transformer_input >>> >>> # Create batched HeteroData (e.g., from NeighborLoader) >>> # First batch_size nodes in each node type are anchor nodes >>> data = HeteroData() >>> data['user'].x = torch.randn(100, 64) # 100 users, first N are anchors >>> data['item'].x = torch.randn(50, 32) >>> data['user', 'buys', 'item'].edge_index = ... >>> >>> # Transform to Graph Transformer input >>> sequences, valid_mask, attention_bias_data = heterodata_to_graph_transformer_input( ... data=data, ... batch_size=32, ... max_seq_len=128, ... anchor_node_type='user', ... ) >>> # sequences: (batch_size, max_seq_len, feature_dim) >>> # valid_mask: (batch_size, max_seq_len)
With Relative Encodings: Relative encodings stored as sparse graph-level attributes can be returned as raw attention-bias features:
>>> from torch_geometric.transforms import Compose >>> from gigl.transforms.add_positional_encodings import ( ... AddHeteroRandomWalkEncodings, ... AddHeteroHopDistanceEncoding, ... ) >>> >>> # First apply PE transforms to the data >>> pe_transform = Compose([ ... AddHeteroRandomWalkEncodings(walk_length=8), ... AddHeteroHopDistanceEncoding(h_max=5), ... ]) >>> data = pe_transform(data) >>> >>> # Transform to sequences with relative bias features >>> sequences, valid_mask, attention_bias_data = heterodata_to_graph_transformer_input( ... data=data, ... batch_size=32, ... max_seq_len=128, ... anchor_node_type='user', ... anchor_based_attention_bias_attr_names=['hop_distance'], ... ) >>> # sequences: (batch_size, max_seq_len, feature_dim) >>> # attention_bias_data['anchor_bias']: (batch_size, max_seq_len, 1)
Attributes#
Classes#
dict() -> new empty dictionary |
Functions#
|
Transform a HeteroData object to Graph Transformer sequence input. |
Module Contents#
- class gigl.transforms.graph_transformer.SequenceAuxiliaryData[source]#
Bases:
TypedDictdict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s
(key, value) pairs
- dict(iterable) -> new dictionary initialized as if via:
d = {} for k, v in iterable:
d[k] = v
- dict(**kwargs) -> new dictionary initialized with the name=value pairs
in the keyword argument list. For example: dict(one=1, two=2)
Initialize self. See help(type(self)) for accurate signature.
- gigl.transforms.graph_transformer.heterodata_to_graph_transformer_input(data, batch_size, max_seq_len, anchor_node_type, anchor_node_ids=None, hop_distance=2, sequence_construction_method='khop', include_anchor_first=True, padding_value=0.0, anchor_based_attention_bias_attr_names=None, anchor_based_input_attr_names=None, pairwise_attention_bias_attr_names=None)[source]#
Transform a HeteroData object to Graph Transformer sequence input.
Given a batched HeteroData object where the first batch_size nodes of anchor_node_type are anchor nodes, this function extracts the k-hop neighborhood for each anchor and creates padded sequences.
Uses sparse matrix operations for efficient batched k-hop neighbor extraction.
- Parameters:
data (torch_geometric.data.HeteroData) – HeteroData object containing node features and edge indices. Expected to have node features as data[node_type].x. All node types must have the same feature dimension.
batch_size (int) – Number of anchor nodes (first batch_size nodes of anchor_node_type). Ignored if anchor_node_ids is provided.
max_seq_len (int) – Maximum sequence length (neighbors beyond this are truncated).
anchor_node_type (torch_geometric.typing.NodeType) – The node type of anchor nodes.
anchor_node_ids (Optional[torch.Tensor]) – Optional tensor of local node indices within anchor_node_type to use as anchors. If None, uses first batch_size nodes. (default: None)
hop_distance (int) – Number of hops to consider for neighborhood when
sequence_construction_method="khop". (default: 2)sequence_construction_method (Literal['khop', 'ppr']) – Strategy used to build per-anchor sequences.
"khop"performs the existing k-hop expansion over the sampled graph."ppr"uses outgoing(anchor_type, "ppr", neighbor_type)edges, sorted by descending PPR weight fromedge_attr. (default:"khop")include_anchor_first (bool) – If True, anchor node is always first in sequence.
padding_value (float) – Value to use for padding (default: 0.0).
anchor_based_attention_bias_attr_names (Optional[list[str]]) – List of anchor-relative feature names used as attention bias. Sparse graph-level attributes are looked up from
dataand the reserved name"ppr_weight"resolves to PPR edge weights in PPR sequence mode. Example: [‘hop_distance’, ‘ppr_weight’].anchor_based_input_attr_names (Optional[list[str]]) – List of anchor-relative attribute names returned as token-aligned model-input features. Sparse graph-level attributes are looked up from
dataand"ppr_weight"resolves to PPR edge weights in PPR sequence mode. Example: [‘hop_distance’, ‘ppr_weight’].pairwise_attention_bias_attr_names (Optional[list[str]]) – List of pairwise feature names used as attention bias. These must correspond to sparse graph-level attributes on
data. Example: [‘pairwise_distance’].
- Returns:
- sequences: (batch_size, max_seq_len, feature_dim) padded node features
taken directly from
data[node_type].xin homogeneous order.- valid_mask: (batch_size, max_seq_len) bool tensor indicating which
sequence positions correspond to real nodes.
- sequence_auxiliary_data: dictionary of raw token-aligned and
attention-bias features with:
"anchor_bias"shaped(batch, seq, num_anchor_attrs)or None"pairwise_bias"shaped(batch, seq, seq, num_pairwise_attrs)or None"token_input"as a dict mapping attribute name to a(batch, seq, 1)tensor, or None
- Return type:
(sequences, valid_mask, attention_bias_data), where
- Raises:
ValueError – If node types have different feature dimensions.
ValueError – If no node features exist in the data.