gigl.experimental.knowledge_graph_embedding.common.dist_checkpoint#
This module provides functions to load and save distributed checkpoints using the Torch Distributed Checkpointing API.
Attributes#
Classes#
This is a useful wrapper for checkpointing an application state. Since this |
Functions#
|
|
|
Saves the state_dict to a specified checkpoint_id URI using the Torch Distributed Checkpointing API. |
Module Contents#
- class gigl.experimental.knowledge_graph_embedding.common.dist_checkpoint.AppState(model, optimizer=None)[source]#
Bases:
torch.distributed.checkpoint.stateful.Stateful
This is a useful wrapper for checkpointing an application state. Since this object is compliant with the Stateful protocol, DCP will automatically call state_dict/load_state_dict as needed in the dcp.save/load APIs.
We take advantage of this wrapper to hande calling distributed state dict methods on the model and optimizer.
See https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html for more details.
- Parameters:
model (torch.nn.Module)
optimizer (Optional[torch.optim.Optimizer])
- load_state_dict(state_dict)[source]#
Restore the object’s state from the provided state_dict.
- Parameters:
state_dict – The state dict to restore from
- state_dict()[source]#
Objects should return their state_dict representation as a dictionary. The output of this function will be checkpointed, and later restored in load_state_dict().
Warning
Because of the inplace nature of restoring a checkpoint, this function is also called during torch.distributed.checkpoint.load.
- Returns:
The objects state dict
- Return type:
Dict
- gigl.experimental.knowledge_graph_embedding.common.dist_checkpoint.load_checkpoint_from_uri(state_dict, checkpoint_id)[source]#
- Parameters:
state_dict (torch.distributed.checkpoint.metadata.STATE_DICT_TYPE)
checkpoint_id (gigl.common.Uri)
- gigl.experimental.knowledge_graph_embedding.common.dist_checkpoint.save_checkpoint_to_uri(state_dict, checkpoint_id, should_save_asynchronously=False)[source]#
Saves the state_dict to a specified checkpoint_id URI using the Torch Distributed Checkpointing API.
If the checkpoint_id is a GCS URI, it will first save the checkpoint locally and then upload it to GCS.
If should_save_asynchronously is True, the save operation will be performed asynchronously, returning a Future object. Otherwise, it will block until the save operation is complete.
- Parameters:
state_dict (STATE_DICT_TYPE) – The state dictionary to save.
checkpoint_id (Uri) – The URI where the checkpoint will be saved.
should_save_asynchronously (bool) – If True, saves the checkpoint asynchronously.
- Returns:
The URI where the checkpoint was saved, or a Future object if saved asynchronously.
- Return type:
- Raises:
AssertionError – If checkpoint_id is not a LocalUri or GcsUri.