Source code for gigl.src.training.v1.lib.data_loaders.common
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Union
from gigl.common import Uri
from gigl.common.logger import Logger
from gigl.src.common.types.graph_data import NodeType
_DEFAULT_DATA_LOADER_BATCH_SIZE = 32
_DEFAULT_DATA_LOADER_NUM_WORKERS = 0
_DEFAULT_DATA_LOADER_SEED = 42
# TODO(nshah-sc): refactor out sample-wise preprocessing methods into pb wrappers.
[docs]
class DataloaderTypes(Enum):
[docs]
train_main = "train_main"
[docs]
val_main = "val_main"
[docs]
test_main = "test_main"
[docs]
train_random_negative = "train_random_negative"
[docs]
val_random_negative = "val_random_negative"
[docs]
test_random_negative = "test_random_negative"
@dataclass
[docs]
class DataloaderConfig:
[docs]
uris: Union[List[Uri], Dict[NodeType, List[Uri]]]
[docs]
batch_size: int = _DEFAULT_DATA_LOADER_BATCH_SIZE
[docs]
num_workers: int = _DEFAULT_DATA_LOADER_NUM_WORKERS
[docs]
should_loop: bool = False
[docs]
pin_memory: bool = False
[docs]
seed: int = _DEFAULT_DATA_LOADER_SEED