Source code for gigl.src.mocking.lib.pyg_to_training_samples

from collections import defaultdict
from typing import Dict, List, Optional, Tuple

import torch
import torch_geometric.transforms as T
from torch_geometric.data import HeteroData
from torch_geometric.loader import NeighborLoader

from gigl.common.env_config import get_available_cpus
from gigl.src.common.types.graph_data import (
    CondensedNodeType,
    EdgeType,
    EdgeUsageType,
    NodeId,
    NodeType,
)
from gigl.src.common.types.pb_wrappers.graph_data_types import GraphPbWrapper
from gigl.src.common.types.pb_wrappers.graph_metadata import GraphMetadataPbWrapper
from gigl.src.common.utils.data.feature_serialization import FeatureSerializationUtils
from gigl.src.mocking.lib.mocked_dataset_resources import MockedDatasetInfo
from gigl.src.mocking.lib.user_defined_edge_sampling import sample_hydrate_user_def_edge
from snapchat.research.gbml import graph_schema_pb2, training_samples_schema_pb2

[docs] DEFAULT_NUM_HOPS_FOR_DATASETS = 1 # Number of hops to consider for each subgraph.
[docs] DEFAULT_NUM_NODES_PER_HOP = 5 # -1 means select all nodes at each hop.
[docs] DEFAULT_NUM_NEGATIVE_SAMPLES_PER_POS_EDGE = 1 # for samples taken from main edges
[docs] DEFAULT_PYG_NODE_ANCHOR_SPLIT_TRANSFORM = T.RandomLinkSplit( num_val=0.1, num_test=0.2, is_undirected=True, add_negative_train_samples=False, neg_sampling_ratio=0, )
[docs] def build_pyg_heterodata_from_mocked_dataset_info( mocked_dataset_info: MockedDatasetInfo, ) -> HeteroData: """ Given a MockedDatasetInfo object, build a HeteroData object to use PyG convenience functions. """ hetero_data = HeteroData() for node_type, node_feats in mocked_dataset_info.node_feats.items(): hetero_data[node_type].x = node_feats hetero_data[node_type].n_id = torch.arange(hetero_data[node_type].num_nodes) if mocked_dataset_info.node_labels is not None: for node_type, node_labels in mocked_dataset_info.node_labels.items(): hetero_data[node_type].y = node_labels for edge_type, edge_index in mocked_dataset_info.edge_index.items(): hetero_data[ edge_type.src_node_type, edge_type.relation, edge_type.dst_node_type ].edge_index = edge_index if mocked_dataset_info.edge_feats is not None: for edge_type, edge_attr in mocked_dataset_info.edge_feats.items(): hetero_data[ (edge_type.src_node_type, edge_type.relation, edge_type.dst_node_type) ].x = edge_attr return hetero_data
def _build_graph_pb_wrapper_from_hetero_data( hetero_data: HeteroData, graph_metadata_pb_wrapper: GraphMetadataPbWrapper ) -> GraphPbWrapper: khop_subgraph_edges: List[graph_schema_pb2.Edge] = list() khop_subgraph_nodes: List[graph_schema_pb2.Node] = list() for pyg_edge_type in hetero_data.edge_types: edge_type_metadata = hetero_data[pyg_edge_type] edge_index = edge_type_metadata.get("edge_index") edge_attr = edge_type_metadata.get("x") src_pyg_node_type = pyg_edge_type[0] dst_pyg_node_type = pyg_edge_type[2] edge_type = EdgeType( src_node_type=src_pyg_node_type, relation=pyg_edge_type[1], dst_node_type=dst_pyg_node_type, ) condensed_edge_type = ( graph_metadata_pb_wrapper.edge_type_to_condensed_edge_type_map[edge_type] ) src_node_ids: torch.Tensor dst_node_ids: torch.Tensor src_node_ids, dst_node_ids = edge_index global_src_node_ids = torch.take( hetero_data[src_pyg_node_type].get("n_id"), src_node_ids ) global_dst_node_ids = torch.take( hetero_data[dst_pyg_node_type].get("n_id"), dst_node_ids ) for idx, (global_src_node_id, global_dst_node_id) in enumerate( zip(global_src_node_ids, global_dst_node_ids) ): edge_feature_value = ( FeatureSerializationUtils.serialize_edge_features( features=edge_attr[idx, :].numpy() ) if edge_attr is not None else None ) edge = graph_schema_pb2.Edge( src_node_id=global_src_node_id, dst_node_id=global_dst_node_id, condensed_edge_type=condensed_edge_type, feature_values=edge_feature_value, # type: ignore ) khop_subgraph_edges.append(edge) for pyg_node_type in hetero_data.node_types: node_type_metadata = hetero_data[pyg_node_type] node_attr = node_type_metadata.get("x") assert node_attr is not None node_type = NodeType(pyg_node_type) condensed_node_type = ( graph_metadata_pb_wrapper.node_type_to_condensed_node_type_map[node_type] ) global_node_ids = node_type_metadata.get("n_id") assert global_node_ids is not None for idx, global_node_id in enumerate(global_node_ids): node_feature_value = FeatureSerializationUtils.serialize_node_features( node_attr[idx, :].numpy() ) node = graph_schema_pb2.Node( node_id=global_node_id, condensed_node_type=condensed_node_type, feature_values=node_feature_value, # type: ignore ) khop_subgraph_nodes.append(node) subgraph = GraphPbWrapper( pb=graph_schema_pb2.Graph( nodes=khop_subgraph_nodes, edges=khop_subgraph_edges, ) ) return subgraph
[docs] def build_k_hop_subgraphs_from_pyg_heterodata( hetero_data: HeteroData, graph_metadata_pb_wrapper: GraphMetadataPbWrapper, root_node_type: NodeType, root_node_idxs: Optional[torch.Tensor] = None, num_hops: int = DEFAULT_NUM_HOPS_FOR_DATASETS, num_neighbors: int = DEFAULT_NUM_NODES_PER_HOP, ) -> Dict[NodeId, GraphPbWrapper]: """ Given inputs, return a map of each root node of type `root_node_type` and index in `root_node_idxs' to GraphPbWrappers which describe the `num_hops` surrounding subgraph. """ if root_node_idxs is None: root_node_idxs = torch.arange(hetero_data[str(root_node_type)].num_nodes) num_neighbors_dict = { edge_type: [num_neighbors] * num_hops for edge_type in hetero_data.edge_types } loader = NeighborLoader( data=hetero_data, num_neighbors=num_neighbors_dict, input_nodes=(str(root_node_type), root_node_idxs), batch_size=1, num_workers=get_available_cpus() - 1, # use all available CPUs except one, for this task. ) k_hop_subgraphs: Dict[NodeId, GraphPbWrapper] = dict() sample: HeteroData for root_node_idx, sample in zip(root_node_idxs.tolist(), loader): graph_pb_wrapper = _build_graph_pb_wrapper_from_hetero_data( hetero_data=sample, graph_metadata_pb_wrapper=graph_metadata_pb_wrapper ) k_hop_subgraphs[NodeId(root_node_idx)] = graph_pb_wrapper return k_hop_subgraphs
def _get_random_negative_samples_for_pos_edges( edge_index: torch.LongTensor, num_nodes: int, num_negative_samples_per_pos_edge: int = 1, ) -> torch.LongTensor: """ Given an "positive" edge index (edges which exist), we return a "negative" edge index (edges which likely don't) of an equal size. We effectively sample the endpoints of these negative edges randomly from the node-set. """ pos_node_ids = edge_index[0].repeat(num_negative_samples_per_pos_edge) neg_node_ids = torch.randint(low=0, high=num_nodes, size=[pos_node_ids.numel()]) return torch.vstack((pos_node_ids, neg_node_ids)) # type: ignore def _build_rooted_node_neighborhood_samples_from_subgraphs( subgraph_dict: Dict[NodeId, GraphPbWrapper], condensed_node_type: CondensedNodeType ) -> List[training_samples_schema_pb2.RootedNodeNeighborhood]: samples: List[training_samples_schema_pb2.RootedNodeNeighborhood] = list() for root_node_id, subgraph in subgraph_dict.items(): sample = training_samples_schema_pb2.RootedNodeNeighborhood( root_node=graph_schema_pb2.Node( node_id=int(root_node_id), condensed_node_type=condensed_node_type, feature_values=None, # type: ignore ), neighborhood=subgraph.pb, ) samples.append(sample) return samples
[docs] def build_supervised_node_classification_samples_from_pyg_heterodata( hetero_data: HeteroData, root_node_type: NodeType, graph_metadata_pb_wrapper: GraphMetadataPbWrapper, ) -> List[training_samples_schema_pb2.SupervisedNodeClassificationSample]: samples: List[ training_samples_schema_pb2.SupervisedNodeClassificationSample ] = list() assert ( hetero_data[str(root_node_type)].get("y") is not None ) # ensure labels exist for this node type (else we cannot have a supervised task) node_labels = hetero_data[str(root_node_type)].y k_hop_subgraphs_for_root_node_type = build_k_hop_subgraphs_from_pyg_heterodata( hetero_data=hetero_data, graph_metadata_pb_wrapper=graph_metadata_pb_wrapper, root_node_type=root_node_type, num_hops=DEFAULT_NUM_HOPS_FOR_DATASETS, ) condensed_node_type = ( graph_metadata_pb_wrapper.node_type_to_condensed_node_type_map[root_node_type] ) for root_node_id, subgraph in k_hop_subgraphs_for_root_node_type.items(): sample = training_samples_schema_pb2.SupervisedNodeClassificationSample( root_node=graph_schema_pb2.Node( node_id=int(root_node_id), condensed_node_type=condensed_node_type, feature_values=None, # type: ignore ), neighborhood=subgraph.pb, root_node_labels=[ training_samples_schema_pb2.Label( label_type="classification", label=node_labels[int(root_node_id)], ) ], ) samples.append(sample) return samples