Source code for gigl.src.common.utils.model
import tempfile
from collections import OrderedDict
import torch
from gigl.common import LocalUri, Uri
from gigl.src.common.utils.file_loader import FileLoader
[docs]
def save_scripted_model(model: torch.nn.Module, save_to_path_uri: Uri) -> None:
    assert isinstance(
        model, torch.nn.Module
    ), "Can only save model of type torch.nn.Module"
    file_loader = FileLoader()
    tmp_save_model_file = tempfile.NamedTemporaryFile(delete=False)
    model_scripted = torch.jit.script(model)  # Export to TorchScript
    model_scripted.save(tmp_save_model_file.name)  # Save
    file_loader.load_file(
        file_uri_src=LocalUri(tmp_save_model_file.name), file_uri_dst=save_to_path_uri
    )
    tmp_save_model_file.close() 
[docs]
def save_state_dict(model: torch.nn.Module, save_to_path_uri: Uri) -> None:
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model = model.module
    assert isinstance(
        model, torch.nn.Module
    ), "Can only save model of type torch.nn.Module"
    file_loader = FileLoader()
    tmp_save_model_file = tempfile.NamedTemporaryFile(delete=False)
    torch.save(model.state_dict(), tmp_save_model_file.name)
    file_loader.load_file(
        file_uri_src=LocalUri(tmp_save_model_file.name), file_uri_dst=save_to_path_uri
    )
    tmp_save_model_file.close() 
[docs]
def load_state_dict_from_uri(
    load_from_uri: Uri,
    device: torch.device = torch.device("cpu"),
) -> OrderedDict[str, torch.Tensor]:
    state_dict: OrderedDict[str, torch.Tensor]
    file_loader = FileLoader()
    tmp_file = file_loader.load_to_temp_file(load_from_uri)
    state_dict = torch.load(tmp_file.name, map_location=device)
    tmp_file.close()
    return state_dict 
[docs]
def load_scripted_model_from_uri(
    load_from_uri: Uri,
) -> torch.nn.Module:
    scripted_model: torch.jit.ScriptModule
    file_loader = FileLoader()
    tmp_file = file_loader.load_to_temp_file(load_from_uri)
    scripted_model = torch.jit.load(tmp_file.name)
    tmp_file.close()
    return scripted_model