from collections import defaultdict
from typing import Optional, Tuple
import torch
import torch_geometric.transforms as T
from torch_geometric.data import HeteroData
from torch_geometric.loader import NeighborLoader
from gigl.common.env_config import get_available_cpus
from gigl.src.common.types.graph_data import (
    CondensedNodeType,
    EdgeType,
    EdgeUsageType,
    NodeId,
    NodeType,
)
from gigl.src.common.types.pb_wrappers.graph_data_types import GraphPbWrapper
from gigl.src.common.types.pb_wrappers.graph_metadata import GraphMetadataPbWrapper
from gigl.src.common.utils.data.feature_serialization import FeatureSerializationUtils
from gigl.src.mocking.lib.mocked_dataset_resources import MockedDatasetInfo
from gigl.src.mocking.lib.user_defined_edge_sampling import sample_hydrate_user_def_edge
from snapchat.research.gbml import graph_schema_pb2, training_samples_schema_pb2
[docs]
DEFAULT_NUM_HOPS_FOR_DATASETS = 1  # Number of hops to consider for each subgraph. 
[docs]
DEFAULT_NUM_NODES_PER_HOP = 5  # -1 means select all nodes at each hop. 
[docs]
DEFAULT_NUM_NEGATIVE_SAMPLES_PER_POS_EDGE = 1  # for samples taken from main edges 
[docs]
def build_pyg_heterodata_from_mocked_dataset_info(
    mocked_dataset_info: MockedDatasetInfo,
) -> HeteroData:
    """
    Given a MockedDatasetInfo object, build a HeteroData object to use PyG convenience functions.
    """
    hetero_data = HeteroData()
    for node_type, node_feats in mocked_dataset_info.node_feats.items():
        hetero_data[node_type].x = node_feats
        hetero_data[node_type].n_id = torch.arange(hetero_data[node_type].num_nodes)
    if mocked_dataset_info.node_labels is not None:
        for node_type, node_labels in mocked_dataset_info.node_labels.items():
            hetero_data[node_type].y = node_labels
    for edge_type, edge_index in mocked_dataset_info.edge_index.items():
        hetero_data[
            edge_type.src_node_type, edge_type.relation, edge_type.dst_node_type
        ].edge_index = edge_index
    if mocked_dataset_info.edge_feats is not None:
        for edge_type, edge_attr in mocked_dataset_info.edge_feats.items():
            hetero_data[
                (edge_type.src_node_type, edge_type.relation, edge_type.dst_node_type)
            ].x = edge_attr
    return hetero_data 
def _build_graph_pb_wrapper_from_hetero_data(
    hetero_data: HeteroData, graph_metadata_pb_wrapper: GraphMetadataPbWrapper
) -> GraphPbWrapper:
    khop_subgraph_edges: list[graph_schema_pb2.Edge] = list()
    khop_subgraph_nodes: list[graph_schema_pb2.Node] = list()
    for pyg_edge_type in hetero_data.edge_types:
        edge_type_metadata = hetero_data[pyg_edge_type]
        edge_index = edge_type_metadata.get("edge_index")
        edge_attr = edge_type_metadata.get("x")
        src_pyg_node_type = pyg_edge_type[0]
        dst_pyg_node_type = pyg_edge_type[2]
        edge_type = EdgeType(
            src_node_type=src_pyg_node_type,
            relation=pyg_edge_type[1],
            dst_node_type=dst_pyg_node_type,
        )
        condensed_edge_type = (
            graph_metadata_pb_wrapper.edge_type_to_condensed_edge_type_map[edge_type]
        )
        src_node_ids: torch.Tensor
        dst_node_ids: torch.Tensor
        src_node_ids, dst_node_ids = edge_index
        global_src_node_ids = torch.take(
            hetero_data[src_pyg_node_type].get("n_id"), src_node_ids
        )
        global_dst_node_ids = torch.take(
            hetero_data[dst_pyg_node_type].get("n_id"), dst_node_ids
        )
        for idx, (global_src_node_id, global_dst_node_id) in enumerate(
            zip(global_src_node_ids, global_dst_node_ids)
        ):
            edge_feature_value = (
                FeatureSerializationUtils.serialize_edge_features(
                    features=edge_attr[idx, :].numpy()
                )
                if edge_attr is not None
                else None
            )
            edge = graph_schema_pb2.Edge(
                src_node_id=global_src_node_id,
                dst_node_id=global_dst_node_id,
                condensed_edge_type=condensed_edge_type,
                feature_values=edge_feature_value,  # type: ignore
            )
            khop_subgraph_edges.append(edge)
    for pyg_node_type in hetero_data.node_types:
        node_type_metadata = hetero_data[pyg_node_type]
        node_attr = node_type_metadata.get("x")
        assert node_attr is not None
        node_type = NodeType(pyg_node_type)
        condensed_node_type = (
            graph_metadata_pb_wrapper.node_type_to_condensed_node_type_map[node_type]
        )
        global_node_ids = node_type_metadata.get("n_id")
        assert global_node_ids is not None
        for idx, global_node_id in enumerate(global_node_ids):
            node_feature_value = FeatureSerializationUtils.serialize_node_features(
                node_attr[idx, :].numpy()
            )
            node = graph_schema_pb2.Node(
                node_id=global_node_id,
                condensed_node_type=condensed_node_type,
                feature_values=node_feature_value,  # type: ignore
            )
            khop_subgraph_nodes.append(node)
    subgraph = GraphPbWrapper(
        pb=graph_schema_pb2.Graph(
            nodes=khop_subgraph_nodes,
            edges=khop_subgraph_edges,
        )
    )
    return subgraph
