gigl.distributed.utils.neighborloader#

Utils for Neighbor loaders.

Attributes#

Functions#

labeled_to_homogeneous(supervision_edge_type, data)

Returns a Data object with the label edges removed.

patch_fanout_for_sampling(edge_types, num_neighbors)

Setups an approprirate fanout for sampling.

shard_nodes_by_process(input_nodes, ...)

Shards input nodes based on the local process rank

strip_label_edges(data)

Removes all edges of a specific type from a heterogeneous graph.

Module Contents#

gigl.distributed.utils.neighborloader.labeled_to_homogeneous(supervision_edge_type, data)[source]#

Returns a Data object with the label edges removed.

Parameters:
  • supervision_edge_type (EdgeType) – The edge type that contains the supervision edges.

  • data (HeteroData) – Heterogeneous graph with the supervision edge type

Returns:

Homogeneous graph with the labeled edge type removed

Return type:

data (Data)

gigl.distributed.utils.neighborloader.patch_fanout_for_sampling(edge_types, num_neighbors)[source]#

Setups an approprirate fanout for sampling.

Does the following: - For all label edge types, sets the fanout to be zero. - For all other edge types, if the fanout is not specified, uses the original fanout.

We add this because the existing sampling logic (below) makes strict assumptions that we need to conform to. alibaba/graphlearn-for-pytorch

Parameters:
  • edge_types (list[EdgeType]) – List of all edge types in the graph.

  • num_neighbors (dict[EdgeType, list[int]]) – Specified fanout by the user

Returns:

Modified fanout that is approariate for sampling.

Return type:

dict[EdgeType, list[int]]

gigl.distributed.utils.neighborloader.shard_nodes_by_process(input_nodes, local_process_rank, local_process_world_size)[source]#

Shards input nodes based on the local process rank :param input_nodes: Nodes which are split across each training or inference process :type input_nodes: torch.Tensor :param local_process_rank: Rank of the current local process :type local_process_rank: int :param local_process_world_size: Total number of local processes on the current machine :type local_process_world_size: int

Returns:

The sharded nodes for the current local process

Return type:

torch.Tensor

Parameters:
  • input_nodes (torch.Tensor)

  • local_process_rank (int)

  • local_process_world_size (int)

gigl.distributed.utils.neighborloader.strip_label_edges(data)[source]#

Removes all edges of a specific type from a heterogeneous graph.

Modifies the input in place.

Parameters:

data (HeteroData) – The input heterogeneous graph.

Returns:

The graph with the label edge types removed.

Return type:

HeteroData

gigl.distributed.utils.neighborloader.logger[source]#