Source code for gigl.src.inference.v1.lib.base_inferencer
from dataclasses import dataclass
from functools import wraps
from typing import Generic, Optional, Protocol, TypeVar, runtime_checkable
import torch
import torch.utils.data
from gigl.common.logger import Logger
from gigl.common.utils.torch_training import is_distributed_available_and_initialized
from gigl.src.common.types.model import BaseModelOperationsProtocol
from gigl.src.training.v1.lib.data_loaders.rooted_node_neighborhood_data_loader import (
    RootedNodeNeighborhoodBatch,
)
from gigl.src.training.v1.lib.data_loaders.supervised_node_classification_data_loader import (
    SupervisedNodeClassificationBatch,
)
[docs]
T = TypeVar("T", contravariant=True) 
@dataclass
[docs]
class InferBatchResults:
[docs]
    embeddings: Optional[torch.Tensor] 
[docs]
    predictions: Optional[torch.Tensor] 
 
[docs]
def no_grad_eval(f):
    @wraps(f)
    def wrapper(self: BaseInferencer, *args, **kwargs):
        curr_model = self.model
        if is_distributed_available_and_initialized() and isinstance(
            self.model, torch.nn.parallel.DistributedDataParallel
        ):
            # We don't need to make use of DDPs unecessary synchronization here
            self.model = self.model.module
        was_training = self.model.training
        self.model.eval()
        with torch.no_grad():
            ret_val = f(self, *args, **kwargs)  # Call infer_batch
        self.model.train(
            mode=was_training
        )  # reset the model to whether it was training or not
        self.model = curr_model
        return ret_val
    return wrapper 
@runtime_checkable
[docs]
class BaseInferencer(BaseModelOperationsProtocol, Protocol, Generic[T]):
    """
    The Protocol that you need to implement for your inferencer to function with automated Inference
    in tabularized mode.
    Note: the BaseInferencer class also implements the :class:`gigl.src.common.types.model.BaseModelOperationsProtocol`
    protocol, which requires the init_model method, and the getter and setter for the model property.
    """
[docs]
    def infer_batch(
        self, batch: T, device: torch.device = torch.device("cpu")
    ) -> InferBatchResults:
        raise NotImplementedError 
 
[docs]
class SupervisedNodeClassificationBaseInferencer(
    BaseInferencer[SupervisedNodeClassificationBatch]
):
    """
    The protocol that you need to implement for inference for Supervised Node Classification tasks in tabularized mode.
    Note: the protocol also implements :class:`BaseInferencer` protocol.
    """ 
[docs]
class NodeAnchorBasedLinkPredictionBaseInferencer(
    BaseInferencer[RootedNodeNeighborhoodBatch]
):
    """
    The protocol that you need to implement for inderence for Node Anchor Based Link Prediction tasks in tabularized mode.
    Note: the protocol also implements :class:`BaseInferencer` protocol.
    """