Source code for gigl.src.common.modeling_task_specs.graphsage_template_modeling_spec

from contextlib import ExitStack
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.distributed
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.algorithms.join import Join, Joinable
from torch_geometric.nn import GraphSAGE

from gigl.common.logger import Logger
from gigl.common.utils.torch_training import (
    get_rank,
    get_world_size,
    is_distributed_available_and_initialized,
)
from gigl.src.common.graph_builder.graph_builder_factory import GraphBuilderFactory
from gigl.src.common.types.graph_data import CondensedEdgeType, CondensedNodeType
from gigl.src.common.types.model import GraphBackend
from gigl.src.common.types.model_eval_metrics import (
    EvalMetric,
    EvalMetricsCollection,
    EvalMetricType,
)
from gigl.src.common.types.pb_wrappers.dataset_metadata_utils import (
    Dataloaders,
    DataloaderTypes,
    NodeAnchorBasedLinkPredictionDatasetDataloaders,
)
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.utils.eval_metrics import hit_rate_at_k
from gigl.src.inference.v1.lib.base_inferencer import (
    InferBatchResults,
    NodeAnchorBasedLinkPredictionBaseInferencer,
    no_grad_eval,
)
from gigl.src.training.v1.lib.base_trainer import BaseTrainer
from gigl.src.training.v1.lib.data_loaders.node_anchor_based_link_prediction_data_loader import (
    NodeAnchorBasedLinkPredictionBatch,
)
from gigl.src.training.v1.lib.data_loaders.rooted_node_neighborhood_data_loader import (
    RootedNodeNeighborhoodBatch,
)
from gigl.src.training.v1.lib.eval_metrics import KS_FOR_EVAL as ks

