Source code for gigl.src.common.modeling_task_specs.node_classification_modeling_task_spec

from contextlib import ExitStack
from typing import Callable, Dict, Optional, OrderedDict

import tensorflow as tf
import torch
import torch.distributed
import torch.nn.functional as F
import torch.utils.data
from torch.distributed.algorithms.join import Join, Joinable

from gigl.common.logger import Logger
from gigl.common.utils.torch_training import (
    get_rank,
    is_distributed_available_and_initialized,
)
from gigl.src.common.constants.graph_metadata import DEFAULT_CONDENSED_NODE_TYPE
from gigl.src.common.modeling_task_specs.utils.profiler_wrapper import TorchProfiler
from gigl.src.common.models.pyg.homogeneous import TwoLayerGCN
from gigl.src.common.types.model_eval_metrics import (
    EvalMetric,
    EvalMetricsCollection,
    EvalMetricType,
)
from gigl.src.common.types.pb_wrappers.dataset_metadata import DatasetMetadataPbWrapper
from gigl.src.common.types.pb_wrappers.dataset_metadata_utils import (
    SupervisedNodeClassificationDatasetDataloaders,
)
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.types.pb_wrappers.preprocessed_metadata import (
    PreprocessedMetadataPbWrapper,
)
from gigl.src.inference.v1.lib.base_inferencer import (
    InferBatchResults,
    SupervisedNodeClassificationBaseInferencer,
    no_grad_eval,
)
from gigl.src.training.v1.lib.base_trainer import BaseTrainer
from gigl.src.training.v1.lib.data_loaders.common import DataloaderTypes
from gigl.src.training.v1.lib.data_loaders.supervised_node_classification_data_loader import (
    SupervisedNodeClassificationBatch,
)
from snapchat.research.gbml import dataset_metadata_pb2

