Source code for gigl.src.training.v1.lib.data_loaders.utils
import math
from typing import List, Optional, Union
import numpy as np
import torch
import torch.utils.data
from torch.utils.data._utils.worker import WorkerInfo
from torch_geometric.data import Data
from torch_geometric.data.hetero_data import HeteroData
from gigl.common.logger import Logger
from gigl.common.utils.torch_training import get_rank, get_world_size
from gigl.src.common.graph_builder.pyg_graph_data import PygGraphData
from gigl.src.common.types.graph_data import EdgeType, NodeType
from gigl.src.common.types.pb_wrappers.graph_metadata import GraphMetadataPbWrapper
from gigl.src.common.types.pb_wrappers.preprocessed_metadata import (
PreprocessedMetadataPbWrapper,
)
[docs]
def get_data_split_for_current_worker(data_list: np.ndarray) -> np.ndarray:
"""Split list of data per worker
Selects a subset of data based on Torch get_worker_info.
Used as a shard selection function in Dataset.
"""
# Worker info only available if in a worker i.e. not in main process
worker_info: Optional[WorkerInfo] = torch.utils.data.get_worker_info()
if worker_info is None:
return data_list # Just using main process for training, use all urls
else:
worker_id: int = worker_info.id
num_workers: int = worker_info.num_workers
# Accounting for distributed training
global_worker_id = get_rank() * num_workers + worker_id
global_num_workers = num_workers * get_world_size()
global_num_worker_to_data_list_ratio = global_num_workers / len(data_list)
if global_num_worker_to_data_list_ratio > 1:
logger.warning(
f"Number of workers ({global_num_workers}) is greater than number of elements ({len(data_list)}). "
f"Data will be replicated, which may lead to increased memory usage. "
f"Consider reducing the number of workers or increasing the dataset size for better efficiency."
)
data_list = np.tile(
data_list, math.ceil(global_num_worker_to_data_list_ratio)
)
# Starting at the url at index `worker_id`, return every url that is `num_workers` index away from prior
# i.e. if worker_id = 2, and num workers = 3, then urls returned have following indeces:
# [3, 6, 9, ...] ; worker_id starts at 0 thus worker_id = 2, is really the 3rd worker
worker_data = data_list[global_worker_id::global_num_workers]
logger.debug(
f"Worker {global_worker_id} has {len(worker_data)} elements: {worker_data}."
)
return worker_data
[docs]
def cast_graph_for_training(
batch_graph_data: PygGraphData,
graph_metadata_pb_wrapper: GraphMetadataPbWrapper,
preprocessed_metadata_pb_wrapper: PreprocessedMetadataPbWrapper,
batch_type: str,
should_register_edge_features: Optional[bool],
) -> Union[Data, HeteroData]:
"""
Casts the PygGraphData object into a Data or HeteroData object. Also fills in any missing fields from graph
builder with empty tensors in cases where there are no edges for a graph or given edge type.
Args:
batch_graph_data (PygGraphData): Coalesced batch graph
graph_metadata_pb_wrapper (GraphMetadataPbWrapper): Graph Metadata Pb Wrapper for this training job
preprocessed_metadata_pb_wrapper (PreprocessedMetadataPbWrapper): Preprocessed Metadata Pb Wrapper for this training job
should_register_edge_features (bool): Whether we should register edge features for the built graph
"""
if graph_metadata_pb_wrapper.is_heterogeneous:
casted_graph = batch_graph_data
missing_node_types_list: List[NodeType] = []
# If we have a node type missing in the graph, insert that node type into the x_dict field
for (
condensed_node_type,
node_feature_dim,
) in (
preprocessed_metadata_pb_wrapper.condensed_node_type_to_feature_dim_map.items()
):
node_type = graph_metadata_pb_wrapper.condensed_node_type_to_node_type_map[
condensed_node_type
]
if node_type not in casted_graph.node_types:
missing_node_types_list.append(node_type)
casted_graph[node_type].x = torch.empty(
(0, node_feature_dim), dtype=torch.float32
)
has_any_missing_node_types: bool = len(missing_node_types_list) > 0
if has_any_missing_node_types:
logger.info(
f"Found the follow node types missing from heterogeneous {batch_type} batched graph: {missing_node_types_list}. If you are seeing multiple of this log across batches and this isn't expected, please revisit the graph definition and sampling strategy."
)
missing_edge_types_list: List[EdgeType] = []
for (
condensed_edge_type,
edge_feature_dim,
) in (
preprocessed_metadata_pb_wrapper.condensed_edge_type_to_feature_dim_map.items()
):
edge_type = graph_metadata_pb_wrapper.condensed_edge_type_to_edge_type_map[
condensed_edge_type
]
is_edge_type_missing_from_casted_graph = (
edge_type not in casted_graph.edge_types
or casted_graph[edge_type].edge_index is None
)
if is_edge_type_missing_from_casted_graph:
missing_edge_types_list.append(edge_type)
casted_graph[edge_type].edge_index = torch.empty(
(2, 0), dtype=torch.int64
)
if should_register_edge_features:
casted_graph[edge_type].edge_attr = torch.empty(
(0, edge_feature_dim), dtype=torch.float32
)
has_any_missing_edge_types: bool = len(missing_edge_types_list) > 0
if has_any_missing_edge_types:
logger.info(
f"Found the follow edge types missing from heterogeneous {batch_type} batched graph: {missing_edge_types_list}. If you are seeing multiple of this log across batches and this isn't expected, please revisit the graph definition and sampling strategy."
)
else:
casted_graph = batch_graph_data.to_homogeneous()
if casted_graph.num_nodes == 0:
logger.warning(
f"Found no nodes in homogeneous {batch_type} batched graph. "
)
condensed_edge_type = graph_metadata_pb_wrapper.condensed_edge_types[0]
is_edge_missing = casted_graph.edge_index is None
if is_edge_missing:
logger.warning(f"Found no edges in homogeneous {batch_type} batched graph.")
casted_graph.edge_index = torch.empty((2, 0), dtype=torch.int64)
if should_register_edge_features:
edge_feature_dim = preprocessed_metadata_pb_wrapper.condensed_edge_type_to_feature_dim_map[
condensed_edge_type
]
casted_graph.edge_attr = torch.empty(
(0, edge_feature_dim), dtype=torch.float32
)
return casted_graph