gigl.experimental.knowledge_graph_embedding.common.dist_checkpoint#

This module provides functions to load and save distributed checkpoints using the Torch Distributed Checkpointing API.

Attributes#

Classes#

AppState

This is a useful wrapper for checkpointing an application state. Since this

Functions#

load_checkpoint_from_uri(state_dict, checkpoint_id)

save_checkpoint_to_uri(state_dict, checkpoint_id[, ...])

Saves the state_dict to a specified checkpoint_id URI using the Torch Distributed Checkpointing API.

Module Contents#

class gigl.experimental.knowledge_graph_embedding.common.dist_checkpoint.AppState(model, optimizer=None)[source]#

Bases: torch.distributed.checkpoint.stateful.Stateful

This is a useful wrapper for checkpointing an application state. Since this object is compliant with the Stateful protocol, DCP will automatically call state_dict/load_state_dict as needed in the dcp.save/load APIs.

We take advantage of this wrapper to hande calling distributed state dict methods on the model and optimizer.

See https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html for more details.

Parameters:
  • model (torch.nn.Module)

  • optimizer (Optional[torch.optim.Optimizer])

load_state_dict(state_dict)[source]#

Restore the object’s state from the provided state_dict.

Parameters:

state_dict – The state dict to restore from

state_dict()[source]#

Objects should return their state_dict representation as a dictionary. The output of this function will be checkpointed, and later restored in load_state_dict().

Warning

Because of the inplace nature of restoring a checkpoint, this function is also called during torch.distributed.checkpoint.load.

Returns:

The objects state dict

Return type:

Dict

to_state_dict()[source]#

Converts the AppState to a state dict that can be used with DCP.

Return type:

torch.distributed.checkpoint.metadata.STATE_DICT_TYPE

APP_STATE_KEY = 'app'[source]#
MODEL_KEY = 'model'[source]#
OPTIMIZER_KEY = 'optimizer'[source]#
model[source]#
optimizer = None[source]#
gigl.experimental.knowledge_graph_embedding.common.dist_checkpoint.load_checkpoint_from_uri(state_dict, checkpoint_id)[source]#
Parameters:
  • state_dict (torch.distributed.checkpoint.metadata.STATE_DICT_TYPE)

  • checkpoint_id (gigl.common.Uri)

gigl.experimental.knowledge_graph_embedding.common.dist_checkpoint.save_checkpoint_to_uri(state_dict, checkpoint_id, should_save_asynchronously=False)[source]#

Saves the state_dict to a specified checkpoint_id URI using the Torch Distributed Checkpointing API.

If the checkpoint_id is a GCS URI, it will first save the checkpoint locally and then upload it to GCS.

If should_save_asynchronously is True, the save operation will be performed asynchronously, returning a Future object. Otherwise, it will block until the save operation is complete.

Parameters:
  • state_dict (STATE_DICT_TYPE) – The state dictionary to save.

  • checkpoint_id (Uri) – The URI where the checkpoint will be saved.

  • should_save_asynchronously (bool) – If True, saves the checkpoint asynchronously.

Returns:

The URI where the checkpoint was saved, or a Future object if saved asynchronously.

Return type:

Union[Future[Uri], Uri]

Raises:

AssertionError – If checkpoint_id is not a LocalUri or GcsUri.

gigl.experimental.knowledge_graph_embedding.common.dist_checkpoint.logger[source]#