gigl.src.common.models.pyg.nn.conv.hgt_conv#
Classes#
| Modified version of PyG's HGTConv conv implementation | 
Module Contents#
- class gigl.src.common.models.pyg.nn.conv.hgt_conv.HGTConv(in_channels, out_channels, metadata, heads=1, **kwargs)[source]#
- Bases: - torch_geometric.nn.conv.MessagePassing- Modified version of PyG’s HGTConv conv implementation https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/hgt_conv.html#HGTConv - PyG’s implementation drops node types in the graph with no incoming message passing edges (line 208 inside forward), while ours keeps those node types in the output. - The Heterogeneous Graph Transformer (HGT) operator from the “Heterogeneous Graph Transformer” paper. - Note - For an example of using HGT, see examples/hetero/hgt_dblp.py. - Parameters:
- in_channels (int or dict[str, int]) – Size of each input sample of every node type, or - -1to derive the size from the first input(s) to the forward method.
- out_channels (int) – Size of each output sample. 
- metadata (Tuple[list[str], list[Tuple[str, str, str]]]) – The metadata of the heterogeneous graph, i.e. its node and edge types given by a list of strings and a list of string triplets, respectively. See - torch_geometric.data.HeteroData.metadata()for more information.
- heads (int, optional) – Number of multi-head-attentions. (default: - 1)
- **kwargs (optional) – Additional arguments of - torch_geometric.nn.conv.MessagePassing.
 
 - Initialize internal Module state, shared by both nn.Module and ScriptModule. - forward(x_dict, edge_index_dict)[source]#
- Runs the forward pass of the module. - Parameters:
- x_dict (dict[str, torch.Tensor]) – A dictionary holding input node features for each individual node type. 
- edge_index_dict (dict[Tuple[str, str, str], torch.Tensor]) – A dictionary holding graph connectivity information for each individual edge type, either as a - torch.Tensorof shape- [2, num_edges]or a- torch_sparse.SparseTensor.
 
- Return type:
- dict[str, Optional[torch.Tensor]]- The output node embeddings for each node type. In case a node type does not receive any message, its output will be set to- None.
 
 - message(k_j, q_i, v_j, edge_attr, index, ptr, size_i)[source]#
- Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in - edge_index. This function can take any argument as input which was initially passed to- propagate(). Furthermore, tensors passed to- propagate()can be mapped to the respective nodes \(i\) and \(j\) by appending- _ior- _jto the variable name, .e.g.- x_iand- x_j.- Parameters:
- k_j (torch.Tensor) 
- q_i (torch.Tensor) 
- v_j (torch.Tensor) 
- edge_attr (torch.Tensor) 
- index (torch.Tensor) 
- ptr (Optional[torch.Tensor]) 
- size_i (Optional[int]) 
 
- Return type:
- torch.Tensor 
 
 
