Source code for gigl.src.common.models.pyg.nn.conv.hgt_conv

import math
from typing import Dict, List, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense import HeteroDictLinear, HeteroLinear
from torch_geometric.nn.inits import ones
from torch_geometric.nn.parameter_dict import ParameterDict
from torch_geometric.typing import Adj, EdgeType, Metadata, NodeType
from torch_geometric.utils import softmax
from torch_geometric.utils.hetero import construct_bipartite_edge_index


[docs] class HGTConv(MessagePassing): r""" Modified version of PyG's HGTConv conv implementation https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/hgt_conv.html#HGTConv PyG's implementation drops node types in the graph with no incoming message passing edges (line 208 inside forward), while ours keeps those node types in the output. The Heterogeneous Graph Transformer (HGT) operator from the `"Heterogeneous Graph Transformer" <https://arxiv.org/abs/2003.01332>`_ paper. .. note:: For an example of using HGT, see `examples/hetero/hgt_dblp.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ hetero/hgt_dblp.py>`_. Args: in_channels (int or Dict[str, int]): Size of each input sample of every node type, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. metadata (Tuple[List[str], List[Tuple[str, str, str]]]): The metadata of the heterogeneous graph, *i.e.* its node and edge types given by a list of strings and a list of string triplets, respectively. See :meth:`torch_geometric.data.HeteroData.metadata` for more information. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ def __init__( self, in_channels: Union[int, Dict[str, int]], out_channels: int, metadata: Metadata, heads: int = 1, **kwargs, ): super().__init__(aggr="add", node_dim=0, **kwargs) if out_channels % heads != 0: raise ValueError( f"'out_channels' (got {out_channels}) must be " f"divisible by the number of heads (got {heads})" ) if not isinstance(in_channels, dict): in_channels = {node_type: in_channels for node_type in metadata[0]}
[docs] self.in_channels = in_channels
[docs] self.out_channels = out_channels
[docs] self.heads = heads
[docs] self.node_types = metadata[0]
[docs] self.edge_types = metadata[1]
[docs] self.edge_types_map = {edge_type: i for i, edge_type in enumerate(metadata[1])}
[docs] self.kqv_lin = HeteroDictLinear(self.in_channels, self.out_channels * 3)
[docs] self.out_lin = HeteroDictLinear( self.out_channels, self.out_channels, types=self.node_types )
dim = out_channels // heads num_types = heads * len(self.edge_types)
[docs] self.k_rel = HeteroLinear(dim, dim, num_types, bias=False, is_sorted=True)
[docs] self.v_rel = HeteroLinear(dim, dim, num_types, bias=False, is_sorted=True)
[docs] self.skip = ParameterDict( {node_type: Parameter(torch.empty(1)) for node_type in self.node_types} )
[docs] self.p_rel = ParameterDict()
for edge_type in self.edge_types: edge_type = "__".join(edge_type) self.p_rel[edge_type] = Parameter(torch.empty(1, heads)) self.reset_parameters()
[docs] def reset_parameters(self): super().reset_parameters() self.kqv_lin.reset_parameters() self.out_lin.reset_parameters() self.k_rel.reset_parameters() self.v_rel.reset_parameters() ones(self.skip) ones(self.p_rel)
def _cat(self, x_dict: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, int]]: """Concatenates a dictionary of features.""" cumsum = 0 outs: List[Tensor] = [] offset: Dict[str, int] = {} for key, x in x_dict.items(): outs.append(x) offset[key] = cumsum cumsum += x.size(0) return torch.cat(outs, dim=0), offset def _construct_src_node_feat( self, k_dict: Dict[str, Tensor], v_dict: Dict[str, Tensor], edge_index_dict: Dict[EdgeType, Adj], ) -> Tuple[Tensor, Tensor, Dict[EdgeType, int]]: """Constructs the source node representations.""" cumsum = 0 num_edge_types = len(self.edge_types) H, D = self.heads, self.out_channels // self.heads # Flatten into a single tensor with shape [num_edge_types * heads, D]: ks: List[Tensor] = [] vs: List[Tensor] = [] type_list: List[Tensor] = [] offset: Dict[EdgeType, int] = {} for edge_type in edge_index_dict.keys(): src = edge_type[0] N = k_dict[src].size(0) offset[edge_type] = cumsum cumsum += N # construct type_vec for curr edge_type with shape [H, D] edge_type_offset = self.edge_types_map[edge_type] type_vec = ( torch.arange(H, dtype=torch.long).view(-1, 1).repeat(1, N) * num_edge_types + edge_type_offset ) type_list.append(type_vec) ks.append(k_dict[src]) vs.append(v_dict[src]) ks = torch.cat(ks, dim=0).transpose(0, 1).reshape(-1, D) # type: ignore vs = torch.cat(vs, dim=0).transpose(0, 1).reshape(-1, D) # type: ignore type_vec = torch.cat(type_list, dim=1).flatten() k = self.k_rel(ks, type_vec).view(H, -1, D).transpose(0, 1) v = self.v_rel(vs, type_vec).view(H, -1, D).transpose(0, 1) return k, v, offset
[docs] def forward( self, x_dict: Dict[NodeType, Tensor], edge_index_dict: Dict[EdgeType, Adj], # Support both. ) -> Dict[NodeType, Optional[Tensor]]: r"""Runs the forward pass of the module. Args: x_dict (Dict[str, torch.Tensor]): A dictionary holding input node features for each individual node type. edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]): A dictionary holding graph connectivity information for each individual edge type, either as a :class:`torch.Tensor` of shape :obj:`[2, num_edges]` or a :class:`torch_sparse.SparseTensor`. :rtype: :obj:`Dict[str, Optional[torch.Tensor]]` - The output node embeddings for each node type. In case a node type does not receive any message, its output will be set to :obj:`None`. """ F = self.out_channels H = self.heads D = F // H k_dict, q_dict, v_dict, out_dict = {}, {}, {}, {} # Compute K, Q, V over node types: kqv_dict = self.kqv_lin(x_dict) for key, val in kqv_dict.items(): k, q, v = torch.tensor_split(val, 3, dim=1) k_dict[key] = k.view(-1, H, D) q_dict[key] = q.view(-1, H, D) v_dict[key] = v.view(-1, H, D) q, dst_offset = self._cat(q_dict) k, v, src_offset = self._construct_src_node_feat( k_dict, v_dict, edge_index_dict ) edge_index, edge_attr = construct_bipartite_edge_index( edge_index_dict, src_offset, dst_offset, edge_attr_dict=self.p_rel ) out = self.propagate(edge_index, k=k, q=q, v=v, edge_attr=edge_attr, size=None) # Reconstruct output node embeddings dict: for node_type, start_offset in dst_offset.items(): end_offset = start_offset + q_dict[node_type].size(0) out_dict[node_type] = out[start_offset:end_offset] # Transform output node embeddings: a_dict = self.out_lin( { k: torch.nn.functional.gelu(v) if v is not None else v for k, v in out_dict.items() } ) # Iterate over node types: for node_type, out in out_dict.items(): out = a_dict[node_type] if out.size(-1) == x_dict[node_type].size(-1): alpha = self.skip[node_type].sigmoid() out = alpha * out + (1 - alpha) * x_dict[node_type] out_dict[node_type] = out return out_dict
[docs] def message( self, k_j: Tensor, q_i: Tensor, v_j: Tensor, edge_attr: Tensor, index: Tensor, ptr: Optional[Tensor], size_i: Optional[int], ) -> Tensor: alpha = (q_i * k_j).sum(dim=-1) * edge_attr alpha = alpha / math.sqrt(q_i.size(-1)) alpha = softmax(alpha, index, ptr, size_i) out = v_j * alpha.view(-1, self.heads, 1) return out.view(-1, self.out_channels)
def __repr__(self) -> str: return ( f"{self.__class__.__name__}(-1, {self.out_channels}, " f"heads={self.heads})" )