[docs]
def build_k_hop_subgraphs_from_pyg_heterodata(
    hetero_data: HeteroData,
    graph_metadata_pb_wrapper: GraphMetadataPbWrapper,
    root_node_type: NodeType,
    root_node_idxs: Optional[torch.Tensor] = None,
    num_hops: int = DEFAULT_NUM_HOPS_FOR_DATASETS,
    num_neighbors: int = DEFAULT_NUM_NODES_PER_HOP,
) -> dict[NodeId, GraphPbWrapper]:
    """
    Given inputs, return a map of each root node of type `root_node_type` and index in `root_node_idxs'
    to GraphPbWrappers which describe the `num_hops` surrounding subgraph.
    """
    if root_node_idxs is None:
        root_node_idxs = torch.arange(hetero_data[str(root_node_type)].num_nodes)
    num_neighbors_dict = {
        edge_type: [num_neighbors] * num_hops for edge_type in hetero_data.edge_types
    }
    loader = NeighborLoader(
        data=hetero_data,
        num_neighbors=num_neighbors_dict,
        input_nodes=(str(root_node_type), root_node_idxs),
        batch_size=1,
        num_workers=get_available_cpus()
        - 1,  # use all available CPUs except one, for this task.
    )
    k_hop_subgraphs: dict[NodeId, GraphPbWrapper] = dict()
    sample: HeteroData
    for root_node_idx, sample in zip(root_node_idxs.tolist(), loader):
        graph_pb_wrapper = _build_graph_pb_wrapper_from_hetero_data(
            hetero_data=sample, graph_metadata_pb_wrapper=graph_metadata_pb_wrapper
        )
        k_hop_subgraphs[NodeId(root_node_idx)] = graph_pb_wrapper
    return k_hop_subgraphs 
def _get_random_negative_samples_for_pos_edges(
    edge_index: torch.LongTensor,
    num_nodes: int,
    num_negative_samples_per_pos_edge: int = 1,
) -> torch.LongTensor:
    """
    Given an "positive" edge index (edges which exist), we return a "negative" edge
    index (edges which likely don't) of an equal size.  We effectively sample the
    endpoints of these negative edges randomly from the node-set.
    """
    pos_node_ids = edge_index[0].repeat(num_negative_samples_per_pos_edge)
    neg_node_ids = torch.randint(low=0, high=num_nodes, size=[pos_node_ids.numel()])
    return torch.vstack((pos_node_ids, neg_node_ids))  # type: ignore
def _build_rooted_node_neighborhood_samples_from_subgraphs(
    subgraph_dict: dict[NodeId, GraphPbWrapper], condensed_node_type: CondensedNodeType
) -> list[training_samples_schema_pb2.RootedNodeNeighborhood]:
    samples: list[training_samples_schema_pb2.RootedNodeNeighborhood] = list()
    for root_node_id, subgraph in subgraph_dict.items():
        sample = training_samples_schema_pb2.RootedNodeNeighborhood(
            root_node=graph_schema_pb2.Node(
                node_id=int(root_node_id),
                condensed_node_type=condensed_node_type,
                feature_values=None,  # type: ignore
            ),
            neighborhood=subgraph.pb,
        )
        samples.append(sample)
    return samples
