Source code for gigl.src.training.v1.lib.base_trainer

from __future__ import annotations

from typing import Optional

import torch

from gigl.common.logger import Logger
from gigl.src.common.modeling_task_specs.utils.profiler_wrapper import TorchProfiler
from gigl.src.common.types.model import BaseModelOperationsProtocol
from gigl.src.common.types.model_eval_metrics import EvalMetricsCollection
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper

[docs] logger = Logger()
[docs] class BaseTrainer(BaseModelOperationsProtocol): """ The Protocol that you need to implement for your trainer to function with Training in tabularized mode. Note: the BaseTrainer class also implements the :class:`gigl.src.common.types.model.BaseModelOperationsProtocol` which requires the init_model method, and the getter and setter for the model property. """
[docs] def train( self, gbml_config_pb_wrapper: GbmlConfigPbWrapper, device: torch.device, profiler: Optional[TorchProfiler] = None, ) -> None: raise NotImplementedError
[docs] def eval( self, gbml_config_pb_wrapper: GbmlConfigPbWrapper, device: torch.device, ) -> EvalMetricsCollection: raise NotImplementedError
[docs] def setup_for_training(self) -> None: raise NotImplementedError
@property
[docs] def supports_distributed_training(self) -> bool: raise NotImplementedError