Source code for gigl.src.common.utils.gbml_protos

from typing import List, Optional, Tuple

from gigl.src.common.types.pb_wrappers.graph_data_types import (
    EdgePbWrapper,
    GraphPbWrapper,
    NodePbWrapper,
)
from snapchat.research.gbml import training_samples_schema_pb2


[docs] class TrainingSamplesSchemaProtoUtils: @staticmethod
[docs] def build_NodeAnchorBasedLinkPredictionSamplePb( target_node: NodePbWrapper, target_neighborhood: GraphPbWrapper, pos_neighborhoods: List[Tuple[EdgePbWrapper, GraphPbWrapper]], hard_neg_neighborhoods: Optional[ List[Tuple[EdgePbWrapper, GraphPbWrapper]] ] = None, random_neg_neighborhoods: Optional[ List[Tuple[EdgePbWrapper, GraphPbWrapper]] ] = None, ) -> training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample: training_sample = ( training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample( root_node=target_node.pb, ) ) neighborhoods = [target_neighborhood] for pos_sample, pos_neighborhood in pos_neighborhoods: training_sample.pos_edges.append(pos_sample.pb) neighborhoods.append(pos_neighborhood) if hard_neg_neighborhoods: for hard_neg_sample, hard_neg_neighborhood in hard_neg_neighborhoods: training_sample.hard_neg_edges.append(hard_neg_sample.pb) neighborhoods.append(hard_neg_neighborhood) if random_neg_neighborhoods: for ( random_neg_sample, random_neg_neighborhood, ) in random_neg_neighborhoods: training_sample.neg_edges.append(random_neg_sample.pb) neighborhoods.append(random_neg_neighborhood) merged_neighborhood = GraphPbWrapper.merge_subgraphs(subgraphs=neighborhoods) training_sample.neighborhood.CopyFrom(merged_neighborhood.pb) return training_sample
@staticmethod
[docs] def build_SupervisedNodeClassificationSamplePb( target_node: NodePbWrapper, neighborhood: GraphPbWrapper, node_labels: List[training_samples_schema_pb2.Label], ) -> training_samples_schema_pb2.SupervisedNodeClassificationSample: if node_labels: return training_samples_schema_pb2.SupervisedNodeClassificationSample( root_node=target_node.pb, neighborhood=neighborhood.pb, root_node_labels=node_labels, ) else: return training_samples_schema_pb2.SupervisedNodeClassificationSample( root_node=target_node.pb, neighborhood=neighborhood.pb, )
@staticmethod
[docs] def build_SupervisedLinkBasedTaskSamplePb() -> ( training_samples_schema_pb2.SupervisedLinkBasedTaskSample ): return NotImplemented
@staticmethod
[docs] def build_RootedNodeNeighborhoodPb( root_node: NodePbWrapper, neighborhood: GraphPbWrapper, ) -> training_samples_schema_pb2.RootedNodeNeighborhood: return training_samples_schema_pb2.RootedNodeNeighborhood( root_node=root_node.pb, neighborhood=neighborhood.pb )