Source code for gigl.src.training.v1.lib.data_loaders.common

from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Union

from gigl.common import Uri
from gigl.common.logger import Logger
from gigl.src.common.types.graph_data import NodeType

[docs] logger = Logger()
_DEFAULT_DATA_LOADER_BATCH_SIZE = 32 _DEFAULT_DATA_LOADER_NUM_WORKERS = 0 _DEFAULT_DATA_LOADER_SEED = 42 # TODO(nshah-sc): refactor out sample-wise preprocessing methods into pb wrappers.
[docs] class DataloaderTypes(Enum):
[docs] train_main = "train_main"
[docs] val_main = "val_main"
[docs] test_main = "test_main"
[docs] train_random_negative = "train_random_negative"
[docs] val_random_negative = "val_random_negative"
[docs] test_random_negative = "test_random_negative"
@dataclass
[docs] class DataloaderConfig:
[docs] uris: Union[List[Uri], Dict[NodeType, List[Uri]]]
[docs] batch_size: int = _DEFAULT_DATA_LOADER_BATCH_SIZE
[docs] num_workers: int = _DEFAULT_DATA_LOADER_NUM_WORKERS
[docs] should_loop: bool = False
[docs] pin_memory: bool = False
[docs] seed: int = _DEFAULT_DATA_LOADER_SEED