Source code for gigl.distributed.sampler
from typing import Any, Final, Optional, Union
import torch
from graphlearn_torch.sampler import NodeSamplerInput
from gigl.src.common.types.graph_data import EdgeType, NodeType
[docs]
def metadata_key_with_prefix(key: str) -> str:
    """Prefixes the key with "#META
    Do this as GLT also does this.
    https://github.com/alibaba/graphlearn-for-pytorch/blob/88ff111ac0d9e45c6c9d2d18cfc5883dca07e9f9/graphlearn_torch/python/distributed/dist_neighbor_sampler.py#L714
    """
    return f"#META.{key}"
[docs]
class ABLPNodeSamplerInput(NodeSamplerInput):
    """
    Sampler input specific for ABLP use case. Contains additional information about positive labels, negative labels, and the corresponding
    supervision node type
    """
    def __init__(
        self,
        node: torch.Tensor,
        input_type: Optional[Union[str, NodeType]],
        positive_label_by_edge_types: dict[EdgeType, torch.Tensor],
        negative_label_by_edge_types: dict[EdgeType, torch.Tensor],
    ):
        """
        Args:
            node (torch.Tensor): Anchor nodes to fanout from
            input_type (Optional[Union[str, NodeType]]): Node type of the anchor nodes
            positive_label_by_edge_types (dict[EdgeType, torch.Tensor]): Positive label nodes to fanout from
            negative_label_by_edge_types (dict[EdgeType, torch.Tensor]): Negative label nodes to fanout from
        """
        super().__init__(node, input_type)
        self._positive_label_by_edge_types = positive_label_by_edge_types
        self._negative_label_by_edge_types = negative_label_by_edge_types
    @property
[docs]
    def positive_label_by_edge_types(self) -> dict[EdgeType, torch.Tensor]:
        return self._positive_label_by_edge_types
    @property
[docs]
    def negative_label_by_edge_types(self) -> dict[EdgeType, torch.Tensor]:
        return self._negative_label_by_edge_types
    def __len__(self) -> int:
        return self.node.shape[0]
    def __getitem__(self, index: Union[torch.Tensor, Any]) -> "ABLPNodeSamplerInput":
        if not isinstance(index, torch.Tensor):
            index = torch.tensor(index, dtype=torch.long)
        index = index.to(self.node.device)
        return ABLPNodeSamplerInput(
            node=self.node[index],
            input_type=self.input_type,
            positive_label_by_edge_types={
                edge_type: self._positive_label_by_edge_types[edge_type][index]
                for edge_type in self._positive_label_by_edge_types
            },
            negative_label_by_edge_types={
                edge_type: self._negative_label_by_edge_types[edge_type][index]
                for edge_type in self._negative_label_by_edge_types
            },
        )
    def __repr__(self) -> str:
        return f"ABLPNodeSamplerInput(\n\tnode={self.node},\n\tinput_type={self.input_type},\n\tpositive_label_by_edge_types={self._positive_label_by_edge_types},\n\tnegative_label_by_edge_types={self._negative_label_by_edge_types}\n)"
