gigl.distributed.utils.neighborloader#
Utils for Neighbor loaders.
Attributes#
Functions#
|
Returns a Data object with the label edges removed. |
|
Sets up an approprirate fanout for sampling. |
|
If a feature is missing from a produced Data or HeteroData object due to not fanning out to it, populates it in-place with an empty tensor |
|
Shards input nodes based on the local process rank |
|
Removes all edges of a specific type from a heterogeneous graph. |
Module Contents#
- gigl.distributed.utils.neighborloader.labeled_to_homogeneous(supervision_edge_type, data)[source]#
Returns a Data object with the label edges removed.
- Parameters:
supervision_edge_type (EdgeType) – The edge type that contains the supervision edges.
data (HeteroData) – Heterogeneous graph with the supervision edge type
- Returns:
Homogeneous graph with the labeled edge type removed
- Return type:
data (Data)
- gigl.distributed.utils.neighborloader.patch_fanout_for_sampling(edge_types, num_neighbors)[source]#
Sets up an approprirate fanout for sampling.
Does the following: - For all label edge types, sets the fanout to be zero. - For all other edge types, if the fanout is not specified, uses the original fanout.
Note that if fanout is provided as a dict, the keys (edges) in the fanout must be in edge_types.
We add this because the existing sampling logic (below) makes strict assumptions that we need to conform to. alibaba/graphlearn-for-pytorch
- Parameters:
- Returns:
- Modified fanout that is appropriate for sampling. Is a list[int]
if the dataset is homogeneous, otherwise is dict[EdgeType, list[int]]
- Return type:
Union[list[int], dict[EdgeType, list[int]]]
- gigl.distributed.utils.neighborloader.set_missing_features(data, node_feature_info, edge_feature_info, device)[source]#
If a feature is missing from a produced Data or HeteroData object due to not fanning out to it, populates it in-place with an empty tensor with the appropriate feature dim. Note that PyG natively does this with their DistNeighborLoader for missing edge features + edge indices and missing node features: https://pytorch-geometric.readthedocs.io/en/2.4.0/_modules/torch_geometric/sampler/neighbor_sampler.html#NeighborSampler
However, native Graphlearn-for-PyTorch only does this for edge indices: alibaba/graphlearn-for-pytorch
so we should do this our sampled node/edge features as well
# TODO (mkolodner-sc): Migrate this utility to GLT once we fork their repo
- Parameters:
data (_GraphType) – Data or HeteroData object which we are setting the missing features for
node_feature_info (Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]]) – Node feature dimension and data type. Note that if heterogeneous, only node types with features should be provided. Can be None in the homogeneous case if there are no node features
edge_feature_info (Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]) – Edge feature dimension and data type. Note that if heterogeneous, only edge types with features should be provided. Can be None in the homogeneous case if there are no edge features
device (torch.device) – Device to move the empty features to
- Returns:
Data or HeteroData type with the updated feature fields
- Return type:
_GraphType
- gigl.distributed.utils.neighborloader.shard_nodes_by_process(input_nodes, local_process_rank, local_process_world_size)[source]#
Shards input nodes based on the local process rank :param input_nodes: Nodes which are split across each training or inference process :type input_nodes: torch.Tensor :param local_process_rank: Rank of the current local process :type local_process_rank: int :param local_process_world_size: Total number of local processes on the current machine :type local_process_world_size: int
- Returns:
The sharded nodes for the current local process
- Return type:
torch.Tensor
- Parameters:
input_nodes (torch.Tensor)
local_process_rank (int)
local_process_world_size (int)
- gigl.distributed.utils.neighborloader.strip_label_edges(data)[source]#
Removes all edges of a specific type from a heterogeneous graph.
Modifies the input in place.
- Parameters:
data (HeteroData) – The input heterogeneous graph.
- Returns:
The graph with the label edge types removed.
- Return type:
HeteroData