Source code for gigl.src.common.types.model

from enum import Enum
from typing import Optional, OrderedDict, Protocol, runtime_checkable

import torch

from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper


@runtime_checkable
[docs] class BaseModelOperationsProtocol(Protocol): """ The Protocol that you need to implement for your model to function with Training and Inference in tabularized mode. """ @property
[docs] def model(self) -> torch.nn.Module: ...
@model.setter def model(self, model: torch.nn.Module) -> None: ...
[docs] def init_model( self, gbml_config_pb_wrapper: GbmlConfigPbWrapper, state_dict: Optional[OrderedDict[str, torch.Tensor]] = None, ) -> torch.nn.Module: ...
[docs] class GraphBackend(str, Enum):
[docs] PYG = "PyG"
[docs] class GnnModel(Protocol): """ read-only property to infer graph-backend from a GNN model """ @property
[docs] def graph_backend(self) -> GraphBackend: ...