gigl.utils.data_splitters#

Attributes#

Classes#

HashedNodeAnchorLinkSplitter

Selects train, val, and test nodes based on some provided edge index.

NodeAnchorLinkSplitter

Protocol that should be satisfied for anything that is used to split on edges.

Functions#

get_labels_for_anchor_nodes(dataset, node_ids, ...[, ...])

Selects labels for the given node ids based on the provided edge types.

select_ssl_positive_label_edges(edge_index, ...)

Selects a percentage of edges from an edge index to use for self-supervised positive labels.

Module Contents#

class gigl.utils.data_splitters.HashedNodeAnchorLinkSplitter(sampling_direction, num_val=0.1, num_test=0.1, hash_function=_fast_hash, edge_types=None)[source]#

Selects train, val, and test nodes based on some provided edge index.

In node-based splitting, a node may only ever live in one split. E.g. if one node has two label edges, both of those edges will be placed into the same split.

The edges must be provided in COO format, as dense tensors. https://tbetcke.github.io/hpc_lecture_notes/sparse_data_structures.html Where the first row of out input are the node ids we that are the “source” of the edge, and the second row are the node ids that are the “destination” of the edge.

Note that there is some tricky interplay with this and the sampling_direction parameter. Take the graph [A -> B] as an example. If sampling_direction is “in”, then B is the source and A is the destination. If sampling_direction is “out”, then A is the source and B is the destination.

Initializes the HashedNodeAnchorLinkSplitter.

Parameters:
  • sampling_direction (Union[Literal["in", "out"], str]) – The direction to sample the nodes. Either “in” or “out”.

  • num_val (Union[float, int]) – The percentage of nodes to use for training. Defaults to 0.1 (10%). If an integer is provided, than exactly that number of nodes will be in the validation split.

  • num_test (Union[float, int]) – The percentage of nodes to use for validation. Defaults to 0.1 (10%). If an integer is provided, than exactly that number of nodes will be in the test split.

  • hash_function (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) – The hash function to use. Defaults to _fast_hash.

  • edge_types (Optional[Union[gigl.src.common.types.graph_data.EdgeType, collections.abc.Sequence[gigl.src.common.types.graph_data.EdgeType]]]) – The supervision edge types we should use for splitting. Must be provided if we are splitting a heterogeneous graph.

class gigl.utils.data_splitters.NodeAnchorLinkSplitter[source]#

Bases: Protocol

Protocol that should be satisfied for anything that is used to split on edges.

The edges must be provided in COO format, as dense tensors. https://tbetcke.github.io/hpc_lecture_notes/sparse_data_structures.html

Parameters:

edge_index – The edges to split on in COO format. 2 x N

Returns:

The train (1 x X), val (1 X Y), test (1 x Z) nodes. X + Y + Z = N

gigl.utils.data_splitters.get_labels_for_anchor_nodes(dataset, node_ids, positive_label_edge_type, negative_label_edge_type=None)[source]#

Selects labels for the given node ids based on the provided edge types.

The labels returned are padded with PADDING_NODE to the maximum number of labels, so that we don’t need to work with jagged tensors. The labels are N x M, where N is the number of nodes and M is the max number of labels. For a given ith node id, the ith row of the labels tensor will contain the labels for the given node id. e.g. if we have node_ids = [0, 1, 2] and the following topology:

0 -> 1 -> 2 0 -> 2

and we provide node_ids = [0, 1] then the returned tensor will be:

[
[

1, # Positive node (0 -> 1) 2, # Positive node (0 -> 2)

], [

2, # Positive node (1 -> 2) -1, # Positive node (padded)

],

]

Parameters:
  • dataset (Dataset) – The dataset storing the graph info, must be heterogeneous.

  • node_ids (torch.Tensor) – The node ids to use for the labels. [N]

  • positive_label_edge_type (PyGEdgeType) – The edge type to use for the positive labels.

  • negative_label_edge_type (Optional[PyGEdgeType]) – The edge type to use for the negative labels. Defaults to None. If not provided no negative labels will be returned.

Returns:

Tuple of (positive labels, negative_labels?) negative labels may be None depending on if negative_label_edge_type is provided. The returned tensors are of shape N x M where N is the number of nodes and M is the max number of labels, per type.

Return type:

tuple[torch.Tensor, Optional[torch.Tensor]]

gigl.utils.data_splitters.select_ssl_positive_label_edges(edge_index, positive_label_percentage)[source]#

Selects a percentage of edges from an edge index to use for self-supervised positive labels. Note that this function does not mask these labeled edges from the edge index tensor.

Parameters:
  • edge_index (torch.Tensor) – Edge Index tensor of shape [2, num_edges]

  • positive_label_percentage (float) – Percentage of edges to select as positive labels

Returns:

Tensor of positive edges of shape [2, num_labels]

Return type:

torch.Tensor

gigl.utils.data_splitters.PADDING_NODE: Final[torch.Tensor][source]#
gigl.utils.data_splitters.logger[source]#