gigl.utils.data_splitters#

Attributes#

Classes#

HashedNodeAnchorLinkSplitter

Selects train, val, and test nodes based on some provided edge index.

NodeAnchorLinkSplitter

Protocol that should be satisfied for anything that is used to split on edges.

Functions#

get_labels_for_anchor_nodes(dataset, node_ids, ...[, ...])

Selects labels for the given node ids based on the provided edge types.

select_ssl_positive_label_edges(edge_index, ...)

Selects a percentage of edges from an edge index to use for self-supervised positive labels.

Module Contents#

class gigl.utils.data_splitters.HashedNodeAnchorLinkSplitter(sampling_direction, num_val=0.1, num_test=0.1, hash_function=_fast_hash, supervision_edge_types=None, should_convert_labels_to_edges=True)[source]#

Selects train, val, and test nodes based on some provided edge index.

NOTE: This splitter must be called when a Torch distributed process group is initialized. e.g. torch.distributed.init_process_group must be called before using this splitter.

In node-based splitting, a node may only ever live in one split. E.g. if one node has two label edges, both of those edges will be placed into the same split.

The edges must be provided in COO format, as dense tensors. https://tbetcke.github.io/hpc_lecture_notes/sparse_data_structures.html Where the first row of out input are the node ids we that are the “source” of the edge, and the second row are the node ids that are the “destination” of the edge.

Note that there is some tricky interplay with this and the sampling_direction parameter. Take the graph [A -> B] as an example. If sampling_direction is “in”, then B is the source and A is the destination. If sampling_direction is “out”, then A is the source and B is the destination.

Initializes the HashedNodeAnchorLinkSplitter.

Parameters:
  • sampling_direction (Union[Literal["in", "out"], str]) – The direction to sample the nodes. Either “in” or “out”.

  • num_val (float) – The percentage of nodes to use for training. Defaults to 0.1 (10%).

  • num_test (float) – The percentage of nodes to use for validation. Defaults to 0.1 (10%).

  • hash_function (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) – The hash function to use. Defaults to _fast_hash.

  • supervision_edge_types (Optional[list[EdgeType]]) – The supervision edge types we should use for splitting. Must be provided if we are splitting a heterogeneous graph. If None, uses the default message passing edge type in the graph.

  • should_convert_labels_to_edges (bool) – Whether label should be converted into an edge type in the graph. If provided, will make gigl.distributed.build_dataset convert all labels into edges, and will infer positive and negative edge types based on supervision_edge_types.

property should_convert_labels_to_edges[source]#
class gigl.utils.data_splitters.NodeAnchorLinkSplitter[source]#

Bases: Protocol

Protocol that should be satisfied for anything that is used to split on edges.

The edges must be provided in COO format, as dense tensors. https://tbetcke.github.io/hpc_lecture_notes/sparse_data_structures.html

Parameters:

edge_index – The edges to split on in COO format. 2 x N

Returns:

The train (1 x X), val (1 X Y), test (1 x Z) nodes. X + Y + Z = N

property should_convert_labels_to_edges[source]#
gigl.utils.data_splitters.get_labels_for_anchor_nodes(dataset, node_ids, positive_label_edge_type, negative_label_edge_type=None)[source]#

Selects labels for the given node ids based on the provided edge types.

The labels returned are padded with PADDING_NODE to the maximum number of labels, so that we don’t need to work with jagged tensors. The labels are N x M, where N is the number of nodes and M is the max number of labels. For a given ith node id, the ith row of the labels tensor will contain the labels for the given node id. e.g. if we have node_ids = [0, 1, 2] and the following topology:

0 -> 1 -> 2 0 -> 2

and we provide node_ids = [0, 1] then the returned tensor will be:

[
[

1, # Positive node (0 -> 1) 2, # Positive node (0 -> 2)

], [

2, # Positive node (1 -> 2) -1, # Positive node (padded)

],

]

If positive and negative label edge types are provided:
  • All negative label node ids must be present in the positive label node ids.

  • For any positive label node id that does not have a negative label, the negative label will be padded with PADDING_NODE.

Parameters:
  • dataset (Dataset) – The dataset storing the graph info, must be heterogeneous.

  • node_ids (torch.Tensor) – The node ids to use for the labels. [N]

  • positive_label_edge_type (PyGEdgeType) – The edge type to use for the positive labels.

  • negative_label_edge_type (Optional[PyGEdgeType]) – The edge type to use for the negative labels. Defaults to None. If not provided no negative labels will be returned.

Returns:

Tuple of (positive labels, negative_labels?) negative labels may be None depending on if negative_label_edge_type is provided. The returned tensors are of shape N x M where N is the number of nodes and M is the max number of labels, per type.

Return type:

tuple[torch.Tensor, Optional[torch.Tensor]]

gigl.utils.data_splitters.select_ssl_positive_label_edges(edge_index, positive_label_percentage)[source]#

Selects a percentage of edges from an edge index to use for self-supervised positive labels. Note that this function does not mask these labeled edges from the edge index tensor.

Parameters:
  • edge_index (torch.Tensor) – Edge Index tensor of shape [2, num_edges]

  • positive_label_percentage (float) – Percentage of edges to select as positive labels

Returns:

Tensor of positive edges of shape [2, num_labels]

Return type:

torch.Tensor

gigl.utils.data_splitters.PADDING_NODE: Final[torch.Tensor][source]#
gigl.utils.data_splitters.logger[source]#