"""
DatasetFactory is responsible for building and returning a DistLinkPredictionDataset class or subclass. It does this by spawning a
process which initializes rpc + worker group, loads and builds a partitioned dataset, and shuts down the rpc + worker group.
"""
import time
from collections import abc
from distutils.util import strtobool
from typing import Dict, Literal, MutableMapping, Optional, Type, Union
import torch
import torch.multiprocessing as mp
from graphlearn_torch.distributed import (
barrier,
get_context,
init_rpc,
init_worker_group,
rpc_is_initialized,
shutdown_rpc,
)
from gigl.common import UriFactory
from gigl.common.data.dataloaders import TFRecordDataLoader
from gigl.common.data.load_torch_tensors import (
SerializedGraphMetadata,
TFDatasetOptions,
load_torch_tensors_from_tf_record,
)
from gigl.common.logger import Logger
from gigl.common.utils.decorator import tf_on_cpu
from gigl.distributed.constants import DEFAULT_MASTER_DATA_BUILDING_PORT
from gigl.distributed.dist_context import DistributedContext
from gigl.distributed.dist_link_prediction_dataset import DistLinkPredictionDataset
from gigl.distributed.dist_partitioner import DistPartitioner
from gigl.distributed.dist_range_partitioner import DistRangePartitioner
from gigl.distributed.utils import get_process_group_name
from gigl.distributed.utils.serialized_graph_metadata_translator import (
convert_pb_to_serialized_graph_metadata,
)
from gigl.src.common.types.graph_data import EdgeType
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.types.graph import (
DEFAULT_HOMOGENEOUS_EDGE_TYPE,
GraphPartitionData,
message_passing_to_negative_label,
message_passing_to_positive_label,
)
from gigl.utils.data_splitters import (
HashedNodeAnchorLinkSplitter,
NodeAnchorLinkSplitter,
select_ssl_positive_label_edges,
)
@tf_on_cpu
def _load_and_build_partitioned_dataset(
serialized_graph_metadata: SerializedGraphMetadata,
should_load_tensors_in_parallel: bool,
edge_dir: Literal["in", "out"],
partitioner_class: Optional[Type[DistPartitioner]],
node_tf_dataset_options: TFDatasetOptions,
edge_tf_dataset_options: TFDatasetOptions,
should_convert_labels_to_edges: bool = False,
splitter: Optional[NodeAnchorLinkSplitter] = None,
_ssl_positive_label_percentage: Optional[float] = None,
) -> DistLinkPredictionDataset:
"""
Given some information about serialized TFRecords, loads and builds a partitioned dataset into a DistLinkPredictionDataset class.
We require init_rpc and init_worker_group have been called to set up the rpc and context, respectively, prior to calling this function. If this is not
set up beforehand, this function will throw an error.
Args:
serialized_graph_metadata (SerializedGraphMetadata): Serialized Graph Metadata contains serialized information for loading TFRecords across node and edge types
should_load_tensors_in_parallel (bool): Whether tensors should be loaded from serialized information in parallel or in sequence across the [node, edge, pos_label, neg_label] entity types.
edge_dir (Literal["in", "out"]): Edge direction of the provided graph
partitioner_class (Optional[Type[DistPartitioner]]): Partitioner class to partition the graph inputs. If provided, this must be a
DistPartitioner or subclass of it. If not provided, will initialize use the DistPartitioner class.
node_tf_dataset_options (TFDatasetOptions): Options provided to a tf.data.Dataset to tune how serialized node data is read.
edge_tf_dataset_options (TFDatasetOptions): Options provided to a tf.data.Dataset to tune how serialized edge data is read.
should_convert_labels_to_edges (bool): Whether to convert labels to edges in the graph. If this is set to true, the output dataset will be heterogeneous.
splitter (Optional[NodeAnchorLinkSplitter]): Optional splitter to use for splitting the graph data into train, val, and test sets. If not provided (None), no splitting will be performed.
_ssl_positive_label_percentage (Optional[float]): Percentage of edges to select as self-supervised labels. Must be None if supervised edge labels are provided in advance.
Slotted for refactor once this functionality is available in the transductive `splitter` directly
Returns:
DistLinkPredictionDataset: Initialized dataset with partitioned graph information
"""
assert (
get_context() is not None
), "Context must be setup prior to calling `load_and_build_partitioned_dataset` through glt.distributed.init_worker_group()"
assert (
rpc_is_initialized()
), "RPC must be setup prior to calling `load_and_build_partitioned_dataset` through glt.distributed.init_rpc()"
rank: int = get_context().rank
world_size: int = get_context().world_size
tfrecord_data_loader = TFRecordDataLoader(rank=rank, world_size=world_size)
loaded_graph_tensors = load_torch_tensors_from_tf_record(
tf_record_dataloader=tfrecord_data_loader,
serialized_graph_metadata=serialized_graph_metadata,
should_load_tensors_in_parallel=should_load_tensors_in_parallel,
rank=rank,
node_tf_dataset_options=node_tf_dataset_options,
edge_tf_dataset_options=edge_tf_dataset_options,
)
if should_convert_labels_to_edges:
loaded_graph_tensors.treat_labels_as_edges()
should_assign_edges_by_src_node: bool = False if edge_dir == "in" else True
if partitioner_class is None:
partitioner_class = DistPartitioner
if should_assign_edges_by_src_node:
logger.info(
f"Initializing {partitioner_class.__name__} instance while partitioning edges to its source node machine"
)
else:
logger.info(
f"Initializing {partitioner_class.__name__} instance while partitioning edges to its destination node machine"
)
partitioner = partitioner_class(
should_assign_edges_by_src_node=should_assign_edges_by_src_node
)
partitioner.register_node_ids(node_ids=loaded_graph_tensors.node_ids)
partitioner.register_edge_index(edge_index=loaded_graph_tensors.edge_index)
if loaded_graph_tensors.node_features is not None:
partitioner.register_node_features(
node_features=loaded_graph_tensors.node_features
)
if loaded_graph_tensors.edge_features is not None:
partitioner.register_edge_features(
edge_features=loaded_graph_tensors.edge_features
)
if loaded_graph_tensors.positive_label is not None:
partitioner.register_labels(
label_edge_index=loaded_graph_tensors.positive_label, is_positive=True
)
if loaded_graph_tensors.negative_label is not None:
partitioner.register_labels(
label_edge_index=loaded_graph_tensors.negative_label, is_positive=False
)
# We call del so that the reference count of these registered fields is 1,
# allowing these intermediate assets to be cleaned up as necessary inside of the partitioner.partition() call
del (
loaded_graph_tensors.node_ids,
loaded_graph_tensors.node_features,
loaded_graph_tensors.edge_index,
loaded_graph_tensors.edge_features,
loaded_graph_tensors.positive_label,
loaded_graph_tensors.negative_label,
)
del loaded_graph_tensors
partition_output = partitioner.partition()
# TODO (mkolodner-sc): Move this code block to transductive splitter once that is ready
if _ssl_positive_label_percentage is not None:
assert (
partition_output.partitioned_positive_labels is None
and partition_output.partitioned_negative_labels is None
), "Cannot have partitioned positive and negative labels when attempting to select self-supervised positive edges from edge index."
positive_label_edges: Union[torch.Tensor, Dict[EdgeType, torch.Tensor]]
# TODO (mkolodner-sc): Only add necessary edge types to positive label dictionary, rather than all of the keys in the partitioned edge index
if isinstance(partition_output.partitioned_edge_index, abc.Mapping):
positive_label_edges = {}
for (
edge_type,
graph_partition_data,
) in partition_output.partitioned_edge_index.items():
edge_index = graph_partition_data.edge_index
positive_label_edges[edge_type] = select_ssl_positive_label_edges(
edge_index=edge_index,
positive_label_percentage=_ssl_positive_label_percentage,
)
elif isinstance(partition_output.partitioned_edge_index, GraphPartitionData):
positive_label_edges = select_ssl_positive_label_edges(
edge_index=partition_output.partitioned_edge_index.edge_index,
positive_label_percentage=_ssl_positive_label_percentage,
)
else:
raise ValueError(
"Found no partitioned edge index when attempting to select positive labels"
)
partition_output.partitioned_positive_labels = positive_label_edges
logger.info(
f"Initializing DistLinkPredictionDataset instance with edge direction {edge_dir}"
)
dataset = DistLinkPredictionDataset(
rank=rank, world_size=world_size, edge_dir=edge_dir
)
dataset.build(
partition_output=partition_output,
splitter=splitter,
)
return dataset
def _build_dataset_process(
process_number_on_current_machine: int,
output_dict: MutableMapping[str, DistLinkPredictionDataset],
serialized_graph_metadata: SerializedGraphMetadata,
distributed_context: DistributedContext,
dataset_building_port: int,
sample_edge_direction: Literal["in", "out"],
should_load_tensors_in_parallel: bool,
partitioner_class: Optional[Type[DistPartitioner]],
node_tf_dataset_options: TFDatasetOptions,
edge_tf_dataset_options: TFDatasetOptions,
should_convert_labels_to_edges: bool = False,
splitter: Optional[NodeAnchorLinkSplitter] = None,
_ssl_positive_label_percentage: Optional[float] = None,
) -> None:
"""
This function is spawned by a single process per machine and is responsible for:
1. Initializing worker group and rpc connections
2. Loading Torch tensors from serialized TFRecords
3. Partition loaded Torch tensors across multiple machines
4. Loading and formatting graph and feature partition data into a `DistLinkPredictionDataset` class, which will be used during inference
5. Tearing down these connections
Steps 2-4 are done by the `load_and_build_partitioned_dataset` function.
We wrap this logic inside of a `mp.spawn` process so that that assets from these steps are properly cleaned up after the dataset has been built. Without
it, we observe inference performance degradation via cached entities that remain during the inference loop. As such, using a `mp.spawn` process is an easy
way to ensure all cached entities are cleaned up. We use `mp.spawn` instead of `mp.Process` so that any exceptions thrown in this function will be correctly
propogated to the parent process.
This step currently only is supported on CPU.
Args:
process_number_on_current_machine (int): Process number on current machine. This parameter is required and provided by mp.spawn.
This is always set to 1 for dataset building.
output_dict (MutableMapping[str, DistLinkPredictionDataset]): A dictionary spawned by a mp.manager which the built dataset
will be written to for use by the parent process
serialized_graph_metadata (SerializedGraphMetadata): Metadata about TFRecords that are serialized to disk
distributed_context (DistributedContext): Distributed context containing information for master_ip_address, rank, and world size
dataset_building_port (int): RPC port to use to build the dataset
sample_edge_direction (Literal["in", "out"]): Whether edges in the graph are directed inward or outward
should_load_tensors_in_parallel (bool): Whether tensors should be loaded from serialized information in parallel or in sequence across the [node, edge, pos_label, neg_label] entity types.
partitioner_class (Optional[Type[DistPartitioner]]): Partitioner class to partition the graph inputs. If provided, this must be a
DistPartitioner or subclass of it. If not provided, will initialize use the DistPartitioner class.
node_tf_dataset_options (TFDatasetOptions): Options provided to a tf.data.Dataset to tune how serialized node data is read.
edge_tf_dataset_options (TFDatasetOptions): Options provided to a tf.data.Dataset to tune how serialized edge data is read.
should_convert_labels_to_edges (bool): Whether to convert labels to edges in the graph. If this is set to true, the output dataset will be heterogeneous.
splitter (Optional[NodeAnchorLinkSplitter]): Optional splitter to use for splitting the graph data into train, val, and test sets. If not provided (None), no splitting will be performed.
_ssl_positive_label_percentage (Optional[float]): Percentage of edges to select as self-supervised labels. Must be None if supervised edge labels are provided in advance.
Slotted for refactor once this functionality is available in the transductive `splitter` directly
"""
# Sets up the worker group and rpc connection. We need to ensure we cleanup by calling shutdown_rpc() after we no longer need the rpc connection.
init_worker_group(
world_size=distributed_context.global_world_size,
rank=distributed_context.global_rank,
group_name=get_process_group_name(process_number_on_current_machine),
)
init_rpc(
master_addr=distributed_context.main_worker_ip_address,
master_port=dataset_building_port,
num_rpc_threads=16,
)
output_dataset: DistLinkPredictionDataset = _load_and_build_partitioned_dataset(
serialized_graph_metadata=serialized_graph_metadata,
should_load_tensors_in_parallel=should_load_tensors_in_parallel,
edge_dir=sample_edge_direction,
partitioner_class=partitioner_class,
node_tf_dataset_options=node_tf_dataset_options,
edge_tf_dataset_options=edge_tf_dataset_options,
should_convert_labels_to_edges=should_convert_labels_to_edges,
splitter=splitter,
_ssl_positive_label_percentage=_ssl_positive_label_percentage,
)
output_dict["dataset"] = output_dataset
# We add a barrier here so that all processes end and exit this function at the same time. Without this, we may have some machines call shutdown_rpc() while other
# machines may require rpc setup for partitioning, which will result in failure.
barrier()
shutdown_rpc()
[docs]
def build_dataset(
serialized_graph_metadata: SerializedGraphMetadata,
distributed_context: DistributedContext,
sample_edge_direction: Union[Literal["in", "out"], str],
should_load_tensors_in_parallel: bool = True,
partitioner_class: Optional[Type[DistPartitioner]] = None,
node_tf_dataset_options: TFDatasetOptions = TFDatasetOptions(),
edge_tf_dataset_options: TFDatasetOptions = TFDatasetOptions(),
should_convert_labels_to_edges: bool = False,
splitter: Optional[NodeAnchorLinkSplitter] = None,
_ssl_positive_label_percentage: Optional[float] = None,
_dataset_building_port: int = DEFAULT_MASTER_DATA_BUILDING_PORT,
) -> DistLinkPredictionDataset:
"""
Launches a spawned process for building and returning a DistLinkPredictionDataset instance provided some SerializedGraphMetadata
Args:
serialized_graph_metadata (SerializedGraphMetadata): Metadata about TFRecords that are serialized to disk
distributed_context (DistributedContext): Distributed context containing information for master_ip_address, rank, and world size
sample_edge_direction (Union[Literal["in", "out"], str]): Whether edges in the graph are directed inward or outward. Note that this is
listed as a possible string to satisfy type check, but in practice must be a Literal["in", "out"].
should_load_tensors_in_parallel (bool): Whether tensors should be loaded from serialized information in parallel or in sequence across the [node, edge, pos_label, neg_label] entity types.
partitioner_class (Optional[Type[DistPartitioner]]): Partitioner class to partition the graph inputs. If provided, this must be a
DistPartitioner or subclass of it. If not provided, will initialize use the DistPartitioner class.
node_tf_dataset_options (TFDatasetOptions): Options provided to a tf.data.Dataset to tune how serialized node data is read.
edge_tf_dataset_options (TFDatasetOptions): Options provided to a tf.data.Dataset to tune how serialized edge data is read.
should_convert_labels_to_edges (bool): Whether to convert labels to edges in the graph. If this is set to true, the output dataset will be heterogeneous.
splitter (Optional[NodeAnchorLinkSplitter]): Optional splitter to use for splitting the graph data into train, val, and test sets. If not provided (None), no splitting will be performed.
_ssl_positive_label_percentage (Optional[float]): Percentage of edges to select as self-supervised labels. Must be None if supervised edge labels are provided in advance.
Slotted for refactor once this functionality is available in the transductive `splitter` directly
_dataset_building_port (int): WARNING: You don't need to configure this unless port conflict issues. Slotted for refactor.
The RPC port to use to build the dataset. In future, the port will be automatically assigned based on availability.
Currently defaults to: gigl.distributed.constants.DEFAULT_MASTER_DATA_BUILDING_PORT
Returns:
DistLinkPredictionDataset: Built GraphLearn-for-PyTorch Dataset class
"""
assert (
sample_edge_direction == "in" or sample_edge_direction == "out"
), f"Provided edge direction from inference args must be one of `in` or `out`, got {sample_edge_direction}"
if should_convert_labels_to_edges:
if splitter is not None:
logger.warning(
f"Received splitter {splitter} and should_convert_labels_to_edges=True. Will use {splitter} to split the graph data."
)
else:
logger.info(
f"Using default splitter {type(HashedNodeAnchorLinkSplitter)} for ABLP labels."
)
# TODO(kmonte): Read train/val/test split counts from config.
# TODO(kmonte): Read label edge dir from config.
splitter = HashedNodeAnchorLinkSplitter(
sampling_direction=sample_edge_direction,
edge_types=[
message_passing_to_positive_label(DEFAULT_HOMOGENEOUS_EDGE_TYPE),
message_passing_to_negative_label(DEFAULT_HOMOGENEOUS_EDGE_TYPE),
],
)
logger.info(
"Will be treating the ABLP labels as heterogeneous edges in the graph."
)
manager = mp.Manager()
dataset_building_start_time = time.time()
# Used for directing the outputs of the dataset building process back to the parent process
output_dict = manager.dict()
# Launches process for loading serialized TFRecords from disk into memory, partitioning the data across machines, and storing data inside a GLT dataset class
mp.spawn(
fn=_build_dataset_process,
args=(
output_dict,
serialized_graph_metadata,
distributed_context,
_dataset_building_port,
sample_edge_direction,
should_load_tensors_in_parallel,
partitioner_class,
node_tf_dataset_options,
edge_tf_dataset_options,
should_convert_labels_to_edges,
splitter,
_ssl_positive_label_percentage,
),
)
output_dataset: DistLinkPredictionDataset = output_dict["dataset"]
logger.info(
f"--- Dataset Building finished on rank {distributed_context.global_rank}, which took {time.time()-dataset_building_start_time:.2f} seconds"
)
return output_dataset
[docs]
def build_dataset_from_task_config_uri(
task_config_uri: str,
distributed_context: DistributedContext,
is_inference: bool = True,
) -> DistLinkPredictionDataset:
"""
Builds a dataset from a provided `task_config_uri` as part of GiGL orchestration. Parameters to
this step should be provided in the `inferenceArgs` field of the GbmlConfig for inference or the
trainerArgs field of the GbmlConfig for training. The current parsable arguments are here are
- sample_edge_direction: Direction of the graph
- should_use_range_partitioning: Whether we should be using range-based partitioning
- should_load_tensors_in_parallel: Whether TFRecord loading should happen in parallel across entities
Args:
task_config_uri (str): URI to a GBML Config
distributed_context (DistributedContext): Distributed context containing information for
master_ip_address, rank, and world size
is_inference (bool): Whether the run is for inference or training. If True, arguments will
be read from inferenceArgs. Otherwise, arguments witll be read from trainerArgs.
"""
# Read from GbmlConfig for preprocessed data metadata, GNN model uri, and bigquery embedding table path
gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri(
gbml_config_uri=UriFactory.create_uri(task_config_uri)
)
if is_inference:
args = dict(gbml_config_pb_wrapper.inferencer_config.inferencer_args)
args_path = "inferencerConfig.inferencerArgs"
should_convert_labels_to_edges = False
else:
args = dict(gbml_config_pb_wrapper.trainer_config.trainer_args)
args_path = "trainerConfig.trainerArgs"
# TODO(kmonte): Maybe we should enable this as a flag?
should_convert_labels_to_edges = True
sample_edge_direction = args.get("sample_edge_direction", "in")
assert sample_edge_direction in (
"in",
"out",
), f"Provided edge direction from args must be one of `in` or `out`, got {sample_edge_direction}"
should_use_range_partitioning = bool(
strtobool(args.get("should_use_range_partitioning", "True"))
)
should_load_tensors_in_parallel = bool(
strtobool(args.get("should_load_tensors_in_parallel", "True"))
)
logger.info(
f"Inferred 'sample_edge_direction' argument as : {sample_edge_direction} from argument path {args_path}. To override, please provide 'sample_edge_direction' flag."
)
logger.info(
f"Inferred 'should_use_range_partitioning' argument as : {should_use_range_partitioning} from argument path {args_path}. To override, please provide 'should_use_range_partitioning' flag."
)
logger.info(
f"Inferred 'should_load_tensors_in_parallel' argument as : {should_load_tensors_in_parallel} from argument path {args_path}. To override, please provide 'should_load_tensors_in_parallel' flag."
)
# We use a `SerializedGraphMetadata` object to store and organize information for loading serialized TFRecords from disk into memory.
# We provide a convenience utility `convert_pb_to_serialized_graph_metadata` to build the
# `SerializedGraphMetadata` object when using GiGL orchestration, leveraging fields of the GBMLConfigPbWrapper
serialized_graph_metadata = convert_pb_to_serialized_graph_metadata(
preprocessed_metadata_pb_wrapper=gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper,
graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper,
)
if should_use_range_partitioning:
partitioner_class = DistPartitioner
else:
partitioner_class = DistRangePartitioner
dataset = build_dataset(
serialized_graph_metadata=serialized_graph_metadata,
distributed_context=distributed_context,
sample_edge_direction=sample_edge_direction,
partitioner_class=partitioner_class,
should_convert_labels_to_edges=should_convert_labels_to_edges,
)
return dataset