Source code for gigl.src.common.models.pyg.nn.conv.edge_attr_gat_conv

from typing import Union

from torch import Tensor
from torch_geometric.nn import GATConv
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
from torch_geometric.utils import add_self_loops, remove_self_loops
from torch_sparse import SparseTensor, set_diag


[docs] class EdgeAttrGATConv(GATConv): r""" Compared to GATConv, EdgeAttrGATConv combines node features with edge features for message passing. """ def __init__(self, share_edge_att_message_weight: bool = True, *args, **kwargs): super().__init__(*args, **kwargs) if self.edge_dim is not None and not share_edge_att_message_weight: self.lin_edge_message = Linear( self.edge_dim, self.heads * self.out_channels, bias=False, weight_initializer="glorot", ) else: self.lin_edge_message = None self.reset_parameters()
[docs] def reset_parameters(self): super().reset_parameters() if hasattr(self, "lin_edge_message") and self.lin_edge_message is not None: self.lin_edge_message.reset_parameters()
[docs] def forward( self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None, return_attention_weights=None, ): r""" Compared to GATConv, EdgeAttrGATConv adds edge_attr as extra input of propagate Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ # NOTE: attention weights will be returned whenever # `return_attention_weights` is set to a value, regardless of its # actual value (might be `True` or `False`). This is a current somewhat # hacky workaround to allow for TorchScript support via the # `torch.jit._overload` decorator, as we can only change the output # arguments conditioned on type (`None` or `bool`), not based on its # actual value. H, C = self.heads, self.out_channels # We first transform the input node features. If a tuple is passed, we # transform source and target node features via separate weights: if isinstance(x, Tensor): assert x.dim() == 2, "Static graphs not supported in 'GATConv'" x_src = x_dst = self.lin_src(x).view(-1, H, C) else: # Tuple of source and target node features: x_src, x_dst = x assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'" x_src = self.lin_src(x_src).view(-1, H, C) if x_dst is not None: x_dst = self.lin_dst(x_dst).view(-1, H, C) x = (x_src, x_dst) # Next, we compute node-level attention coefficients, both for source # and target nodes (if present): alpha_src = (x_src * self.att_src).sum(dim=-1) alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) alpha = (alpha_src, alpha_dst) if self.add_self_loops: if isinstance(edge_index, Tensor): # We only want to add self-loops for nodes that appear both as # source and target nodes: num_nodes = x_src.size(0) if x_dst is not None: num_nodes = min(num_nodes, x_dst.size(0)) num_nodes = min(size) if size is not None else num_nodes edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) edge_index, edge_attr = add_self_loops( edge_index, edge_attr, fill_value=self.fill_value, num_nodes=num_nodes, ) elif isinstance(edge_index, SparseTensor): if self.edge_dim is None: edge_index = set_diag(edge_index) else: raise NotImplementedError( "The usage of 'edge_attr' and 'add_self_loops' " "simultaneously is currently not yet supported for " "'edge_index' in a 'SparseTensor' form" ) # edge_updater_type: (alpha: OptPairTensor, edge_attr: OptTensor) alpha = self.edge_updater(edge_index, alpha=alpha, edge_attr=edge_attr) # propagate_type: (x: OptPairTensor, alpha: Tensor) out = self.propagate( edge_index, x=x, alpha=alpha, size=size, edge_attr=edge_attr ) if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out = out + self.bias if isinstance(return_attention_weights, bool): if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout="coo") else: return out
[docs] def message(self, x_j, alpha, edge_attr): r""" Compared to GATConv, EdgeAttrGATConv has extra step of adding node features and edge features """ if edge_attr is not None and (self.lin_edge_message or self.lin_edge): edge_message_layer = ( self.lin_edge_message if self.lin_edge_message else self.lin_edge ) edge_attr = edge_message_layer(edge_attr).view( -1, self.heads, self.out_channels ) x_j += edge_attr return alpha.unsqueeze(-1) * x_j