gigl.distributed.dataset_factory#

DatasetFactory is responsible for building and returning a DistLinkPredictionDataset class or subclass. It does this by spawning a process which initializes rpc + worker group, loads and builds a partitioned dataset, and shuts down the rpc + worker group.

Attributes#

Functions#

build_dataset(serialized_graph_metadata, ...[, ...])

Launches a spawned process for building and returning a DistLinkPredictionDataset instance provided some SerializedGraphMetadata

build_dataset_from_task_config_uri(task_config_uri, ...)

Builds a dataset from a provided task_config_uri as part of GiGL orchestration. Parameters to

Module Contents#

gigl.distributed.dataset_factory.build_dataset(serialized_graph_metadata, distributed_context, sample_edge_direction, should_load_tensors_in_parallel=True, partitioner_class=None, node_tf_dataset_options=TFDatasetOptions(), edge_tf_dataset_options=TFDatasetOptions(), should_convert_labels_to_edges=False, splitter=None, _ssl_positive_label_percentage=None, _dataset_building_port=DEFAULT_MASTER_DATA_BUILDING_PORT)[source]#

Launches a spawned process for building and returning a DistLinkPredictionDataset instance provided some SerializedGraphMetadata :param serialized_graph_metadata: Metadata about TFRecords that are serialized to disk :type serialized_graph_metadata: SerializedGraphMetadata :param distributed_context: Distributed context containing information for master_ip_address, rank, and world size :type distributed_context: DistributedContext :param sample_edge_direction: Whether edges in the graph are directed inward or outward. Note that this is

listed as a possible string to satisfy type check, but in practice must be a Literal[“in”, “out”].

Parameters:
  • should_load_tensors_in_parallel (bool) – Whether tensors should be loaded from serialized information in parallel or in sequence across the [node, edge, pos_label, neg_label] entity types.

  • partitioner_class (Optional[Type[DistPartitioner]]) – Partitioner class to partition the graph inputs. If provided, this must be a DistPartitioner or subclass of it. If not provided, will initialize use the DistPartitioner class.

  • node_tf_dataset_options (TFDatasetOptions) – Options provided to a tf.data.Dataset to tune how serialized node data is read.

  • edge_tf_dataset_options (TFDatasetOptions) – Options provided to a tf.data.Dataset to tune how serialized edge data is read.

  • should_convert_labels_to_edges (bool) – Whether to convert labels to edges in the graph. If this is set to true, the output dataset will be heterogeneous.

  • splitter (Optional[NodeAnchorLinkSplitter]) – Optional splitter to use for splitting the graph data into train, val, and test sets. If not provided (None), no splitting will be performed.

  • _ssl_positive_label_percentage (Optional[float]) – Percentage of edges to select as self-supervised labels. Must be None if supervised edge labels are provided in advance. Slotted for refactor once this functionality is available in the transductive splitter directly

  • _dataset_building_port (int) – WARNING: You don’t need to configure this unless port conflict issues. Slotted for refactor. The RPC port to use to build the dataset. In future, the port will be automatically assigned based on availability. Currently defaults to: gigl.distributed.constants.DEFAULT_MASTER_DATA_BUILDING_PORT

  • serialized_graph_metadata (gigl.common.data.load_torch_tensors.SerializedGraphMetadata)

  • distributed_context (gigl.distributed.dist_context.DistributedContext)

  • sample_edge_direction (Union[Literal["in", "out"], str])

Returns:

Built GraphLearn-for-PyTorch Dataset class

Return type:

DistLinkPredictionDataset

gigl.distributed.dataset_factory.build_dataset_from_task_config_uri(task_config_uri, distributed_context, is_inference=True)[source]#

Builds a dataset from a provided task_config_uri as part of GiGL orchestration. Parameters to this step should be provided in the inferenceArgs field of the GbmlConfig for inference or the trainerArgs field of the GbmlConfig for training. The current parsable arguments are here are - sample_edge_direction: Direction of the graph - should_use_range_partitioning: Whether we should be using range-based partitioning - should_load_tensors_in_parallel: Whether TFRecord loading should happen in parallel across entities :param task_config_uri: URI to a GBML Config :type task_config_uri: str :param distributed_context: Distributed context containing information for

master_ip_address, rank, and world size

Parameters:
  • is_inference (bool) – Whether the run is for inference or training. If True, arguments will be read from inferenceArgs. Otherwise, arguments witll be read from trainerArgs.

  • task_config_uri (str)

  • distributed_context (DistributedContext)

Return type:

gigl.distributed.dist_link_prediction_dataset.DistLinkPredictionDataset

gigl.distributed.dataset_factory.logger[source]#