gigl.distributed.dataset_factory#

DatasetFactory is responsible for building and returning a DistLinkPredictionDataset class or subclass. It does this by spawning a process which initializes rpc + worker group, loads and builds a partitioned dataset, and shuts down the rpc + worker group.

Attributes#

Functions#

build_dataset(serialized_graph_metadata, ...[, ...])

Launches a spawned process for building and returning a DistLinkPredictionDataset instance provided some

build_dataset_from_task_config_uri(task_config_uri[, ...])

Builds a dataset from a provided task_config_uri as part of GiGL orchestration. Parameters to

Module Contents#

gigl.distributed.dataset_factory.build_dataset(serialized_graph_metadata, sample_edge_direction, distributed_context=None, should_load_tensors_in_parallel=True, partitioner_class=None, node_tf_dataset_options=TFDatasetOptions(), edge_tf_dataset_options=TFDatasetOptions(), splitter=None, _ssl_positive_label_percentage=None, _dataset_building_port=None)[source]#

Launches a spawned process for building and returning a DistLinkPredictionDataset instance provided some SerializedGraphMetadata.

It is expected that there is only one build_dataset call per node (machine). This requirement exists to ensure each machine only participates once in housing a parition of a dataset; otherwise a machine may end up housing multiple partitions of the same dataset which may cause memory issues.

This function expects that there is a process group initialized between the process’ for the nodes participating in hosting the dataset partition. This is so necessary information can be communicated between the nodes i.e. free port information, master IP address, et al. to enable configure RPC. If there is no process group initialized, the function will initialize one using env:// config. See torch.distributed.init_process_group for more info.

Parameters:
  • serialized_graph_metadata (SerializedGraphMetadata) – Metadata about TFRecords that are serialized to disk

  • distributed_context (deprecated field - will be removed soon) (Optional[DistributedContext]) – Distributed context containing information for master_ip_address, rank, and world size. Defaults to None, in which case it will be initialized from the current torch.distributed context. If provided, you need not initialized a process_group, one will be initialized.

  • sample_edge_direction (Union[Literal["in", "out"], str]) – Whether edges in the graph are directed inward or outward. Note that this is listed as a possible string to satisfy type check, but in practice must be a Literal[“in”, “out”].

  • should_load_tensors_in_parallel (bool) – Whether tensors should be loaded from serialized information in parallel or in sequence across the [node, edge, pos_label, neg_label] entity types.

  • partitioner_class (Optional[Type[DistPartitioner]]) – Partitioner class to partition the graph inputs. If provided, this must be a DistPartitioner or subclass of it. If not provided, will initialize use the DistPartitioner class.

  • node_tf_dataset_options (TFDatasetOptions) – Options provided to a tf.data.Dataset to tune how serialized node data is read.

  • edge_tf_dataset_options (TFDatasetOptions) – Options provided to a tf.data.Dataset to tune how serialized edge data is read.

  • splitter (Optional[NodeAnchorLinkSplitter]) – Optional splitter to use for splitting the graph data into train, val, and test sets. If not provided (None), no splitting will be performed.

  • _ssl_positive_label_percentage (Optional[float]) – Percentage of edges to select as self-supervised labels. Must be None if supervised edge labels are provided in advance. Slotted for refactor once this functionality is available in the transductive splitter directly

  • _dataset_building_port (deprecated field - will be removed soon) (Optional[int]) – Contains information about master port. Defaults to None, in which case it will be initialized from the current torch.distributed context.

Returns:

Built GraphLearn-for-PyTorch Dataset class

Return type:

DistLinkPredictionDataset

gigl.distributed.dataset_factory.build_dataset_from_task_config_uri(task_config_uri, distributed_context=None, is_inference=True, _tfrecord_uri_pattern='.*-of-.*\\.tfrecord(\\.gz)?$')[source]#

Builds a dataset from a provided task_config_uri as part of GiGL orchestration. Parameters to this step should be provided in the inferenceArgs field of the GbmlConfig for inference or the trainerArgs field of the GbmlConfig for training.

It is expected that there is only one build_dataset_from_task_config_uri call per node (machine). This requirement exists to ensure each machine only participates once in housing a parition of a dataset; otherwise a machine may end up housing multiple partitions of the same dataset which may cause memory issues.

This function expects that there is a process group initialized between the process’ for the nodes participating in hosting the dataset partition. This is so necessary information can be communicated between the nodes i.e. free port information, master IP address, et al. to configure RPC. If there is no process group initialized, the function will initialize one using env:// config. See torch.distributed.init_process_group for more info.

The current parsable arguments are here are - sample_edge_direction: Direction of the graph - should_use_range_partitioning: Whether we should be using range-based partitioning - should_load_tensors_in_parallel: Whether TFRecord loading should happen in parallel across entities :param task_config_uri: URI to a GBML Config :type task_config_uri: str :param distributed_context: Distributed context containing information for

master_ip_address, rank, and world size. Defaults to None, in which case it will be initialized from the current torch.distributed context.

Parameters:
  • is_inference (bool) – Whether the run is for inference or training. If True, arguments will be read from inferenceArgs. Otherwise, arguments witll be read from trainerArgs.

  • _tfrecord_uri_pattern (str) – INTERNAL ONLY. Regex pattern for loading serialized tf records. Defaults to “.*-of-.*.tfrecord(.gz)?$”.

  • task_config_uri (Union[str, gigl.common.Uri])

  • distributed_context (Optional[DistributedContext])

Return type:

gigl.distributed.dist_link_prediction_dataset.DistLinkPredictionDataset

gigl.distributed.dataset_factory.logger[source]#