gigl.src.common.models.pyg.nn.conv.hgt_conv#
Classes#
Modified version of PyG's HGTConv conv implementation |
Module Contents#
- class gigl.src.common.models.pyg.nn.conv.hgt_conv.HGTConv(in_channels, out_channels, metadata, heads=1, **kwargs)[source]#
Bases:
torch_geometric.nn.conv.MessagePassing
Modified version of PyG’s HGTConv conv implementation https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/hgt_conv.html#HGTConv
PyG’s implementation drops node types in the graph with no incoming message passing edges (line 208 inside forward), while ours keeps those node types in the output.
The Heterogeneous Graph Transformer (HGT) operator from the “Heterogeneous Graph Transformer” paper.
Note
For an example of using HGT, see examples/hetero/hgt_dblp.py.
- Parameters:
in_channels (int or Dict[str, int]) – Size of each input sample of every node type, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
metadata (Tuple[List[str], List[Tuple[str, str, str]]]) – The metadata of the heterogeneous graph, i.e. its node and edge types given by a list of strings and a list of string triplets, respectively. See
torch_geometric.data.HeteroData.metadata()
for more information.heads (int, optional) – Number of multi-head-attentions. (default:
1
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x_dict, edge_index_dict)[source]#
Runs the forward pass of the module.
- Parameters:
x_dict (Dict[str, torch.Tensor]) – A dictionary holding input node features for each individual node type.
edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]) – A dictionary holding graph connectivity information for each individual edge type, either as a
torch.Tensor
of shape[2, num_edges]
or atorch_sparse.SparseTensor
.
- Return type:
Dict[str, Optional[torch.Tensor]]
- The output node embeddings for each node type. In case a node type does not receive any message, its output will be set toNone
.
- message(k_j, q_i, v_j, edge_attr, index, ptr, size_i)[source]#
Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in
edge_index
. This function can take any argument as input which was initially passed topropagate()
. Furthermore, tensors passed topropagate()
can be mapped to the respective nodes \(i\) and \(j\) by appending_i
or_j
to the variable name, .e.g.x_i
andx_j
.- Parameters:
k_j (torch.Tensor)
q_i (torch.Tensor)
v_j (torch.Tensor)
edge_attr (torch.Tensor)
index (torch.Tensor)
ptr (Optional[torch.Tensor])
size_i (Optional[int])
- Return type:
torch.Tensor