[docs] logger = Logger()
[docs] class GraphSageTemplateTrainerSpec( BaseTrainer, NodeAnchorBasedLinkPredictionBaseInferencer ): """ Template Simple Training Spec that uses GraphSAGE for Node Anchor Based Link Prediction with DDP support. Arguments are to be passed in via trainerArgs in GBML Config. Args: hidden_dim (int): Hidden dimension to use for the model (default: 64) num_layers (int): Number of layers to use for the model (default: 2) out_channels (int): Output channels to use for the model (default: 64) validate_every_n_batches (int): Number of batches to validate after (default: 20) num_val_batches (int): Number of batches to validate on (default: 10) num_test_batches (int): Number of batches to test on (default: 100) early_stop_patience (int): Number of consecutive checks without improvement to trigger early stopping (default: 3) num_epochs (int): Number of epochs to train the model for (default: 5) optim_lr (float): Learning rate to use for the optimizer (default: 0.001) main_sample_batch_size (int): Batch size to use for the main samples (default: 256) random_negative_batch_size (int): Batch size to use for the random negative samples (default: 64) train_main_num_workers (int): Number of workers to use for the train main dataloader (default: 2) val_main_num_workers (int): Number of workers to use for the val main dataloader (default: 1) """ def __init__(self, **kwargs) -> None: super().__init__()
[docs] self.hidden_dim = int(kwargs.get("hidden_dim", 64))
[docs] self.num_layers = int(kwargs.get("num_layers", 2))
[docs] self.out_channels = int(kwargs.get("out_channels", 64))
[docs] self.validate_every_n_batches = int(kwargs.get("validate_every_n_batches", 20))
[docs] self.num_val_batches = int(kwargs.get("num_val_batches", 10))
[docs] self.num_test_batches = int(kwargs.get("num_test_batches", 100))
[docs] self.early_stop_patience = int(kwargs.get("early_stop_patience", 3))
[docs] self.num_epochs = int(kwargs.get("num_epochs", 5))
[docs] self.optim_lr = float(kwargs.get("optim_lr", 0.001))
[docs] self.main_sample_batch_size = int(kwargs.get("main_sample_batch_size", 256))
[docs] self.random_negative_batch_size = int( kwargs.get("random_negative_batch_size", 64) )
self._graph_builder = GraphBuilderFactory.get_graph_builder( backend_name=GraphBackend("PyG") ) # Prepare dataloader configurations dataloader_batch_size_map: Dict[DataloaderTypes, int] = { DataloaderTypes.train_main: self.main_sample_batch_size, DataloaderTypes.val_main: self.main_sample_batch_size, DataloaderTypes.test_main: self.main_sample_batch_size, DataloaderTypes.train_random_negative: self.random_negative_batch_size, DataloaderTypes.val_random_negative: self.random_negative_batch_size, DataloaderTypes.test_random_negative: self.random_negative_batch_size, } dataloader_num_workers_map: Dict[DataloaderTypes, int] = { DataloaderTypes.train_main: int(kwargs.get("train_main_num_workers", 2)), DataloaderTypes.val_main: int(kwargs.get("val_main_num_workers", 1)), DataloaderTypes.test_main: int(kwargs.get("test_main_num_workers", 1)), DataloaderTypes.train_random_negative: int( kwargs.get("train_random_negative_num_workers", 2) ), DataloaderTypes.val_random_negative: int( kwargs.get("val_random_negative_num_workers", 1) ), DataloaderTypes.test_random_negative: int( kwargs.get("test_random_negative_num_workers", 1) ), } # Utility for data loader initialization self._dataloaders: NodeAnchorBasedLinkPredictionDatasetDataloaders = ( NodeAnchorBasedLinkPredictionDatasetDataloaders( batch_size_map=dataloader_batch_size_map, num_workers_map=dataloader_num_workers_map, ) ) @property
[docs] def model(self) -> torch.nn.Module: return self.__model
@model.setter def model(self, model: torch.nn.Module) -> None: self.__model = model self.__model.graph_backend = GraphBackend.PYG # type: ignore
[docs] def init_model( self, gbml_config_pb_wrapper: GbmlConfigPbWrapper, state_dict: Optional[dict] = None, device: torch.device = torch.device("cuda"), ) -> nn.Module: node_feat_dim = gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper.condensed_node_type_to_feature_dim_map[ CondensedNodeType(0) ] model = GraphSAGE( in_channels=node_feat_dim, hidden_channels=self.hidden_dim, num_layers=self.num_layers, out_channels=self.out_channels, ) self.model = model if state_dict is not None: self.model.load_state_dict(state_dict) return self.model
# function for setting up things like optimizer, scheduler, criterion etc.
[docs] def setup_for_training(self): self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.optim_lr)
[docs] def train( self, gbml_config_pb_wrapper: GbmlConfigPbWrapper, device: torch.device, profiler=None, ): """ Main Training loop for the GraphSAGE model. Args: gbml_config_pb_wrapper: GbmlConfigPbWrapper for gbmlConfig proto device: torch.device to run the training on num_epochs: Number of epochs to train the model for profiler: Profiler object to profile the training """ early_stop_counter = 0 best_val_loss = float("inf") data_loaders: Dataloaders = self._dataloaders.get_training_dataloaders( gbml_config_pb_wrapper=gbml_config_pb_wrapper, graph_backend=self.model.graph_backend, device=device, ) assert ( data_loaders.train_main is not None and data_loaders.val_main is not None and data_loaders.train_random_negative is not None and data_loaders.val_random_negative is not None ) logger.info("Data loaders initialized") main_data_loader = data_loaders.train_main random_negative_data_loader = data_loaders.train_random_negative val_main_data_loader_iter = iter(data_loaders.val_main) val_random_data_loader_iter = iter(data_loaders.val_random_negative) main_batch: NodeAnchorBasedLinkPredictionBatch random_negative_batch: RootedNodeNeighborhoodBatch logger.info("Starting Training...") with ExitStack() as stack: if is_distributed_available_and_initialized(): assert isinstance(self.model, Joinable) stack.enter_context(Join([self.model])) logger.info(f"Model on rank {get_rank()} joined.") self.model.train() for batch_index, (main_batch, random_negative_batch) in enumerate( zip(main_data_loader, random_negative_data_loader), start=1 ): pos_scores, hard_neg_scores, random_neg_scores = self._process_batch( main_batch=main_batch, random_negative_batch=random_negative_batch, device=device, ) loss = self._compute_loss( pos_scores, hard_neg_scores, random_neg_scores, device ) logger.info(f"Processed batch {batch_index}, Loss: {loss.item()}") self.optimizer.zero_grad() loss.backward() self.optimizer.step() if ( batch_index % (self.validate_every_n_batches // get_world_size()) == 0 ): if is_distributed_available_and_initialized(): torch.distributed.barrier() logger.info(f"Validating at batch {batch_index}") avg_val_loss = self.validate( val_main_data_loader_iter, val_random_data_loader_iter, device, ) if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss early_stop_counter = 0 logger.info( f"Validation Loss Improved to {best_val_loss:.4f}. Resetting early stop counter." ) else: early_stop_counter += 1 logger.info( f"No improvement in Validation Loss for {early_stop_counter} consecutive checks. Early Stop Counter: {early_stop_counter}" ) if early_stop_counter >= self.early_stop_patience: logger.info( f"Early stopping triggered after {early_stop_counter} checks without improvement. Best Validation Loss: {best_val_loss:.4f}" ) break logger.info( f"Validation Loss: {avg_val_loss:.4f} at batch {batch_index}" )
def _compute_loss( self, pos_scores_list: List[torch.Tensor], hard_neg_scores_list: List[torch.Tensor], random_neg_scores_list: List[torch.Tensor], device: torch.device, ) -> torch.Tensor: total_loss: torch.Tensor = torch.tensor(0.0, device=device) total_sample_size = 0 for pos_scores, hard_neg_scores, random_neg_scores in zip( pos_scores_list, hard_neg_scores_list, random_neg_scores_list ): all_neg_scores = torch.cat((hard_neg_scores, random_neg_scores), dim=1) # shape=[1, num_hard_neg_nodes + num_random_neg_nodes] if all_neg_scores.numel() > 0 and pos_scores.numel() > 0: all_neg_scores_repeated = all_neg_scores.repeat(1, pos_scores.shape[1]) pos_scores_repeated = pos_scores.repeat_interleave( all_neg_scores.shape[1], dim=1 ) targets = torch.ones_like(pos_scores_repeated).to(device=device) loss = F.margin_ranking_loss( input1=pos_scores_repeated, input2=all_neg_scores_repeated, target=targets, margin=0.5, reduction="sum", ) total_loss += loss total_sample_size += pos_scores_repeated.numel() if total_sample_size > 0: average_loss = total_loss / total_sample_size else: average_loss = torch.tensor(0.0, device=device) return average_loss def _process_batch( self, main_batch: NodeAnchorBasedLinkPredictionBatch, random_negative_batch: RootedNodeNeighborhoodBatch, device: torch.device, ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: main_embeddings = self.model( main_batch.graph.x.to(device), main_batch.graph.edge_index.to(device) ) random_negative_embeddings = self.model( random_negative_batch.graph.x.to(device), random_negative_batch.graph.edge_index.to(device), ) pos_score_list: List[torch.Tensor] = [] hard_neg_score_list: List[torch.Tensor] = [] random_neg_score_list: List[torch.Tensor] = [] main_batch_root_node_indices = main_batch.root_node_indices.to(device=device) # For homogenous graph, we only have one condensed node type random_neg_root_node_indices = ( random_negative_batch.condensed_node_type_to_root_node_indices_map[ CondensedNodeType(0) ].to(device=device) ) # inner product (decoder) batch_random_neg_scores = torch.mm( main_embeddings[main_batch_root_node_indices], random_negative_embeddings[random_neg_root_node_indices].T, ) for root_node_index, root_node in enumerate(main_batch_root_node_indices): root_node = torch.unsqueeze(root_node, 0) pos_scores = torch.FloatTensor([]).to(device=device) hard_neg_scores = torch.FloatTensor([]).to(device=device) pos_nodes: torch.LongTensor = main_batch.pos_supervision_edge_data[ CondensedEdgeType(0) ].root_node_to_target_node_id[root_node.item()] if pos_nodes.numel(): pos_scores = torch.mm( main_embeddings[root_node], main_embeddings[pos_nodes].T ) hard_neg_nodes: ( torch.LongTensor ) = main_batch.hard_neg_supervision_edge_data[ CondensedEdgeType(0) ].root_node_to_target_node_id[ root_node.item() ] # shape=[num_hard_neg_nodes] if hard_neg_nodes.numel(): hard_neg_scores = torch.mm( main_embeddings[root_node], main_embeddings[hard_neg_nodes].T ) random_neg_scores = batch_random_neg_scores[[root_node_index], :].to( device=device ) pos_score_list.append(pos_scores) hard_neg_score_list.append(hard_neg_scores) random_neg_score_list.append(random_neg_scores) return pos_score_list, hard_neg_score_list, random_neg_score_list @no_grad_eval
[docs] def validate( self, main_data_loader: torch.utils.data.dataloader._BaseDataLoaderIter, random_negative_data_loader: torch.utils.data.dataloader._BaseDataLoaderIter, device: torch.device, ) -> float: """ Get the validation loss for the model using the similarity scores for the positive and negative samples. Args: main_data_loader: DataLoader for the positive samples random_negative_data_loader: DataLoader for the random negative samples device: torch.device to run the validation on Returns: float: Average validation loss """ validation_metrics = self._compute_metrics( main_data_loader=main_data_loader, random_negative_data_loader=random_negative_data_loader, device=device, num_batches=self.num_val_batches, ) avg_val_mrr = validation_metrics["avg_mrr"] avg_val_loss = validation_metrics["avg_loss"] logger.info(f"Validation got MRR: {avg_val_mrr}") return avg_val_loss
[docs] def eval( self, gbml_config_pb_wrapper: GbmlConfigPbWrapper, device: torch.device ) -> EvalMetricsCollection: """ Evaluate the model using the test data loaders. Args: gbml_config_pb_wrapper: GbmlConfigPbWrapper for gbmlConfig proto device: torch.device to run the evaluation on """ logger.info("Start testing...") data_loaders: Dataloaders = self._dataloaders.get_test_dataloaders( gbml_config_pb_wrapper=gbml_config_pb_wrapper, graph_backend=self.model.graph_backend, device=device, ) assert ( data_loaders.test_main is not None and data_loaders.test_random_negative is not None ) eval_metrics = self._compute_metrics( main_data_loader=iter(data_loaders.test_main), random_negative_data_loader=iter(data_loaders.test_random_negative), device=device, num_batches=self.num_test_batches, ) avg_mrr = eval_metrics["avg_mrr"] avg_hit_rates = eval_metrics["avg_hit_rates"] logger.info(f"Average MRR: {avg_mrr}") for k, hit_rate in zip(ks, avg_hit_rates): logger.info(f"Hit Rate@{k}: {hit_rate.item()}") hit_rates_model_metrics = [ EvalMetric( name=f"HitRate_at_{k}", value=rate, ) for k, rate in zip(ks, avg_hit_rates) ] metric_list = [ EvalMetric.from_eval_metric_type( eval_metric_type=EvalMetricType.mrr, value=avg_mrr, ), *hit_rates_model_metrics, ] metrics = EvalMetricsCollection(metrics=metric_list) return metrics
def _compute_metrics( self, main_data_loader: torch.utils.data.dataloader._BaseDataLoaderIter, random_negative_data_loader: torch.utils.data.dataloader._BaseDataLoaderIter, device: torch.device, num_batches: int, ) -> Dict[str, Any]: self.model.eval() total_mrr: float = 0.0 total_loss: float = 0.0 ks = [1, 5, 10, 50, 100, 500] total_hit_rates = torch.zeros(len(ks), device=device) num_batches_processed = 0 if is_distributed_available_and_initialized(): num_batches_per_rank = num_batches // get_world_size() else: num_batches_per_rank = num_batches with torch.no_grad(): for batch_idx, (main_batch, random_negative_batch) in enumerate( zip(main_data_loader, random_negative_data_loader), start=1 ): if batch_idx > num_batches_per_rank: break ( pos_scores_list, hard_neg_scores_list, random_neg_scores_list, ) = self._process_batch( main_batch=main_batch, random_negative_batch=random_negative_batch, device=device, ) loss = self._compute_loss( pos_scores_list=pos_scores_list, hard_neg_scores_list=hard_neg_scores_list, random_neg_scores_list=random_neg_scores_list, device=device, ) total_loss += loss.item() for pos_scores, hard_neg_scores, random_neg_scores in zip( pos_scores_list, hard_neg_scores_list, random_neg_scores_list ): neg_scores = torch.cat((hard_neg_scores, random_neg_scores), dim=1) if pos_scores.numel() == 0: continue combined_scores = torch.cat((pos_scores, neg_scores), dim=1) ranks = torch.argsort(combined_scores, dim=1, descending=True) pos_rank = (ranks == 0).nonzero(as_tuple=True)[1] + 1 mrr = 1.0 / pos_rank.float() hit_rates = hit_rate_at_k( pos_scores=pos_scores, # type: ignore neg_scores=neg_scores, # type: ignore ks=torch.tensor(ks, device=device, dtype=torch.long), # type: ignore ) total_mrr += mrr.mean().item() total_hit_rates += hit_rates.to(device) num_batches_processed += 1 # Reduce the total_mrr, total_loss and total_hit_rates across all ranks (DDP) if is_distributed_available_and_initialized(): total_mrr_tensor = torch.tensor(total_mrr, device=device) torch.distributed.all_reduce( total_mrr_tensor, op=torch.distributed.ReduceOp.SUM ) total_mrr = total_mrr_tensor.item() / get_world_size() total_loss_tensor = torch.tensor(total_loss, device=device) torch.distributed.all_reduce( total_loss_tensor, op=torch.distributed.ReduceOp.SUM ) total_loss = total_loss_tensor.item() torch.distributed.all_reduce( total_hit_rates, op=torch.distributed.ReduceOp.SUM ) total_hit_rates /= get_world_size() num_batches_tensor = torch.tensor(num_batches_processed, device=device) torch.distributed.all_reduce( num_batches_tensor, op=torch.distributed.ReduceOp.SUM ) num_batches_processed = int(num_batches_tensor.item()) avg_mrr = total_mrr / num_batches_processed if num_batches_processed > 0 else 0 avg_loss = ( total_loss / num_batches_processed if num_batches_processed > 0 else 0 ) avg_hit_rates = ( total_hit_rates / num_batches_processed if num_batches_processed > 0 else torch.zeros(len(ks), device=device) ) metrics = { "avg_mrr": avg_mrr, "avg_loss": avg_loss, "avg_hit_rates": avg_hit_rates, } return metrics @no_grad_eval
[docs] def infer_batch( self, batch: RootedNodeNeighborhoodBatch, device: torch.device = torch.device("cpu"), ) -> InferBatchResults: out = self.model(batch.graph.x.to(device), batch.graph.edge_index.to(device)) batch_root_node_indices = batch.condensed_node_type_to_root_node_indices_map[ CondensedNodeType(0) # For homogenous graph only one condensed node type ].to(device=device) embeddings = out[batch_root_node_indices] return InferBatchResults(embeddings=embeddings, predictions=None)
@property
[docs] def supports_distributed_training(self) -> bool: return True