[docs] logger = Logger()
[docs] class NodeClassificationModelingTaskSpec( BaseTrainer, SupervisedNodeClassificationBaseInferencer ): def __init__(self, is_training: bool = True, **kwargs) -> None: self.__optim_lr = float(kwargs.get("optim_lr", 0.01)) self.__optim_weight_decay = float(kwargs.get("optim_weight_decay", 5e-4)) self.__num_epochs = int(kwargs.get("num_epochs", 5)) self.__out_dim = int(kwargs.get("out_dim", 7)) self.__is_training = is_training main_sample_batch_size = int(kwargs.get("main_sample_batch_size", 16)) dataloader_batch_size_map: Dict[DataloaderTypes, int] = { DataloaderTypes.train_main: main_sample_batch_size, DataloaderTypes.val_main: main_sample_batch_size, DataloaderTypes.test_main: main_sample_batch_size, } # TODO (mkolodner-sc): Investigate how we can automatically infer num_worker values dataloader_num_workers_map: Dict[DataloaderTypes, int] = { DataloaderTypes.train_main: int(kwargs.get("train_main_num_workers", 0)), DataloaderTypes.val_main: int(kwargs.get("val_main_num_workers", 0)), DataloaderTypes.test_main: int(kwargs.get("test_main_num_workers", 0)), } self._dataloaders: SupervisedNodeClassificationDatasetDataloaders = ( SupervisedNodeClassificationDatasetDataloaders( batch_size_map=dataloader_batch_size_map, num_workers_map=dataloader_num_workers_map, ) ) super().__init__(**kwargs) @property
[docs] def model(self) -> torch.nn.Module: return self.__model
@model.setter def model(self, model: torch.nn.Module) -> None: self.__model = model @property
[docs] def gbml_config_pb_wrapper(self) -> GbmlConfigPbWrapper: if not self.__gbml_config_pb_wrapper: raise ValueError( "gbml_config_pb_wrapper is not initialized before use, " "run init_model to set." ) return self.__gbml_config_pb_wrapper
@property
[docs] def supports_distributed_training(self) -> bool: return True
[docs] def init_model( self, gbml_config_pb_wrapper: GbmlConfigPbWrapper, state_dict: Optional[OrderedDict[str, torch.Tensor]] = None, ) -> torch.nn.Module: self.__gbml_config_pb_wrapper = gbml_config_pb_wrapper preprocessed_metadata_pb_wrapper: PreprocessedMetadataPbWrapper = ( gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper ) model = TwoLayerGCN( in_dim=preprocessed_metadata_pb_wrapper.condensed_node_type_to_feature_dim_map[ DEFAULT_CONDENSED_NODE_TYPE ], out_dim=self.__out_dim, is_training=self.__is_training, ) if state_dict is not None: model.load_state_dict(state_dict) self.model = model self._graph_backend = model.graph_backend return model
[docs] def setup_for_training(self): self._optimizer = torch.optim.Adam( self.model.parameters(), lr=self.__optim_lr, weight_decay=self.__optim_weight_decay, ) self._train_loss_fn: Callable[ [torch.Tensor, torch.Tensor], torch.Tensor ] = lambda input, target: F.cross_entropy(input=input, target=target) self.model.train()
def _train( self, data_loader: torch.utils.data.DataLoader, device: torch.device ) -> Optional[torch.Tensor]: self.model.train() loss: Optional[torch.Tensor] = None with ExitStack() as stack: if is_distributed_available_and_initialized(): assert isinstance( self.model, Joinable ), "The model should be Joinable, i.e. wrapped with DistributedDataParallel" # See https://pytorch.org/tutorials/advanced/generic_join.html for context, # also: https://github.com/pytorch/pytorch/issues/38174 # and: https://github.com/pytorch/pytorch/issues/33148 # This is needed to train model with unequal batch sizes across different Ranks stack.enter_context(Join([self.model])) logger.info(f"Model on rank {get_rank()} joined.") batch: SupervisedNodeClassificationBatch for batch in data_loader: self._optimizer.zero_grad() inputs = batch.graph.to(device=device) root_node_indices = batch.root_node_indices.to(device=device) assert ( batch.root_node_labels is not None ), "Labels required for training." root_node_labels = batch.root_node_labels.to(device=device) out = self.model(inputs) # Figure out why below is a typing issue loss = self._train_loss_fn( input=out[root_node_indices], target=root_node_labels ) # type: ignore loss.backward() self._optimizer.step() logger.info( f"Rank {get_rank()} has exhausted all of its inputs for current epoch of training!" ) if is_distributed_available_and_initialized(): torch.distributed.barrier() return loss @no_grad_eval
[docs] def infer_batch( self, batch: SupervisedNodeClassificationBatch, device: torch.device = torch.device("cpu"), ) -> InferBatchResults: inputs = batch.graph.to(device) root_node_indices = batch.root_node_indices.to(device) out = self.model(inputs) embed = out[root_node_indices] pred = embed.argmax(dim=1) return InferBatchResults(embeddings=embed, predictions=pred)
@no_grad_eval
[docs] def score( self, data_loader: torch.utils.data.DataLoader, device: torch.device ) -> float: num_correct = 0 num_evaluated = 0 batch: SupervisedNodeClassificationBatch for batch in data_loader: assert batch.root_node_labels is not None, "Labels required for scoring." root_node_labels = batch.root_node_labels.to(device) assert root_node_labels is not None results: InferBatchResults = self.infer_batch(batch=batch, device=device) num_correct_in_batch = int((results.predictions == root_node_labels).sum()) num_correct += num_correct_in_batch num_evaluated += len(batch.root_node_labels) logger.info(f"Rank {get_rank()} has exhausted all of its inputs!") if is_distributed_available_and_initialized(): torch.distributed.barrier() num_correct_tensor = torch.Tensor([num_correct]).to(device) num_evaluated_tensor = torch.Tensor([num_evaluated]).to(device) logger.info( f"Will reduce num_correct: {num_correct}, and num_evaluated: {num_evaluated}" ) torch.distributed.all_reduce( num_correct_tensor, op=torch.distributed.ReduceOp.SUM ) torch.distributed.all_reduce( num_evaluated_tensor, op=torch.distributed.ReduceOp.SUM ) num_correct = int(num_correct_tensor.item()) num_evaluated = int(num_evaluated_tensor.item()) acc = num_correct / num_evaluated logger.info(f"Computed acc: {acc}, in rank: {get_rank()}") return acc
[docs] def train( self, gbml_config_pb_wrapper: GbmlConfigPbWrapper, device: torch.device, profiler: Optional[TorchProfiler] = None, ) -> None: dataset_metadata_pb_wrapper: DatasetMetadataPbWrapper = ( gbml_config_pb_wrapper.dataset_metadata_pb_wrapper ) assert ( dataset_metadata_pb_wrapper.output_metadata_type == dataset_metadata_pb2.SupervisedNodeClassificationDataset ), "Expected a node classification dataset" data_loaders = self._dataloaders.get_training_dataloaders( gbml_config_pb_wrapper=gbml_config_pb_wrapper, graph_backend=self._graph_backend, device=device, ) best_val_acc = 0.0 for epoch in range(self.__num_epochs): logger.info(f"Batch training... for epoch {epoch}/{self.__num_epochs }") train_loss = self._train( data_loader=data_loaders.train_main, device=device # type: ignore ) train_loss_str = ( f"{train_loss.item():.3f}" if train_loss is not None else None ) val_acc = self.score(data_loader=data_loaders.val_main, device=device) if best_val_acc < val_acc: best_val_acc = val_acc if train_loss is not None: tf.summary.scalar("Train Loss", train_loss.item(), step=epoch) tf.summary.scalar("Acc", round(val_acc, 3), step=epoch) logger.info( f"Train Epoch {epoch}/{self.__num_epochs } | Loss: {train_loss_str} | Val Acc: {val_acc:.3f} | Best Val Acc: {best_val_acc:.3f}" ) logger.info(f"Finished training... ") self._dataloaders.cleanup_dataloaders()
[docs] def eval( self, gbml_config_pb_wrapper: GbmlConfigPbWrapper, device: torch.device, ) -> EvalMetricsCollection: logger.info("Start testing... ") dataset_metadata_pb_wrapper: DatasetMetadataPbWrapper = ( gbml_config_pb_wrapper.dataset_metadata_pb_wrapper ) assert ( dataset_metadata_pb_wrapper.output_metadata_type == dataset_metadata_pb2.SupervisedNodeClassificationDataset ), "Expected a node classification dataset" data_loaders = self._dataloaders.get_test_dataloaders( gbml_config_pb_wrapper=gbml_config_pb_wrapper, graph_backend=self._graph_backend, device=device, ) test_acc = self.score(data_loader=data_loaders.test_main, device=device) logger.info(f"global test acc: {test_acc:.3f})") test_acc_metric = EvalMetric.from_eval_metric_type( eval_metric_type=EvalMetricType.acc, value=test_acc ) model_eval_metrics = EvalMetricsCollection(metrics=[test_acc_metric]) self._dataloaders.cleanup_dataloaders() return model_eval_metrics