[docs]
def build_supervised_node_classification_samples_from_pyg_heterodata(
    hetero_data: HeteroData,
    root_node_type: NodeType,
    graph_metadata_pb_wrapper: GraphMetadataPbWrapper,
) -> list[training_samples_schema_pb2.SupervisedNodeClassificationSample]:
    samples: list[
        training_samples_schema_pb2.SupervisedNodeClassificationSample
    ] = list()
    assert (
        hetero_data[str(root_node_type)].get("y") is not None
    )  # ensure labels exist for this node type (else we cannot have a supervised task)
    node_labels = hetero_data[str(root_node_type)].y
    k_hop_subgraphs_for_root_node_type = build_k_hop_subgraphs_from_pyg_heterodata(
        hetero_data=hetero_data,
        graph_metadata_pb_wrapper=graph_metadata_pb_wrapper,
        root_node_type=root_node_type,
        num_hops=DEFAULT_NUM_HOPS_FOR_DATASETS,
    )
    condensed_node_type = (
        graph_metadata_pb_wrapper.node_type_to_condensed_node_type_map[root_node_type]
    )
    for root_node_id, subgraph in k_hop_subgraphs_for_root_node_type.items():
        sample = training_samples_schema_pb2.SupervisedNodeClassificationSample(
            root_node=graph_schema_pb2.Node(
                node_id=int(root_node_id),
                condensed_node_type=condensed_node_type,
                feature_values=None,  # type: ignore
            ),
            neighborhood=subgraph.pb,
            root_node_labels=[
                training_samples_schema_pb2.Label(
                    label_type="classification",
                    label=node_labels[int(root_node_id)],
                )
            ],
        )
        samples.append(sample)
    return samples 
