import json
import tempfile
from functools import partial
from typing import Tuple
import tensorflow as tf
from google.cloud import bigquery
import gigl.src.common.utils.model as model_utils
from gigl.common import UriFactory
from gigl.common.logger import Logger
from gigl.common.utils.os_utils import import_obj
from gigl.src.common.graph_builder.abstract_graph_builder import GraphBuilder
from gigl.src.common.graph_builder.pyg_graph_builder import PygGraphBuilder
from gigl.src.common.translators.training_samples_protos_translator import (
    RootedNodeNeighborhoodSample,
    SupervisedNodeClassificationSample,
)
from gigl.src.common.types.graph_data import NodeType
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.types.pb_wrappers.graph_metadata import GraphMetadataPbWrapper
from gigl.src.common.types.pb_wrappers.preprocessed_metadata import (
    PreprocessedMetadataPbWrapper,
)
from gigl.src.common.types.task_metadata import TaskMetadataType
from gigl.src.common.utils.bq import BqUtils
from gigl.src.inference.v1.lib.base_inferencer import BaseInferencer, InferBatchResults
from gigl.src.inference.v1.lib.inference_output_schema import (
    DEFAULT_EMBEDDING_FIELD,
    DEFAULT_EMBEDDINGS_TABLE_SCHEMA,
    DEFAULT_NODE_ID_FIELD,
    DEFAULT_PREDICTION_FIELD,
    DEFAULT_PREDICTIONS_TABLE_SCHEMA,
)
from gigl.src.training.v1.lib.data_loaders.rooted_node_neighborhood_data_loader import (
    RootedNodeNeighborhoodBatch,
)
from gigl.src.training.v1.lib.data_loaders.supervised_node_classification_data_loader import (
    SupervisedNodeClassificationBatch,
)
from snapchat.research.gbml import (
    flattened_graph_metadata_pb2,
    gbml_config_pb2,
    training_samples_schema_pb2,
)
def _initialize_inferencer_with_gbml_config_pb(
    gbml_config_pb: gbml_config_pb2.GbmlConfig,
) -> Tuple[BaseInferencer, GbmlConfigPbWrapper]:
    inferencer_cls = import_obj(gbml_config_pb.inferencer_config.inferencer_cls_path)
    kwargs = dict(gbml_config_pb.inferencer_config.inferencer_args)
    inferencer = inferencer_cls(**kwargs)
    gbml_config_pb_wrapper = GbmlConfigPbWrapper(gbml_config_pb=gbml_config_pb)
    model_save_path_uri = UriFactory.create_uri(
        gbml_config_pb.shared_config.trained_model_metadata.trained_model_uri
    )
    logger.info(
        f"Loading model state dict from: {model_save_path_uri}, for inferencer: {inferencer}"
    )
    model_state_dict = model_utils.load_state_dict_from_uri(
        load_from_uri=model_save_path_uri
    )
    inferencer.init_model(
        gbml_config_pb_wrapper=gbml_config_pb_wrapper,
        state_dict=model_state_dict,
    )
    return (
        inferencer,
        gbml_config_pb_wrapper,
    )
[docs]
def infer_model(
    gbml_config_pb: gbml_config_pb2.GbmlConfig,
):
    (
        inferencer,
        gbml_config_pb_wrapper,
    ) = _initialize_inferencer_with_gbml_config_pb(gbml_config_pb=gbml_config_pb)
    task_metadata_pb_wrapper = gbml_config_pb_wrapper.task_metadata_pb_wrapper
    if task_metadata_pb_wrapper.task_metadata_type == TaskMetadataType.NODE_BASED_TASK:
        _infer_supervised_node_classification_model(
            inferencer=inferencer,
            gbml_config_pb_wrapper=gbml_config_pb_wrapper,
        )
    elif (
        task_metadata_pb_wrapper.task_metadata_type
        == TaskMetadataType.NODE_ANCHOR_BASED_LINK_PREDICTION_TASK
    ):
        _infer_node_anchor_based_link_prediction_model(
            inferencer=inferencer,
            gbml_config_pb_wrapper=gbml_config_pb_wrapper,
        )
    else:
        raise NotImplementedError 
