Source code for gigl.src.common.models.pyg.homogeneous

from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.data
from torch_geometric.nn import (
    GATConv,
    GATv2Conv,
    GCNConv,
    GINConv,
    SAGEConv,
    TransformerConv,
)
from torch_geometric.nn.models import MLP

from gigl.common.logger import Logger
from gigl.src.common.constants.training import DEFAULT_NUM_GNN_HOPS
from gigl.src.common.models.layers.normalization import l2_normalize_embeddings
from gigl.src.common.models.pyg import utils as pyg_utils
from gigl.src.common.models.pyg.nn.conv.edge_attr_gat_conv import EdgeAttrGATConv
from gigl.src.common.models.pyg.nn.conv.gin_conv import GINEConv
from gigl.src.common.models.pyg.nn.models.feature_embedding import FeatureEmbeddingLayer
from gigl.src.common.models.pyg.nn.models.feature_interaction import FeatureInteraction
from gigl.src.common.models.pyg.nn.models.jumping_knowledge import JumpingKnowledge
from gigl.src.common.types.model import GnnModel, GraphBackend

[docs] logger = Logger()
[docs] class BasicHomogeneousGNN(nn.Module, GnnModel): def __init__( self, in_dim: int, hid_dim: int, out_dim: int, conv_kwargs: Dict[str, Any] = {}, edge_dim: Optional[int] = None, num_layers: int = DEFAULT_NUM_GNN_HOPS, activation: Callable = F.relu, activation_before_norm: bool = False, # apply activation function before normalization activation_after_last_conv: bool = False, # apply activation after the last conv layer dropout: float = 0.0, # dropout will auto set to 0.0 when model.eval() batchnorm: bool = False, # batch norm linear_layer: bool = False, return_emb: bool = False, should_l2_normalize_embedding_layer_output: bool = False, jk_mode: Optional[str] = None, jk_lstm_dim: Optional[int] = None, feature_interaction_layer: Optional[FeatureInteraction] = None, feature_embedding_layer: Optional[FeatureEmbeddingLayer] = None, **kwargs, ): super(BasicHomogeneousGNN, self).__init__()
[docs] self.in_dim = in_dim
[docs] self.hid_dim = hid_dim
[docs] self.out_dim = out_dim
[docs] self.activation = activation
[docs] self.activation_before_norm = activation_before_norm
[docs] self.activation_after_last_conv = activation_after_last_conv
[docs] self.dropout = nn.Dropout(p=dropout)
[docs] self.batchnorm = batchnorm
[docs] self.num_layers = num_layers
# Feature embedding layer to pass selected features through an embedding layer
[docs] self.feature_embedding_layer = feature_embedding_layer
# Feature interaction layers
[docs] self.feats_interaction = feature_interaction_layer
[docs] self.conv_layers: nn.ModuleList = self.init_conv_layers( # type: ignore in_dim=in_dim, out_dim=hid_dim if linear_layer or jk_mode else out_dim, edge_dim=edge_dim, hid_dim=hid_dim, num_layers=num_layers, **conv_kwargs, )
if batchnorm: num_heads = int(conv_kwargs.get("heads", 1)) num_batchnorm_layers = num_layers if jk_mode else num_layers - 1 self.batchnorm_layers = nn.ModuleList( [ nn.BatchNorm1d(hid_dim * num_heads) for i in range(num_batchnorm_layers) ] )
[docs] self.should_l2_normalize_embedding_layer_output = ( should_l2_normalize_embedding_layer_output )
if jk_mode: self.jk_layer = JumpingKnowledge( mode=jk_mode, hid_dim=hid_dim, out_dim=out_dim if not linear_layer else hid_dim, num_layers=num_layers, lstm_dim=jk_lstm_dim, ) else: self.jk_layer = None # type: ignore
[docs] self.return_emb = return_emb
[docs] self.linear_layer = linear_layer
if linear_layer: self.linear = nn.Linear(hid_dim, out_dim)
[docs] def forward( self, data: torch_geometric.data.Data, device: Optional[torch.device] = None, ) -> torch.Tensor: x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr # pass selected features through an embedding layer if self.feature_embedding_layer: x = self.feature_embedding_layer(x) # node feature interaction before graph convolution if self.feats_interaction: x = self.feats_interaction(x) xs: List[torch.Tensor] = [] for i, conv_layer in enumerate(self.conv_layers): if self.supports_edge_attr: x = conv_layer(x=x, edge_index=edge_index, edge_attr=edge_attr) else: x = conv_layer(x=x, edge_index=edge_index) # exclude batch norm, activation, dropout after last layer if ( i == self.num_layers - 1 and not self.jk_layer and not self.activation_after_last_conv ): break if self.activation_before_norm: x = self.activation(x) if self.batchnorm: x = self.batchnorm_layers[i](x) if not self.activation_before_norm: x = self.activation(x) x = self.dropout(x) if self.jk_layer: xs.append(x) if self.jk_layer: x = self.jk_layer(xs) if self.should_l2_normalize_embedding_layer_output: x = l2_normalize_embeddings(node_typed_embeddings=x) if self.return_emb: return x if self.linear_layer: x = self.linear(x) return x
[docs] def init_conv_layers( self, in_dim: Union[int, Tuple[int, int]], out_dim: int, edge_dim: Optional[int], hid_dim: int, num_layers: int, **kwargs, ) -> nn.ModuleList: raise NotImplementedError
@property
[docs] def graph_backend(self) -> GraphBackend: return GraphBackend.PYG
[docs] class GraphSAGE(BasicHomogeneousGNN):
[docs] supports_edge_weight = False
[docs] supports_edge_attr = False
[docs] def init_conv_layers( self, in_dim: Union[int, Tuple[int, int]], out_dim: int, edge_dim: Optional[int], hid_dim: int, num_layers: int, **kwargs, ) -> nn.ModuleList: remaining_kwargs, discarded_kwargs = pyg_utils.filter_dict( input_dict=kwargs, keys_to_keep=pyg_utils.MESSAGE_PASSING_BASE_CLS_ARGS + ["aggr", "normalize", "root_weight", "project", "bias"], ) logger.info( f"Discarded kwargs for {SAGEConv.__name__}: {discarded_kwargs.keys()}" ) conv_layers = nn.ModuleList( [ SAGEConv( in_channels=in_dim if i == 0 else hid_dim, out_channels=hid_dim if i < num_layers - 1 else out_dim, **remaining_kwargs, ) for i in range(num_layers) ] ) return conv_layers
[docs] class GIN(BasicHomogeneousGNN):
[docs] supports_edge_weight = False
[docs] supports_edge_attr = False
[docs] def init_conv_layers( self, in_dim: Union[int, Tuple[int, int]], out_dim: int, edge_dim: Optional[int], hid_dim: int, num_layers: int, **kwargs, ) -> nn.ModuleList: eps: float = kwargs.pop("eps", 0.0) train_eps: bool = kwargs.pop("train_eps", False) remaining_kwargs, discarded_kwargs = pyg_utils.filter_dict( input_dict=kwargs, keys_to_keep=pyg_utils.MESSAGE_PASSING_BASE_CLS_ARGS ) logger.info( f"Discarded kwargs for {GINConv.__name__}: {discarded_kwargs.keys()}" ) conv_layers = nn.ModuleList( [ GINConv( nn=MLP( [ in_dim if i == 0 else hid_dim, hid_dim if i < num_layers - 1 else out_dim, hid_dim if i < num_layers - 1 else out_dim, ], act=self.activation, act_first=self.activation_before_norm, # Note: PyG has its own BatchNorm class so this BatchNorm won't be converted to torch.nn.SyncBatchNorm norm="batch_norm" if self.batchnorm else None, ), eps=eps, train_eps=train_eps, **remaining_kwargs, ) for i in range(num_layers) ] ) return conv_layers
[docs] class GINE(BasicHomogeneousGNN):
[docs] supports_edge_weight = False
[docs] supports_edge_attr = True
[docs] def init_conv_layers( self, in_dim: Union[int, Tuple[int, int]], out_dim: int, edge_dim: Optional[int], hid_dim: int, num_layers: int, **kwargs, ) -> nn.ModuleList: eps: float = kwargs.pop("eps", 0.0) train_eps: bool = kwargs.pop("train_eps", False) remaining_kwargs, discarded_kwargs = pyg_utils.filter_dict( input_dict=kwargs, keys_to_keep=pyg_utils.MESSAGE_PASSING_BASE_CLS_ARGS ) logger.info( f"Discarded kwargs for {GINEConv.__name__}: {discarded_kwargs.keys()}" ) conv_layers = nn.ModuleList( [ GINEConv( nn=MLP( [ in_dim if i == 0 else hid_dim, hid_dim if i < num_layers - 1 else out_dim, hid_dim if i < num_layers - 1 else out_dim, ], act=self.activation, act_first=self.activation_before_norm, # Note: PyG has its own BatchNorm class so this BatchNorm won't be converted to torch.nn.SyncBatchNorm norm="batch_norm" if self.batchnorm else None, ), eps=eps, train_eps=train_eps, edge_dim=edge_dim, **remaining_kwargs, ) for i in range(num_layers) ] ) return conv_layers
[docs] class GAT(BasicHomogeneousGNN):
[docs] supports_edge_weight = False
[docs] supports_edge_attr = True
[docs] def init_conv_layers( self, in_dim: Union[int, Tuple[int, int]], out_dim: int, edge_dim: Optional[int], hid_dim: int, num_layers: int, **kwargs, ) -> nn.ModuleList: num_heads = int(kwargs.pop("heads", 1)) remaining_kwargs, discarded_kwargs = pyg_utils.filter_dict( input_dict=kwargs, keys_to_keep=pyg_utils.MESSAGE_PASSING_BASE_CLS_ARGS + [ "concat", "negative_slope", "dropout", "add_self_loops", "fill_value", "bias", ], ) logger.info( f"Discarded kwargs for {GATConv.__name__}: {discarded_kwargs.keys()}" ) conv_layers = nn.ModuleList( [ GATConv( in_channels=in_dim if i == 0 else hid_dim * num_heads, out_channels=hid_dim if i < num_layers - 1 else out_dim, edge_dim=edge_dim, heads=num_heads if i < num_layers - 1 else 1, **remaining_kwargs, ) for i in range(num_layers) ] ) return conv_layers
[docs] class GATv2(BasicHomogeneousGNN):
[docs] supports_edge_weight = False
[docs] supports_edge_attr = True
[docs] def init_conv_layers( self, in_dim: Union[int, Tuple[int, int]], out_dim: int, edge_dim: Optional[int], hid_dim: int, num_layers: int, **kwargs, ) -> nn.ModuleList: num_heads = kwargs.pop("heads", 1) fill_value = kwargs.pop("fill_value", "mean") share_weights = kwargs.pop("share_weights", False) remaining_kwargs, discarded_kwargs = pyg_utils.filter_dict( input_dict=kwargs, keys_to_keep=pyg_utils.MESSAGE_PASSING_BASE_CLS_ARGS + ["concat", "negative_slope", "dropout", "add_self_loops", "bias"], ) logger.info( f"Discarded kwargs for {GATv2Conv.__name__}: {discarded_kwargs.keys()}" ) conv_layers = nn.ModuleList( [ GATv2Conv( in_channels=in_dim if i == 0 else hid_dim * num_heads, out_channels=hid_dim if i < num_layers - 1 else out_dim, edge_dim=edge_dim, heads=num_heads if i < num_layers - 1 else 1, fill_value=fill_value, share_weights=share_weights, **remaining_kwargs, ) for i in range(num_layers) ] ) return conv_layers
[docs] class EdgeAttrGAT(BasicHomogeneousGNN):
[docs] supports_edge_weight = False
[docs] supports_edge_attr = True
[docs] def init_conv_layers( self, in_dim: Union[int, Tuple[int, int]], out_dim: int, edge_dim: Optional[int], hid_dim: int, num_layers: int, **kwargs, ) -> nn.ModuleList: num_heads = int(kwargs.pop("heads", 1)) share_edge_att_message_weight = kwargs.pop( "share_edge_att_message_weight", True ) remaining_kwargs, discarded_kwargs = pyg_utils.filter_dict( input_dict=kwargs, keys_to_keep=pyg_utils.MESSAGE_PASSING_BASE_CLS_ARGS + [ "concat", "negative_slope", "dropout", "add_self_loops", "fill_value", "bias", ], ) logger.info( f"Discarded kwargs for {EdgeAttrGATConv.__name__}: {discarded_kwargs.keys()}" ) conv_layers = nn.ModuleList( [ EdgeAttrGATConv( in_channels=in_dim if i == 0 else hid_dim * num_heads, out_channels=hid_dim if i < num_layers - 1 else out_dim, edge_dim=edge_dim, heads=num_heads if i < num_layers - 1 else 1, share_edge_att_message_weight=share_edge_att_message_weight, **remaining_kwargs, ) for i in range(num_layers) ] ) return conv_layers
[docs] class Transformer(BasicHomogeneousGNN):
[docs] supports_edge_weight = False
[docs] supports_edge_attr = True
[docs] def init_conv_layers( self, in_dim: Union[int, Tuple[int, int]], out_dim: int, edge_dim: Optional[int], hid_dim: int, num_layers: int, **kwargs, ) -> nn.ModuleList: num_heads = int(kwargs.pop("heads", 1)) beta = kwargs.pop("beta", False) remaining_kwargs, discarded_kwargs = pyg_utils.filter_dict( input_dict=kwargs, keys_to_keep=pyg_utils.MESSAGE_PASSING_BASE_CLS_ARGS + [ "concat", "dropout", "bias", "root_weight", ], ) logger.info( f"Discarded kwargs for {EdgeAttrGATConv.__name__}: {discarded_kwargs.keys()}" ) # Layers prior to the last layer will be a concatenation of heads by default # The last layer will do a average pool on all heads so the output is still out_dim return nn.ModuleList( [ TransformerConv( in_channels=in_dim if i == 0 else hid_dim * num_heads, out_channels=hid_dim if i < num_layers - 1 else out_dim, edge_dim=edge_dim, heads=num_heads if i < num_layers - 1 else 1, beta=beta, **remaining_kwargs, ) for i in range(num_layers) ] )
[docs] class TwoLayerGCN(torch.nn.Module, GnnModel): def __init__( self, in_dim: int, out_dim: int, hid_dim: int = 16, is_training: bool = True, should_l2_normalize_output: bool = False, **kwargs, ): """ Simple 2 layer GCN Implementation using PyG constructs Args: in_feats (int): number input features out_dim (int): num output classes h_feats (int, optional): num hidden features. Defaults to 16. **kwargs (:class:`torch_geometric.nn.conv.MessagePassing`): Additional arguments for all GCNConv layers """ super().__init__()
[docs] self.is_training = is_training
[docs] self.should_normalize = should_l2_normalize_output
remaining_kwargs, discarded_kwargs = pyg_utils.filter_dict( input_dict=kwargs, keys_to_keep=pyg_utils.MESSAGE_PASSING_BASE_CLS_ARGS + [ "improved", "cached", "add_self_loops", "normalize", "bias", ], ) logger.info( f"Discarded kwargs for {GCNConv.__name__}: {discarded_kwargs.keys()}" )
[docs] self.conv1 = GCNConv( in_channels=in_dim, out_channels=hid_dim, **remaining_kwargs )
[docs] self.conv2 = GCNConv( in_channels=hid_dim, out_channels=out_dim, **remaining_kwargs )
[docs] def forward(self, data: torch_geometric.data.Data) -> torch.Tensor: x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.is_training) x = self.conv2(x, edge_index) if self.should_normalize: x = F.normalize(x, p=2, dim=1) return x
@property
[docs] def graph_backend(self) -> GraphBackend: return GraphBackend.PYG