[docs]
def build_node_anchor_link_prediction_samples_from_pyg_heterodata(
    hetero_data: HeteroData,
    sample_edge_type: EdgeType,
    graph_metadata_pb_wrapper: GraphMetadataPbWrapper,
    mocked_dataset_info: MockedDatasetInfo,
) -> Tuple[
    list[training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample],
    list[training_samples_schema_pb2.RootedNodeNeighborhood],
    list[training_samples_schema_pb2.RootedNodeNeighborhood],
]:
    src_node_id_to_k_hop_subgraph = build_k_hop_subgraphs_from_pyg_heterodata(
        hetero_data=hetero_data,
        graph_metadata_pb_wrapper=graph_metadata_pb_wrapper,
        root_node_type=sample_edge_type.src_node_type,
        num_hops=DEFAULT_NUM_HOPS_FOR_DATASETS,
    )
    if sample_edge_type.src_node_type == sample_edge_type.dst_node_type:
        # If the source and destination node types are the same, we can reuse the same subgraphs.
        dst_node_id_to_k_hop_subgraph = src_node_id_to_k_hop_subgraph
    else:
        # Otherwise, we need to build a separate set of subgraphs for the destination node type.
        dst_node_id_to_k_hop_subgraph = build_k_hop_subgraphs_from_pyg_heterodata(
            hetero_data=hetero_data,
            graph_metadata_pb_wrapper=graph_metadata_pb_wrapper,
            root_node_type=sample_edge_type.dst_node_type,
            num_hops=DEFAULT_NUM_HOPS_FOR_DATASETS,
        )
    condensed_src_node_type = (
        graph_metadata_pb_wrapper.node_type_to_condensed_node_type_map[
            sample_edge_type.src_node_type
        ]
    )
    condensed_dst_node_type = (
        graph_metadata_pb_wrapper.node_type_to_condensed_node_type_map[
            sample_edge_type.dst_node_type
        ]
    )
    condensed_sample_edge_type = (
        graph_metadata_pb_wrapper.edge_type_to_condensed_edge_type_map[sample_edge_type]
    )
    # Create RootedNodeNeighborhood samples
    rooted_neighborhoods_for_src_node_type = (
        _build_rooted_node_neighborhood_samples_from_subgraphs(
            subgraph_dict=src_node_id_to_k_hop_subgraph,
            condensed_node_type=condensed_src_node_type,
        )
    )
    rooted_neighborhoods_for_dst_node_type = (
        _build_rooted_node_neighborhood_samples_from_subgraphs(
            subgraph_dict=dst_node_id_to_k_hop_subgraph,
            condensed_node_type=condensed_dst_node_type,
        )
    )
    user_defined_pos_edges = (
        mocked_dataset_info.user_defined_edge_index[sample_edge_type][
            EdgeUsageType.POSITIVE
        ]
        if mocked_dataset_info.user_defined_edge_index
        else None
    )
    user_def_pos_edge_feats = (
        mocked_dataset_info.user_defined_edge_feats[sample_edge_type][
            EdgeUsageType.POSITIVE
        ]
        if mocked_dataset_info.user_defined_edge_feats
        else None
    )
    if user_defined_pos_edges is not None:
        pos_node_map = sample_hydrate_user_def_edge(
            mocked_dataset_info=mocked_dataset_info,
            edge_usage_type=EdgeUsageType.POSITIVE,
        )
    else:
        pos_node_map = defaultdict(list)
        # Create map to track each node's candidate neighbors.
        edge_label_index = hetero_data[
            (
                str(sample_edge_type.src_node_type),
                str(sample_edge_type.relation),
                str(sample_edge_type.dst_node_type),
            )
        ].edge_label_index
        for src, dst in zip(edge_label_index[0].tolist(), edge_label_index[1].tolist()):
            pos_node_map[src].append(dst)
    user_defined_neg_edges = (
        mocked_dataset_info.user_defined_edge_index[sample_edge_type][
            EdgeUsageType.NEGATIVE
        ]
        if mocked_dataset_info.user_defined_edge_index
        else None
    )
    user_def_neg_edge_feats = (
        mocked_dataset_info.user_defined_edge_feats[sample_edge_type][
            EdgeUsageType.NEGATIVE
        ]
        if mocked_dataset_info.user_defined_edge_feats
        else None
    )
    if user_defined_neg_edges is not None:
        hard_neg_node_map = sample_hydrate_user_def_edge(
            mocked_dataset_info=mocked_dataset_info,
            edge_usage_type=EdgeUsageType.NEGATIVE,
        )
    else:
        hard_neg_node_map = defaultdict(list)
        # Create map to track each node's negatives
        hard_neg_edge_index = _get_random_negative_samples_for_pos_edges(
            edge_index=edge_label_index,
            num_nodes=hetero_data[str(sample_edge_type.dst_node_type)].num_nodes,
            num_negative_samples_per_pos_edge=DEFAULT_NUM_NEGATIVE_SAMPLES_PER_POS_EDGE,
        )
        for src, dst in zip(
            hard_neg_edge_index[0].tolist(), hard_neg_edge_index[1].tolist()
        ):
            hard_neg_node_map[src].append(dst)
    unsup_node_anchor_samples: list[
        training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample
    ] = list()
    # Create UnsupNodeAnchor samples for each node with at least 1 positive edge.
    unique_nodes = list(pos_node_map.keys())
    for root_node_id in unique_nodes:
        pos_edge_pbs: list[graph_schema_pb2.Edge] = list()
        hard_neg_edge_pbs: list[graph_schema_pb2.Edge] = list()
        subgraphs_to_merge: list[GraphPbWrapper] = list()
        root_node_pb = graph_schema_pb2.Node(
            node_id=root_node_id,
            condensed_node_type=condensed_src_node_type,
            feature_values=None,  # type: ignore
        )
        subgraphs_to_merge.append(src_node_id_to_k_hop_subgraph[root_node_id])
        for pos_sample in pos_node_map[root_node_id]:
            if (
                user_def_pos_edge_feats is not None
            ):  # pos_node_map={root_node_id: [pos_node_id, edge_feats]}
                pos_node_id = pos_sample[0]
                pos_edge_feats = pos_sample[1]
                edge_pb = graph_schema_pb2.Edge(
                    src_node_id=root_node_id,
                    dst_node_id=pos_node_id,
                    condensed_edge_type=condensed_sample_edge_type,
                    feature_values=pos_edge_feats,
                )
            else:
                pos_node_id = pos_sample
                edge_pb = graph_schema_pb2.Edge(
                    src_node_id=root_node_id,
                    dst_node_id=pos_node_id,
                    condensed_edge_type=condensed_sample_edge_type,
                )
            pos_edge_pbs.append(edge_pb)
            subgraphs_to_merge.append(dst_node_id_to_k_hop_subgraph[pos_node_id])
        for hard_neg_sample in hard_neg_node_map[root_node_id]:
            if (
                user_def_neg_edge_feats is not None
            ):  # neg_node_map={root_node_id: [hard_neg_node_id, edge_feats]}:
                hard_neg_node_id = hard_neg_sample[0]
                hard_neg_edge_feats = hard_neg_sample[1]
                edge_pb = graph_schema_pb2.Edge(
                    src_node_id=root_node_id,
                    dst_node_id=hard_neg_node_id,
                    condensed_edge_type=condensed_sample_edge_type,
                    feature_values=hard_neg_edge_feats,
                )
            else:
                hard_neg_node_id = hard_neg_sample
                edge_pb = graph_schema_pb2.Edge(
                    src_node_id=root_node_id,
                    dst_node_id=hard_neg_node_id,
                    condensed_edge_type=condensed_sample_edge_type,
                )
            hard_neg_edge_pbs.append(edge_pb)
            subgraphs_to_merge.append(dst_node_id_to_k_hop_subgraph[hard_neg_node_id])
        neighborhood_pb = GraphPbWrapper.merge_subgraphs(
            subgraphs=subgraphs_to_merge
        ).pb
        sample = training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample(
            root_node=root_node_pb,
            pos_edges=pos_edge_pbs,
            hard_neg_edges=hard_neg_edge_pbs,
            neighborhood=neighborhood_pb,
        )
        unsup_node_anchor_samples.append(sample)
    return (
        unsup_node_anchor_samples,
        rooted_neighborhoods_for_src_node_type,
        rooted_neighborhoods_for_dst_node_type,
    )