gigl.src.common.models.pyg.nn.conv.simplehgn_conv#

Classes#

SimpleHGNConv

The SimpleHGN convolution layer based on https://arxiv.org/pdf/2112.14936

Module Contents#

class gigl.src.common.models.pyg.nn.conv.simplehgn_conv.SimpleHGNConv(in_channels, out_channels, num_edge_types, edge_in_channels=None, num_heads=1, edge_type_dim=16, should_use_node_residual=True, negative_slope=0.2, dropout=0.0)[source]#

Bases: torch_geometric.nn.conv.MessagePassing

The SimpleHGN convolution layer based on https://arxiv.org/pdf/2112.14936

Here, we adopt a form which includes support for edge-features in addition to node-features for attention calculation. This layer is based on the adaptation for link prediction tasks listed below Eq.14 in the paper.

Parameters:
  • in_channels (int) – the input dimension of node features

  • edge_in_channels (Optional[int]) – the input dimension of edge features

  • out_channels (int) – the output dimension of node features

  • edge_type_dim (int) – the hidden dimension allocated to edge-type embeddings (per head)

  • num_heads (int) – the number of heads

  • num_edge_types (int) – the number of edge types

  • dropout (float) – the feature drop rate

  • negative_slope (float) – the negative slope used in the LeakyReLU

  • should_use_node_residual (boolean) – whether we need the node residual operation

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(edge_index, node_feat, edge_type, edge_feat=None)[source]#

Runs the forward pass of the module.

Parameters:
  • edge_index (torch.LongTensor)

  • node_feat (torch.FloatTensor)

  • edge_type (torch.LongTensor)

  • edge_feat (Optional[torch.FloatTensor])

message(node_emb_j, alpha)[source]#

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

reset_parameters()[source]#

Resets all learnable parameters of the module.

W_etype[source]#
W_nfeat[source]#
a_etype[source]#
a_l[source]#
a_r[source]#
edge_in_dim = None[source]#
edge_type_dim = 16[source]#
edge_type_emb[source]#
in_dim[source]#
leakyrelu[source]#
nfeat_drop[source]#
num_edge_types[source]#
num_heads = 1[source]#
out_dim[source]#