Source code for gigl.src.common.utils.model
import tempfile
from typing import OrderedDict
import torch
from gigl.common import LocalUri, Uri
from gigl.src.common.utils.file_loader import FileLoader
[docs]
def save_scripted_model(model: torch.nn.Module, save_to_path_uri: Uri) -> None:
assert isinstance(
model, torch.nn.Module
), "Can only save model of type torch.nn.Module"
file_loader = FileLoader()
tmp_save_model_file = tempfile.NamedTemporaryFile(delete=False)
model_scripted = torch.jit.script(model) # Export to TorchScript
model_scripted.save(tmp_save_model_file.name) # Save
file_loader.load_file(
file_uri_src=LocalUri(tmp_save_model_file.name), file_uri_dst=save_to_path_uri
)
tmp_save_model_file.close()
[docs]
def save_state_dict(model: torch.nn.Module, save_to_path_uri: Uri) -> None:
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model = model.module
assert isinstance(
model, torch.nn.Module
), "Can only save model of type torch.nn.Module"
file_loader = FileLoader()
tmp_save_model_file = tempfile.NamedTemporaryFile(delete=False)
torch.save(model.state_dict(), tmp_save_model_file.name)
file_loader.load_file(
file_uri_src=LocalUri(tmp_save_model_file.name), file_uri_dst=save_to_path_uri
)
tmp_save_model_file.close()
[docs]
def load_state_dict_from_uri(
load_from_uri: Uri,
device: torch.device = torch.device("cpu"),
) -> OrderedDict[str, torch.Tensor]:
state_dict: OrderedDict[str, torch.Tensor]
file_loader = FileLoader()
tmp_file = file_loader.load_to_temp_file(load_from_uri)
state_dict = torch.load(tmp_file.name, map_location=device)
tmp_file.close()
return state_dict
[docs]
def load_scripted_model_from_uri(
load_from_uri: Uri,
) -> torch.nn.Module:
scripted_model: torch.jit.ScriptModule
file_loader = FileLoader()
tmp_file = file_loader.load_to_temp_file(load_from_uri)
scripted_model = torch.jit.load(tmp_file.name)
tmp_file.close()
return scripted_model