def _infer_supervised_node_classification_model(
    inferencer: BaseInferencer,
    gbml_config_pb_wrapper: GbmlConfigPbWrapper,
):
    builder = PygGraphBuilder()
    def _collate_node_classification_batch(
        elements: list[SupervisedNodeClassificationSample],
        graph_metadata_pb_wrapper: GraphMetadataPbWrapper,
        preprocessed_metadata_pb_wrapper: PreprocessedMetadataPbWrapper,
        builder: GraphBuilder,
    ) -> SupervisedNodeClassificationBatch:
        batch = (
            SupervisedNodeClassificationBatch.collate_pyg_node_classification_minibatch(
                builder=builder,
                graph_metadata_pb_wrapper=graph_metadata_pb_wrapper,
                preprocessed_metadata_pb_wrapper=preprocessed_metadata_pb_wrapper,
                samples=elements,
            )
        )
        return batch
    translator = partial(
        SupervisedNodeClassificationBatch.preprocess_node_classification_sample_fn,
        builder=builder,
        gbml_config_pb_wrapper=gbml_config_pb_wrapper,
    )
    assert isinstance(
        gbml_config_pb_wrapper.flattened_graph_metadata_pb_wrapper.output_metadata,
        flattened_graph_metadata_pb2.SupervisedNodeClassificationOutput,
    ), f"Flattened graph metadata output of wrong type: expected {flattened_graph_metadata_pb2.SupervisedNodeClassificationOutput.__name__}"
    supervised_node_classification_output = (
        gbml_config_pb_wrapper.flattened_graph_metadata_pb_wrapper.output_metadata
    )
    labeled_tfrecord_files = tf.io.gfile.glob(
        f"{supervised_node_classification_output.labeled_tfrecord_uri_prefix}*"
    )
    unlabeled_tfrecord_files = tf.io.gfile.glob(
        f"{supervised_node_classification_output.unlabeled_tfrecord_uri_prefix}*"
    )
    all_tfrecord_files = labeled_tfrecord_files + unlabeled_tfrecord_files
    ds_iter = tf.data.TFRecordDataset(filenames=all_tfrecord_files).as_numpy_iterator()
    emb_tfh = tempfile.NamedTemporaryFile(delete=False, mode="w")
    emb_file = open(emb_tfh.name, "w")
    pred_tfh = tempfile.NamedTemporaryFile(delete=False, mode="w")
    pred_file = open(pred_tfh.name, "w")
    node_type: NodeType
    for sample_bytes in ds_iter:
        pb = training_samples_schema_pb2.SupervisedNodeClassificationSample()
        pb.ParseFromString(sample_bytes)
        training_sample = translator(pb)
        batch = _collate_node_classification_batch(
            elements=[training_sample],
            graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper,
            preprocessed_metadata_pb_wrapper=gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper,
            builder=builder,
        )
        infer_batch_results: InferBatchResults = inferencer.infer_batch(batch=batch)
        node = batch.root_nodes[0]
        node_id = node.id
        node_type = node.type
        assert (
            infer_batch_results.embeddings is not None
            and infer_batch_results.predictions is not None
        )
        emb = infer_batch_results.embeddings[0].tolist()
        pred = infer_batch_results.predictions[0].tolist()
        emb_file.write(
            json.dumps(
                {
                    DEFAULT_NODE_ID_FIELD: node_id,
                    DEFAULT_EMBEDDING_FIELD: emb,
                }
            )
            + "\n"
        )
        pred_file.write(
            json.dumps(
                {
                    DEFAULT_NODE_ID_FIELD: node_id,
                    DEFAULT_PREDICTION_FIELD: pred,
                }
            )
            + "\n"
        )
    emb_file.close()
    pred_file.close()
    bq_utils = BqUtils()
    emb_path = gbml_config_pb_wrapper.shared_config.inference_metadata.node_type_to_inferencer_output_info_map[
        node_type
    ].embeddings_path
    assert DEFAULT_EMBEDDINGS_TABLE_SCHEMA.schema is not None
    bq_utils.load_file_to_bq(
        source_path=UriFactory.create_uri(emb_tfh.name),
        bq_path=emb_path,
        job_config=bigquery.LoadJobConfig(
            source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON,
            write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
            schema=DEFAULT_EMBEDDINGS_TABLE_SCHEMA.schema["fields"],
        ),
        retry=True,
    )
    logger.info(f"Embeddings for {node_type} loaded to BQ table {emb_path}")
    pred_path = gbml_config_pb_wrapper.shared_config.inference_metadata.node_type_to_inferencer_output_info_map[
        node_type
    ].predictions_path
    assert DEFAULT_PREDICTIONS_TABLE_SCHEMA.schema is not None
    bq_utils.load_file_to_bq(
        source_path=UriFactory.create_uri(pred_tfh.name),
        bq_path=pred_path,
        job_config=bigquery.LoadJobConfig(
            source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON,
            write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
            schema=DEFAULT_PREDICTIONS_TABLE_SCHEMA.schema["fields"],
        ),
        retry=True,
    )
    logger.info(f"Predictions for {node_type} loaded to BQ table {pred_path}")
