gigl.src.common.models.pyg.nn.conv.hgt_conv#

Classes#

HGTConv

Modified version of PyG's HGTConv conv implementation

Module Contents#

class gigl.src.common.models.pyg.nn.conv.hgt_conv.HGTConv(in_channels, out_channels, metadata, heads=1, **kwargs)[source]#

Bases: torch_geometric.nn.conv.MessagePassing

Modified version of PyG’s HGTConv conv implementation https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/hgt_conv.html#HGTConv

PyG’s implementation drops node types in the graph with no incoming message passing edges (line 208 inside forward), while ours keeps those node types in the output.

The Heterogeneous Graph Transformer (HGT) operator from the “Heterogeneous Graph Transformer” paper.

Note

For an example of using HGT, see examples/hetero/hgt_dblp.py.

Parameters:
  • in_channels (int or Dict[str, int]) – Size of each input sample of every node type, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • metadata (Tuple[List[str], List[Tuple[str, str, str]]]) – The metadata of the heterogeneous graph, i.e. its node and edge types given by a list of strings and a list of string triplets, respectively. See torch_geometric.data.HeteroData.metadata() for more information.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x_dict, edge_index_dict)[source]#

Runs the forward pass of the module.

Parameters:
  • x_dict (Dict[str, torch.Tensor]) – A dictionary holding input node features for each individual node type.

  • edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]) – A dictionary holding graph connectivity information for each individual edge type, either as a torch.Tensor of shape [2, num_edges] or a torch_sparse.SparseTensor.

Return type:

Dict[str, Optional[torch.Tensor]] - The output node embeddings for each node type. In case a node type does not receive any message, its output will be set to None.

message(k_j, q_i, v_j, edge_attr, index, ptr, size_i)[source]#

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

Parameters:
  • k_j (torch.Tensor)

  • q_i (torch.Tensor)

  • v_j (torch.Tensor)

  • edge_attr (torch.Tensor)

  • index (torch.Tensor)

  • ptr (Optional[torch.Tensor])

  • size_i (Optional[int])

Return type:

torch.Tensor

reset_parameters()[source]#

Resets all learnable parameters of the module.

edge_types[source]#
edge_types_map[source]#
heads = 1[source]#
in_channels[source]#
k_rel[source]#
kqv_lin[source]#
node_types[source]#
out_channels[source]#
out_lin[source]#
p_rel[source]#
skip[source]#
v_rel[source]#