gigl.experimental.knowledge_graph_embedding.common.dist_checkpoint#
This module provides functions to load and save distributed checkpoints using the Torch Distributed Checkpointing API.
Attributes#
Classes#
| This is a useful wrapper for checkpointing an application state. Since this | 
Functions#
| 
 | |
| 
 | Saves the state_dict to a specified checkpoint_id URI using the Torch Distributed Checkpointing API. | 
Module Contents#
- class gigl.experimental.knowledge_graph_embedding.common.dist_checkpoint.AppState(model, optimizer=None)[source]#
- Bases: - torch.distributed.checkpoint.stateful.Stateful- This is a useful wrapper for checkpointing an application state. Since this object is compliant with the Stateful protocol, DCP will automatically call state_dict/load_state_dict as needed in the dcp.save/load APIs. - We take advantage of this wrapper to hande calling distributed state dict methods on the model and optimizer. - See https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html for more details. - Parameters:
- model (torch.nn.Module) 
- optimizer (Optional[torch.optim.Optimizer]) 
 
 - load_state_dict(state_dict)[source]#
- Restore the object’s state from the provided state_dict. - Parameters:
- state_dict – The state dict to restore from 
 
 - state_dict()[source]#
- Objects should return their state_dict representation as a dictionary. The output of this function will be checkpointed, and later restored in load_state_dict(). - Warning - Because of the inplace nature of restoring a checkpoint, this function is also called during torch.distributed.checkpoint.load. - Returns:
- The objects state dict 
- Return type:
- Dict 
 
 
- gigl.experimental.knowledge_graph_embedding.common.dist_checkpoint.load_checkpoint_from_uri(state_dict, checkpoint_id)[source]#
- Parameters:
- state_dict (torch.distributed.checkpoint.metadata.STATE_DICT_TYPE) 
- checkpoint_id (gigl.common.Uri) 
 
 
- gigl.experimental.knowledge_graph_embedding.common.dist_checkpoint.save_checkpoint_to_uri(state_dict, checkpoint_id, should_save_asynchronously=False)[source]#
- Saves the state_dict to a specified checkpoint_id URI using the Torch Distributed Checkpointing API. - If the checkpoint_id is a GCS URI, it will first save the checkpoint locally and then upload it to GCS. - If should_save_asynchronously is True, the save operation will be performed asynchronously, returning a Future object. Otherwise, it will block until the save operation is complete. - Parameters:
- state_dict (STATE_DICT_TYPE) – The state dictionary to save. 
- checkpoint_id (Uri) – The URI where the checkpoint will be saved. 
- should_save_asynchronously (bool) – If True, saves the checkpoint asynchronously. 
 
- Returns:
- The URI where the checkpoint was saved, or a Future object if saved asynchronously. 
- Return type:
- Raises:
- AssertionError – If checkpoint_id is not a LocalUri or GcsUri. 
 