def _infer_node_anchor_based_link_prediction_model(
    inferencer: BaseInferencer,
    gbml_config_pb_wrapper: GbmlConfigPbWrapper,
):
    def _collate_rooted_node_neighborhood_batch(
        elements: list[RootedNodeNeighborhoodSample],
        graph_metadata_pb_wrapper: GraphMetadataPbWrapper,
        preprocessed_metadata_pb_wrapper: PreprocessedMetadataPbWrapper,
        builder: GraphBuilder,
    ):
        dataloaded_elements = [
            {element.root_node.type: element} for element in elements
        ]
        batch = (
            RootedNodeNeighborhoodBatch.collate_pyg_rooted_node_neighborhood_minibatch(
                builder=builder,
                graph_metadata_pb_wrapper=graph_metadata_pb_wrapper,
                preprocessed_metadata_pb_wrapper=preprocessed_metadata_pb_wrapper,
                samples=dataloaded_elements,
            )
        )
        return batch
    builder = PygGraphBuilder()
    translator = partial(
        RootedNodeNeighborhoodBatch.preprocess_rooted_node_neighborhood_sample_fn,
        builder=builder,
        gbml_config_pb_wrapper=gbml_config_pb_wrapper,
    )
    assert isinstance(
        gbml_config_pb_wrapper.flattened_graph_metadata_pb_wrapper.output_metadata,
        flattened_graph_metadata_pb2.NodeAnchorBasedLinkPredictionOutput,
    ), f"Flattened graph metadata output of wrong type: expected {flattened_graph_metadata_pb2.NodeAnchorBasedLinkPredictionOutput.__name__}"
    node_anchor_output = (
        gbml_config_pb_wrapper.flattened_graph_metadata_pb_wrapper.output_metadata
    )
    bq_utils = BqUtils()
    for (
        node_type,
        random_negative_tfrecord_uri_prefix,
    ) in node_anchor_output.node_type_to_random_negative_tfrecord_uri_prefix.items():
        tfrecord_files = tf.io.gfile.glob(f"{random_negative_tfrecord_uri_prefix}*")
        ds_iter = tf.data.TFRecordDataset(filenames=tfrecord_files).as_numpy_iterator()
        emb_tfh = tempfile.NamedTemporaryFile(delete=False, mode="w")
        emb_file = open(emb_tfh.name, "w")
        for sample_bytes in ds_iter:
            pb = training_samples_schema_pb2.RootedNodeNeighborhood()
            pb.ParseFromString(sample_bytes)
            training_sample = translator(pb)
            batch = _collate_rooted_node_neighborhood_batch(
                elements=[training_sample],
                graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper,
                preprocessed_metadata_pb_wrapper=gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper,
                builder=builder,
            )
            infer_batch_results: InferBatchResults = inferencer.infer_batch(batch=batch)
            node = batch.root_nodes[0]
            node_id = node.id
            assert (
                NodeType(node_type) == node.type
            ), "Expected node type at this tfrecord_uri_prefix to match batch root node type"
            assert (
                infer_batch_results.embeddings is not None
            ), "Expected embeddings to be returned by inferencer"
            emb = infer_batch_results.embeddings[0].tolist()
            emb_file.write(
                json.dumps(
                    {
                        DEFAULT_NODE_ID_FIELD: node_id,
                        DEFAULT_EMBEDDING_FIELD: emb,
                    }
                )
                + "\n"
            )
        emb_file.close()
        emb_path = gbml_config_pb_wrapper.shared_config.inference_metadata.node_type_to_inferencer_output_info_map[
            node_type
        ].embeddings_path
        assert DEFAULT_EMBEDDINGS_TABLE_SCHEMA.schema is not None
        bq_utils.load_file_to_bq(
            source_path=UriFactory.create_uri(emb_tfh.name),
            bq_path=emb_path,
            job_config=bigquery.LoadJobConfig(
                source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON,
                write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
                schema=DEFAULT_EMBEDDINGS_TABLE_SCHEMA.schema["fields"],
            ),
            retry=True,
        )
        logger.info(
            f"Embeddings for node type {node_type} loading to BQ Table {emb_path}"
        )
    logger.info("Finished loading all inferred embeddings to BQ")