gigl.distributed.utils.neighborloader#

Utils for Neighbor loaders.

Attributes#

Classes#

DatasetSchema

Shared metadata between the local and remote datasets.

SamplingClusterSetup

The setup of the sampling cluster.

Functions#

attach_ppr_outputs(data, ppr_edge_indices, ppr_weights)

Attach PPR edge indices and weights onto a HeteroData object.

extract_edge_type_metadata(metadata, prefixes)

Extract entries matching any of the given prefixes from metadata, grouped by prefix.

extract_metadata(msg, device)

Separate user-defined metadata from a SampleMessage.

labeled_to_homogeneous(supervision_edge_type, data)

Returns a Data object with the label edges removed.

patch_fanout_for_sampling(edge_types, num_neighbors)

Sets up an approprirate fanout for sampling.

set_missing_features(data, node_feature_info, ...)

If a feature is missing from a produced Data or HeteroData object due to not fanning out to it, populates it in-place with an empty tensor

shard_nodes_by_process(input_nodes, ...)

Shards input nodes based on the local process rank

strip_label_edges(data)

Removes all edges of a specific type from a heterogeneous graph.

strip_non_ppr_edge_types(data, ppr_edge_types)

Remove all edge types not in ppr_edge_types from a HeteroData object.

Module Contents#

class gigl.distributed.utils.neighborloader.DatasetSchema[source]#

Shared metadata between the local and remote datasets.

edge_dir: str | Literal['in', 'out'][source]#
edge_feature_info: gigl.types.graph.FeatureInfo | dict[torch_geometric.typing.EdgeType, gigl.types.graph.FeatureInfo] | None[source]#
edge_types: list[torch_geometric.typing.EdgeType] | None[source]#
is_homogeneous_with_labeled_edge_type: bool[source]#
node_feature_info: gigl.types.graph.FeatureInfo | dict[torch_geometric.typing.NodeType, gigl.types.graph.FeatureInfo] | None[source]#
class gigl.distributed.utils.neighborloader.SamplingClusterSetup(*args, **kwds)[source]#

Bases: enum.Enum

The setup of the sampling cluster.

COLOCATED = 'colocated'[source]#
GRAPH_STORE = 'graph_store'[source]#
gigl.distributed.utils.neighborloader.attach_ppr_outputs(data, ppr_edge_indices, ppr_weights)[source]#

Attach PPR edge indices and weights onto a HeteroData object.

For each PPR edge type, sets data[edge_type].edge_index and data[edge_type].edge_attr in-place. Called from the loader’s _collate_fn only when a PPR sampler is active; the function is a no-op if both dicts are empty.

Parameters:
  • data (Union[torch_geometric.data.Data, torch_geometric.data.HeteroData]) – The Data or HeteroData object to attach outputs to.

  • ppr_edge_indices (dict[torch_geometric.typing.EdgeType, torch.Tensor]) – Dict mapping PPR edge type to [2, N] edge-index tensor.

  • ppr_weights (dict[torch_geometric.typing.EdgeType, torch.Tensor]) – Dict mapping PPR edge type to [N] weight tensor.

Raises:

AssertionError – If ppr_edge_indices and ppr_weights have different edge-type keys.

Return type:

None

gigl.distributed.utils.neighborloader.extract_edge_type_metadata(metadata, prefixes)[source]#

Extract entries matching any of the given prefixes from metadata, grouped by prefix.

Scans metadata for keys that start with any of the provided prefixes. For each match, the suffix (everything after the matched prefix) is parsed via ast.literal_eval as an EdgeType tuple and added to that prefix’s sub-dict. All unmatched keys are placed in the remaining dict.

Each prefix gets its own sub-dict in the result, so distinct categories (e.g. positive labels, negative labels) can never collide even when extracted in one call.

The original metadata is not modified.

Example

matched, remaining = extract_edge_type_metadata(

metadata=metadata, prefixes=[POSITIVE_LABEL_METADATA_KEY, NEGATIVE_LABEL_METADATA_KEY],

) positive_labels = matched[POSITIVE_LABEL_METADATA_KEY] negative_labels = matched[NEGATIVE_LABEL_METADATA_KEY]

Parameters:
  • metadata (dict[str, torch.Tensor]) – Dict of string keys to tensors.

  • prefixes (list[str]) – List of prefixes to match against. Prefixes should be unique (no repeats).

Returns:

  • matched: Dict mapping each prefix to a sub-dict of {EdgeType: tensor} for all keys that started with that prefix. Every prefix in prefixes is guaranteed to be present as a key (with an empty dict if nothing matched).

  • remaining: Dict of all key/value pairs that matched no prefix.

Return type:

A 2-tuple of

gigl.distributed.utils.neighborloader.extract_metadata(msg, device)[source]#

Separate user-defined metadata from a SampleMessage.

GLT’s to_hetero_data misinterprets #META.-prefixed keys as edge types, causing failures with edge_dir="out" (it tries to call reverse_edge_type on metadata key strings). This function separates metadata from the sampling data so the stripped message can be passed to GLT’s _collate_fn without triggering the bug.

