Source code for gigl.src.inference.v1.lib.base_inferencer

from dataclasses import dataclass
from functools import wraps
from typing import Generic, Optional, Protocol, TypeVar, runtime_checkable

import torch
import torch.utils.data

from gigl.common.logger import Logger
from gigl.common.utils.torch_training import is_distributed_available_and_initialized
from gigl.src.common.types.model import BaseModelOperationsProtocol
from gigl.src.training.v1.lib.data_loaders.rooted_node_neighborhood_data_loader import (
    RootedNodeNeighborhoodBatch,
)
from gigl.src.training.v1.lib.data_loaders.supervised_node_classification_data_loader import (
    SupervisedNodeClassificationBatch,
)

[docs] T = TypeVar("T", contravariant=True)
[docs] logger = Logger()
@dataclass
[docs] class InferBatchResults:
[docs] embeddings: Optional[torch.Tensor]
[docs] predictions: Optional[torch.Tensor]
[docs] def no_grad_eval(f): @wraps(f) def wrapper(self: BaseInferencer, *args, **kwargs): curr_model = self.model if is_distributed_available_and_initialized() and isinstance( self.model, torch.nn.parallel.DistributedDataParallel ): # We don't need to make use of DDPs unecessary synchronization here self.model = self.model.module was_training = self.model.training self.model.eval() with torch.no_grad(): ret_val = f(self, *args, **kwargs) # Call infer_batch self.model.train( mode=was_training ) # reset the model to whether it was training or not self.model = curr_model return ret_val return wrapper
@runtime_checkable
[docs] class BaseInferencer(BaseModelOperationsProtocol, Protocol, Generic[T]): """ The Protocol that you need to implement for your inferencer to function with automated Inference in tabularized mode. Note: the BaseInferencer class also implements the :class:`gigl.src.common.types.model.BaseModelOperationsProtocol` protocol, which requires the init_model method, and the getter and setter for the model property. """
[docs] def infer_batch( self, batch: T, device: torch.device = torch.device("cpu") ) -> InferBatchResults: raise NotImplementedError
[docs] class SupervisedNodeClassificationBaseInferencer( BaseInferencer[SupervisedNodeClassificationBatch] ): """ The protocol that you need to implement for inference for Supervised Node Classification tasks in tabularized mode. Note: the protocol also implements :class:`BaseInferencer` protocol. """
[docs] class NodeAnchorBasedLinkPredictionBaseInferencer( BaseInferencer[RootedNodeNeighborhoodBatch] ): """ The protocol that you need to implement for inderence for Node Anchor Based Link Prediction tasks in tabularized mode. Note: the protocol also implements :class:`BaseInferencer` protocol. """