Source code for gigl.src.common.models.pyg.nn.conv.simplehgn_conv

from typing import Optional

import torch
import torch_geometric
from torch import nn
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import softmax


[docs] class SimpleHGNConv(MessagePassing): """ The SimpleHGN convolution layer based on https://arxiv.org/pdf/2112.14936 Here, we adopt a form which includes support for edge-features in addition to node-features for attention calculation. This layer is based on the adaptation for link prediction tasks listed below Eq.14 in the paper. Args: in_channels (int): the input dimension of node features edge_in_channels (Optional[int]): the input dimension of edge features out_channels (int): the output dimension of node features edge_type_dim (int): the hidden dimension allocated to edge-type embeddings (per head) num_heads (int): the number of heads num_edge_types (int): the number of edge types dropout (float): the feature drop rate negative_slope (float): the negative slope used in the LeakyReLU should_use_node_residual (boolean): whether we need the node residual operation """ def __init__( self, in_channels: int, out_channels: int, num_edge_types: int, edge_in_channels: Optional[int] = None, num_heads: int = 1, edge_type_dim: int = 16, should_use_node_residual: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, ): super().__init__(aggr="add", node_dim=0)
[docs] self.in_dim = in_channels
[docs] self.out_dim = out_channels
[docs] self.edge_in_dim = edge_in_channels
[docs] self.edge_type_dim = edge_type_dim
[docs] self.num_edge_types = num_edge_types
[docs] self.num_heads = num_heads
# Encodes embeddings for each edge-type.
[docs] self.edge_type_emb = nn.Parameter( torch.empty(size=(self.num_edge_types, self.edge_type_dim)) )
# Multi-headed linear projection for edge-type embedding
[docs] self.W_etype = torch_geometric.nn.HeteroLinear( self.edge_type_dim, self.edge_type_dim * self.num_heads, self.num_edge_types )
# Linear projection for node features (for each head)
[docs] self.W_nfeat = nn.Parameter( torch.FloatTensor(self.in_dim, self.out_dim * self.num_heads) )
if self.edge_in_dim: # Linear projection for edge features (for each head) self.W_efeat = nn.Parameter( torch.FloatTensor(self.edge_in_dim, self.edge_in_dim * self.num_heads) ) # Attention weights for edge features self.a_efeat = nn.Parameter( torch.empty(size=(1, self.num_heads, self.edge_in_dim)) ) # Dropout for edge features self.efeat_drop = nn.Dropout(dropout)
[docs] self.a_l = nn.Parameter(torch.empty(size=(1, self.num_heads, self.out_dim)))
[docs] self.a_r = nn.Parameter(torch.empty(size=(1, self.num_heads, self.out_dim)))
[docs] self.a_etype = nn.Parameter( torch.empty(size=(1, self.num_heads, self.edge_type_dim)) )
[docs] self.nfeat_drop = nn.Dropout(dropout)
[docs] self.leakyrelu = nn.LeakyReLU(negative_slope)
if should_use_node_residual: self.residual = nn.Linear(self.in_dim, self.out_dim * self.num_heads) else: self.register_buffer("residual", None) self.reset_parameters()
[docs] def reset_parameters(self): for param in [ self.edge_type_emb, self.W_nfeat, self.a_l, self.a_r, self.a_etype, ]: nn.init.xavier_uniform_(param, gain=1.414) if self.edge_in_dim: for param in [self.W_efeat, self.a_efeat]: nn.init.xavier_uniform_(param, gain=1.414) self.residual.reset_parameters() self.W_etype.reset_parameters()
[docs] def forward( self, edge_index: torch.LongTensor, node_feat: torch.FloatTensor, edge_type: torch.LongTensor, edge_feat: Optional[torch.FloatTensor] = None, ): # edge_index shape: [2, num_edges] # node_feat shape: [num_nodes, in_dim] # edge_feat shape: None | [num_edges, edge_in_dim] # edge_type shape: [num_edges] # For each head, project node features to out_dim and correct NaNs. # Output shape: [num_nodes, num_heads, out_dim] node_emb = self.nfeat_drop(node_feat) node_emb = torch.matmul(node_emb, self.W_nfeat).view( -1, self.num_heads, self.out_dim ) node_emb[torch.isnan(node_emb)] = 0.0 # For each head, project edge features to out_dim and correct NaNs. # Output shape: [num_edges, num_heads, edge_in_dim] if edge_feat is not None and self.edge_in_dim is not None: edge_emb = self.efeat_drop(edge_feat) edge_emb = torch.matmul(edge_emb, self.W_efeat).view( -1, self.num_heads, self.edge_in_dim ) edge_emb[torch.isnan(edge_emb)] = 0.0 # For each edge type, get an embedding of dimension edge_type_dim for each head # Output shape: [num_edges, num_heads, edge_type_dim] edge_type_emb = self.W_etype(self.edge_type_emb[edge_type], edge_type).view( -1, self.num_heads, self.edge_type_dim ) # Compute the attention scores (alpha) for all heads # Output shape: [num_edges, num_heads] row, col = edge_index[0, :], edge_index[1, :] h_l_term = (self.a_l * node_emb).sum(dim=-1)[row] h_r_term = (self.a_r * node_emb).sum(dim=-1)[col] h_etype_term = (self.a_etype * edge_type_emb).sum(dim=-1) h_efeat_term = ( 0 if edge_feat is None or self.edge_in_dim is None else (self.a_efeat * edge_emb).sum(dim=-1) ) alpha = self.leakyrelu(h_l_term + h_r_term + h_etype_term + h_efeat_term) alpha = softmax(alpha, row) # Propagate messages # Output shape: [num_nodes, num_heads, out_dim] out = self.propagate( edge_index, node_emb=node_emb, node_feat=node_feat, alpha=alpha ) # Concatenate embeddings across heads # Output shape: [num_nodes, num_heads * out_dim] out = out.view(-1, self.num_heads * self.out_dim) # Add node residual # Output shape: [num_nodes, num_heads * out_dim] if self.residual: out += self.residual(node_feat) return out
[docs] def message(self, node_emb_j, alpha): # Multiply embeddings for each head with attention scores # node_emb_j is shape [num_edges, num_heads, out_dim] # alpha is shape [num_edges, num_heads] return alpha.unsqueeze(-1) * node_emb_j