The original msg is not modified.

Parameters:
  • msg (graphlearn_torch.channel.SampleMessage) – The SampleMessage to extract metadata from.

  • device (torch.device) – The device to move metadata tensors to.

Returns:

  • metadata: Dict mapping metadata key (without #META. prefix) to tensor.

  • stripped_msg: A new SampleMessage with #META.-prefixed keys removed.

Return type:

A 2-tuple of

gigl.distributed.utils.neighborloader.labeled_to_homogeneous(supervision_edge_type, data)[source]#

Returns a Data object with the label edges removed.

Parameters:
  • supervision_edge_type (EdgeType) – The edge type that contains the supervision edges.

  • data (HeteroData) – Heterogeneous graph with the supervision edge type

Returns:

Homogeneous graph with the labeled edge type removed

Return type:

data (Data)

gigl.distributed.utils.neighborloader.patch_fanout_for_sampling(edge_types, num_neighbors)[source]#

Sets up an approprirate fanout for sampling.

Does the following: - For all label edge types, sets the fanout to be zero. - For all other edge types, if the fanout is not specified, uses the original fanout.

Note that if fanout is provided as a dict, the keys (edges) in the fanout must be in edge_types.

We add this because the existing sampling logic (below) makes strict assumptions that we need to conform to. alibaba/graphlearn-for-pytorch

Parameters:
  • edge_types (Optional[list[EdgeType]]) – List of all edge types in the graph, is None for homogeneous datasets

  • num_neighbors (dict[EdgeType, list[int]]) – Specified fanout by the user

Returns:

Modified fanout that is appropriate for sampling. Is a list[int]

if the dataset is homogeneous, otherwise is dict[EdgeType, list[int]]

Return type:

Union[list[int], dict[EdgeType, list[int]]]

gigl.distributed.utils.neighborloader.set_missing_features(data, node_feature_info, edge_feature_info, device)[source]#

If a feature is missing from a produced Data or HeteroData object due to not fanning out to it, populates it in-place with an empty tensor with the appropriate feature dim. Note that PyG natively does this with their DistNeighborLoader for missing edge features + edge indices and missing node features: https://pytorch-geometric.readthedocs.io/en/2.4.0/_modules/torch_geometric/sampler/neighbor_sampler.html#NeighborSampler

However, native Graphlearn-for-PyTorch only does this for edge indices: alibaba/graphlearn-for-pytorch

so we should do this our sampled node/edge features as well

# TODO (mkolodner-sc): Migrate this utility to GLT once we fork their repo

Parameters:
  • data (_GraphType) – Data or HeteroData object which we are setting the missing features for

  • node_feature_info (Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]]) – Node feature dimension and data type. Note that if heterogeneous, only node types with features should be provided. Can be None in the homogeneous case if there are no node features

  • edge_feature_info (Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]) – Edge feature dimension and data type. Note that if heterogeneous, only edge types with features should be provided. Can be None in the homogeneous case if there are no edge features

  • device (torch.device) – Device to move the empty features to

Returns:

Data or HeteroData type with the updated feature fields

Return type:

_GraphType

gigl.distributed.utils.neighborloader.shard_nodes_by_process(input_nodes, local_process_rank, local_process_world_size)[source]#

Shards input nodes based on the local process rank :param input_nodes: Nodes which are split across each training or inference process :type input_nodes: torch.Tensor :param local_process_rank: Rank of the current local process :type local_process_rank: int :param local_process_world_size: Total number of local processes on the current machine :type local_process_world_size: int

Returns:

The sharded nodes for the current local process

Return type:

torch.Tensor

Parameters:
  • input_nodes (torch.Tensor)

  • local_process_rank (int)

  • local_process_world_size (int)

gigl.distributed.utils.neighborloader.strip_label_edges(data)[source]#

Removes all edges of a specific type from a heterogeneous graph.

Modifies the input in place.

Parameters:

data (HeteroData) – The input heterogeneous graph.

Returns:

The graph with the label edge types removed.

Return type:

HeteroData

gigl.distributed.utils.neighborloader.strip_non_ppr_edge_types(data, ppr_edge_types)[source]#

Remove all edge types not in ppr_edge_types from a HeteroData object.

GLT’s collate function creates edge stores for all edge types registered in the sampler (including original graph and reverse edge types) even when the PPR sampler provides empty row/col tensors. This removes those ghost stores so the output contains only PPR edge types.

Modifies the input in place.

Parameters:
  • data (torch_geometric.data.HeteroData) – The HeteroData object to clean up.

  • ppr_edge_types (set[torch_geometric.typing.EdgeType]) – The exact set of PPR edge types to keep, as returned by attach_ppr_outputs.

Returns:

The same object with non-PPR edge types removed.

Return type:

torch_geometric.data.HeteroData

gigl.distributed.utils.neighborloader.logger[source]#