gigl.utils.data_splitters#

Attributes#

Classes#

HashedNodeAnchorLinkSplitter

Selects train, val, and test nodes based on some provided edge index.

HashedNodeSplitter

Selects train, val, and test nodes based on provided node IDs directly.

NodeAnchorLinkSplitter

Protocol that should be satisfied for anything that is used to split on edges.

NodeSplitter

Protocol that should be satisfied for anything that is used to split on nodes directly.

Functions#

get_labels_for_anchor_nodes(dataset, node_ids, ...[, ...])

Selects labels for the given node ids based on the provided edge types.

select_ssl_positive_label_edges(edge_index, ...)

Selects a percentage of edges from an edge index to use for self-supervised positive labels.

Module Contents#

class gigl.utils.data_splitters.HashedNodeAnchorLinkSplitter(sampling_direction, num_val=0.1, num_test=0.1, hash_function=_fast_hash, supervision_edge_types=None, should_convert_labels_to_edges=True)[source]#

Selects train, val, and test nodes based on some provided edge index.

NOTE: This splitter must be called when a Torch distributed process group is initialized. e.g. torch.distributed.init_process_group must be called before using this splitter.

We need this communication between the processes for determining the maximum and minimum hashed node id across all machines.

In node-based splitting, a node may only ever live in one split. E.g. if one node has two label edges, both of those edges will be placed into the same split.

The edges must be provided in COO format, as dense tensors. https://tbetcke.github.io/hpc_lecture_notes/sparse_data_structures.html Where the first row of out input are the node ids we that are the “source” of the edge, and the second row are the node ids that are the “destination” of the edge.

Note that there is some tricky interplay with this and the sampling_direction parameter. Take the graph [A -> B] as an example. If sampling_direction is “in”, then B is the source and A is the destination. If sampling_direction is “out”, then A is the source and B is the destination.

Initializes the HashedNodeAnchorLinkSplitter.

Parameters:
  • sampling_direction (Union[Literal["in", "out"], str]) – The direction to sample the nodes. Either “in” or “out”.

  • num_val (float) – The percentage of nodes to use for training. Defaults to 0.1 (10%).

  • num_test (float) – The percentage of nodes to use for validation. Defaults to 0.1 (10%).

  • hash_function (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) – The hash function to use. Defaults to _fast_hash.

  • supervision_edge_types (Optional[list[EdgeType]]) – The supervision edge types we should use for splitting. Must be provided if we are splitting a heterogeneous graph. If None, uses the default message passing edge type in the graph.

  • should_convert_labels_to_edges (bool) – Whether label should be converted into an edge type in the graph. If provided, will make gigl.distributed.build_dataset convert all labels into edges, and will infer positive and negative edge types based on supervision_edge_types.

property should_convert_labels_to_edges[source]#
class gigl.utils.data_splitters.HashedNodeSplitter(num_val=0.1, num_test=0.1, hash_function=_fast_hash)[source]#

Selects train, val, and test nodes based on provided node IDs directly.

NOTE: This splitter must be called when a Torch distributed process group is initialized. e.g. torch.distributed.init_process_group must be called before using this splitter.

We need this communication between the processes for determining the maximum and minimum hashed node id across all machines.

In node-based splitting, each node will be placed into exactly one split based on its hash value. This is simpler than edge-based splitting as it doesn’t require extracting anchor nodes from edges.

Additionally, the HashedNodeSplitter does not de-dup repeated node ids. This means that if there are repeated node ids which are passed in, the same number of repeated node ids are included in the output, all of which are put into the same split. This differs from the HashedNodeAnchorLinkSplitter, which does de-dup the repeated source or destination nodes that appear from the labeled edges.

Parameters:
  • node_ids – The node IDs to split. Either a 1D tensor for homogeneous graphs, or a mapping from node types to 1D tensors for heterogeneous graphs.

  • num_val (float)

  • num_test (float)

  • hash_function (Callable[[torch.Tensor], torch.Tensor])

Returns:

The train, val, test node splits as tensors or mappings depending on input format.

Initializes the HashedNodeSplitter.

Parameters:
  • num_val (float) – The percentage of nodes to use for validation. Defaults to 0.1 (10%).

  • num_test (float) – The percentage of nodes to use for testing. Defaults to 0.1 (10%).

  • hash_function (Callable[[torch.Tensor], torch.Tensor]) – The hash function to use. Defaults to _fast_hash.

class gigl.utils.data_splitters.NodeAnchorLinkSplitter[source]#

Bases: Protocol

Protocol that should be satisfied for anything that is used to split on edges.

The edges must be provided in COO format, as dense tensors. https://tbetcke.github.io/hpc_lecture_notes/sparse_data_structures.html

Parameters:

edge_index – The edges to split on in COO format. 2 x N

Returns:

The train (1 x X), val (1 X Y), test (1 x Z) nodes. X + Y + Z = N

property should_convert_labels_to_edges[source]#
class gigl.utils.data_splitters.NodeSplitter[source]#

Bases: Protocol

Protocol that should be satisfied for anything that is used to split on nodes directly.

Parameters:

node_ids – The node IDs to split on. 1D tensor for homogeneous or mapping for heterogeneous. 1 x N

Returns:

The train (1 x X), val (1 x Y), test (1 x Z) nodes. X + Y + Z = N

gigl.utils.data_splitters.get_labels_for_anchor_nodes(dataset, node_ids, positive_label_edge_type, negative_label_edge_type=None)[source]#

Selects labels for the given node ids based on the provided edge types.

The labels returned are padded with PADDING_NODE to the maximum number of labels, so that we don’t need to work with jagged tensors. The labels are N x M, where N is the number of nodes and M is the max number of labels. For a given ith node id, the ith row of the labels tensor will contain the labels for the given node id. e.g. if we have node_ids = [0, 1, 2] and the following topology:

0 -> 1 -> 2 0 -> 2

and we provide node_ids = [0, 1] then the returned tensor will be:

[
[

1, # Positive node (0 -> 1) 2, # Positive node (0 -> 2)

], [

2, # Positive node (1 -> 2) -1, # Positive node (padded)

],

]

If positive and negative label edge types are provided:
  • All negative label node ids must be present in the positive label node ids.

  • For any positive label node id that does not have a negative label, the negative label will be padded with PADDING_NODE.

Parameters:
  • dataset (Dataset) – The dataset storing the graph info, must be heterogeneous.

  • node_ids (torch.Tensor) – The node ids to use for the labels. [N]

  • positive_label_edge_type (PyGEdgeType) – The edge type to use for the positive labels.

  • negative_label_edge_type (Optional[PyGEdgeType]) – The edge type to use for the negative labels. Defaults to None. If not provided no negative labels will be returned.

Returns:

Tuple of (positive labels, negative_labels?) negative labels may be None depending on if negative_label_edge_type is provided. The returned tensors are of shape N x M where N is the number of nodes and M is the max number of labels, per type.

Return type:

tuple[torch.Tensor, Optional[torch.Tensor]]

gigl.utils.data_splitters.select_ssl_positive_label_edges(edge_index, positive_label_percentage)[source]#

Selects a percentage of edges from an edge index to use for self-supervised positive labels. Note that this function does not mask these labeled edges from the edge index tensor.

Parameters:
  • edge_index (torch.Tensor) – Edge Index tensor of shape [2, num_edges]

  • positive_label_percentage (float) – Percentage of edges to select as positive labels

Returns:

Tensor of positive edges of shape [2, num_labels]

Return type:

torch.Tensor

gigl.utils.data_splitters.PADDING_NODE: Final[torch.Tensor][source]#
gigl.utils.data_splitters.logger[source]#