Source code for gigl.src.training.v1.lib.data_loaders.common
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from typing import Union
from gigl.common import Uri
from gigl.common.logger import Logger
from gigl.src.common.types.graph_data import NodeType
_DEFAULT_DATA_LOADER_BATCH_SIZE = 32
_DEFAULT_DATA_LOADER_NUM_WORKERS = 0
_DEFAULT_DATA_LOADER_SEED = 42
# TODO(nshah-sc): refactor out sample-wise preprocessing methods into pb wrappers.
[docs]
class DataloaderTypes(Enum):
[docs]
    train_main = "train_main" 
[docs]
    val_main = "val_main" 
[docs]
    test_main = "test_main" 
[docs]
    train_random_negative = "train_random_negative" 
[docs]
    val_random_negative = "val_random_negative" 
[docs]
    test_random_negative = "test_random_negative" 
 
@dataclass
[docs]
class DataloaderConfig:
[docs]
    uris: Union[list[Uri], dict[NodeType, list[Uri]]] 
[docs]
    batch_size: int = _DEFAULT_DATA_LOADER_BATCH_SIZE 
[docs]
    num_workers: int = _DEFAULT_DATA_LOADER_NUM_WORKERS 
[docs]
    should_loop: bool = False 
[docs]
    pin_memory: bool = False 
[docs]
    seed: int = _DEFAULT_DATA_LOADER_SEED