from typing import Optional
import torch
import torch_geometric
from torch import nn
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import softmax
[docs]
class SimpleHGNConv(MessagePassing):
"""
The SimpleHGN convolution layer based on https://arxiv.org/pdf/2112.14936
Here, we adopt a form which includes support for edge-features in addition to node-features for attention calculation.
This layer is based on the adaptation for link prediction tasks listed below Eq.14 in the paper.
Args:
in_channels (int): the input dimension of node features
edge_in_channels (Optional[int]): the input dimension of edge features
out_channels (int): the output dimension of node features
edge_type_dim (int): the hidden dimension allocated to edge-type embeddings (per head)
num_heads (int): the number of heads
num_edge_types (int): the number of edge types
dropout (float): the feature drop rate
negative_slope (float): the negative slope used in the LeakyReLU
should_use_node_residual (boolean): whether we need the node residual operation
"""
def __init__(
self,
in_channels: int,
out_channels: int,
num_edge_types: int,
edge_in_channels: Optional[int] = None,
num_heads: int = 1,
edge_type_dim: int = 16,
should_use_node_residual: bool = True,
negative_slope: float = 0.2,
dropout: float = 0.0,
):
super().__init__(aggr="add", node_dim=0)
[docs]
self.in_dim = in_channels
[docs]
self.out_dim = out_channels
[docs]
self.edge_in_dim = edge_in_channels
[docs]
self.edge_type_dim = edge_type_dim
[docs]
self.num_edge_types = num_edge_types
[docs]
self.num_heads = num_heads
# Encodes embeddings for each edge-type.
[docs]
self.edge_type_emb = nn.Parameter(
torch.empty(size=(self.num_edge_types, self.edge_type_dim))
)
# Multi-headed linear projection for edge-type embedding
[docs]
self.W_etype = torch_geometric.nn.HeteroLinear(
self.edge_type_dim, self.edge_type_dim * self.num_heads, self.num_edge_types
)
# Linear projection for node features (for each head)
[docs]
self.W_nfeat = nn.Parameter(
torch.FloatTensor(self.in_dim, self.out_dim * self.num_heads)
)
if self.edge_in_dim:
# Linear projection for edge features (for each head)
self.W_efeat = nn.Parameter(
torch.FloatTensor(self.edge_in_dim, self.edge_in_dim * self.num_heads)
)
# Attention weights for edge features
self.a_efeat = nn.Parameter(
torch.empty(size=(1, self.num_heads, self.edge_in_dim))
)
# Dropout for edge features
self.efeat_drop = nn.Dropout(dropout)
[docs]
self.a_l = nn.Parameter(torch.empty(size=(1, self.num_heads, self.out_dim)))
[docs]
self.a_r = nn.Parameter(torch.empty(size=(1, self.num_heads, self.out_dim)))
[docs]
self.a_etype = nn.Parameter(
torch.empty(size=(1, self.num_heads, self.edge_type_dim))
)
[docs]
self.nfeat_drop = nn.Dropout(dropout)
[docs]
self.leakyrelu = nn.LeakyReLU(negative_slope)
if should_use_node_residual:
self.residual = nn.Linear(self.in_dim, self.out_dim * self.num_heads)
else:
self.register_buffer("residual", None)
self.reset_parameters()
[docs]
def reset_parameters(self):
for param in [
self.edge_type_emb,
self.W_nfeat,
self.a_l,
self.a_r,
self.a_etype,
]:
nn.init.xavier_uniform_(param, gain=1.414)
if self.edge_in_dim:
for param in [self.W_efeat, self.a_efeat]:
nn.init.xavier_uniform_(param, gain=1.414)
self.residual.reset_parameters()
self.W_etype.reset_parameters()
[docs]
def forward(
self,
edge_index: torch.LongTensor,
node_feat: torch.FloatTensor,
edge_type: torch.LongTensor,
edge_feat: Optional[torch.FloatTensor] = None,
):
# edge_index shape: [2, num_edges]
# node_feat shape: [num_nodes, in_dim]
# edge_feat shape: None | [num_edges, edge_in_dim]
# edge_type shape: [num_edges]
# For each head, project node features to out_dim and correct NaNs.
# Output shape: [num_nodes, num_heads, out_dim]
node_emb = self.nfeat_drop(node_feat)
node_emb = torch.matmul(node_emb, self.W_nfeat).view(
-1, self.num_heads, self.out_dim
)
node_emb[torch.isnan(node_emb)] = 0.0
# For each head, project edge features to out_dim and correct NaNs.
# Output shape: [num_edges, num_heads, edge_in_dim]
if edge_feat is not None and self.edge_in_dim is not None:
edge_emb = self.efeat_drop(edge_feat)
edge_emb = torch.matmul(edge_emb, self.W_efeat).view(
-1, self.num_heads, self.edge_in_dim
)
edge_emb[torch.isnan(edge_emb)] = 0.0
# For each edge type, get an embedding of dimension edge_type_dim for each head
# Output shape: [num_edges, num_heads, edge_type_dim]
edge_type_emb = self.W_etype(self.edge_type_emb[edge_type], edge_type).view(
-1, self.num_heads, self.edge_type_dim
)
# Compute the attention scores (alpha) for all heads
# Output shape: [num_edges, num_heads]
row, col = edge_index[0, :], edge_index[1, :]
h_l_term = (self.a_l * node_emb).sum(dim=-1)[row]
h_r_term = (self.a_r * node_emb).sum(dim=-1)[col]
h_etype_term = (self.a_etype * edge_type_emb).sum(dim=-1)
h_efeat_term = (
0
if edge_feat is None or self.edge_in_dim is None
else (self.a_efeat * edge_emb).sum(dim=-1)
)
alpha = self.leakyrelu(h_l_term + h_r_term + h_etype_term + h_efeat_term)
alpha = softmax(alpha, row)
# Propagate messages
# Output shape: [num_nodes, num_heads, out_dim]
out = self.propagate(
edge_index, node_emb=node_emb, node_feat=node_feat, alpha=alpha
)
# Concatenate embeddings across heads
# Output shape: [num_nodes, num_heads * out_dim]
out = out.view(-1, self.num_heads * self.out_dim)
# Add node residual
# Output shape: [num_nodes, num_heads * out_dim]
if self.residual:
out += self.residual(node_feat)
return out
[docs]
def message(self, node_emb_j, alpha):
# Multiply embeddings for each head with attention scores
# node_emb_j is shape [num_edges, num_heads, out_dim]
# alpha is shape [num_edges, num_heads]
return alpha.unsqueeze(-1) * node_emb_j