Source code for gigl.src.mocking.lib.mock_input_for_inference
import torch.nn
import gigl.src.common.utils.model as model_utils
from gigl.common import UriFactory
from gigl.common.logger import Logger
from gigl.common.utils.os_utils import import_obj
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.types.task_metadata import TaskMetadataType
from snapchat.research.gbml import dataset_metadata_pb2, gbml_config_pb2
[docs]
def train_model(
gbml_config_pb: gbml_config_pb2.GbmlConfig,
):
trainer_cls = import_obj(gbml_config_pb.trainer_config.trainer_cls_path)
kwargs = dict(gbml_config_pb.trainer_config.trainer_args)
trainer = trainer_cls(**kwargs)
gbml_config_pb_wrapper = GbmlConfigPbWrapper(gbml_config_pb=gbml_config_pb)
trainer.init_model(gbml_config_pb_wrapper=gbml_config_pb_wrapper)
trainer.setup_for_training()
dataset_metadata_pb_wrapper = gbml_config_pb_wrapper.dataset_metadata_pb_wrapper
graph_metadata_pb_wrapper = gbml_config_pb_wrapper.graph_metadata_pb_wrapper
task_metadata_pb_wrapper = gbml_config_pb_wrapper.task_metadata_pb_wrapper
if task_metadata_pb_wrapper.task_metadata_type == TaskMetadataType.NODE_BASED_TASK:
assert isinstance(
dataset_metadata_pb_wrapper.output_metadata,
dataset_metadata_pb2.SupervisedNodeClassificationDataset,
), f"Did not find {dataset_metadata_pb2.SupervisedNodeClassificationDataset.__name__} instance"
elif (
task_metadata_pb_wrapper.task_metadata_type
== TaskMetadataType.NODE_ANCHOR_BASED_LINK_PREDICTION_TASK
):
assert isinstance(
dataset_metadata_pb_wrapper.output_metadata,
dataset_metadata_pb2.NodeAnchorBasedLinkPredictionDataset,
), f"Did not find {dataset_metadata_pb2.NodeAnchorBasedLinkPredictionDataset.__name__} instance"
else:
raise NotImplementedError
trainer.train(
gbml_config_pb_wrapper=gbml_config_pb_wrapper,
device=torch.device("cpu"),
)
model_save_path_uri = UriFactory.create_uri(
gbml_config_pb.shared_config.trained_model_metadata.trained_model_uri
)
model_utils.save_state_dict(
model=trainer.model, save_to_path_uri=model_save_path_uri
)
logger.info(f"Saved model to: {model_save_path_uri}.")