Source code for gigl.experimental.knowledge_graph_embedding.lib.checkpoint
from concurrent.futures import Future
from typing import Optional, Union
import torch
import torch.nn as nn
from gigl.common import Uri, UriFactory
from gigl.common.logger import Logger
from gigl.experimental.knowledge_graph_embedding.common.dist_checkpoint import (
    AppState,
    load_checkpoint_from_uri,
    save_checkpoint_to_uri,
)
from gigl.experimental.knowledge_graph_embedding.lib.config.training import (
    CheckpointingConfig,
)
[docs]
def maybe_load_checkpoint(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    checkpointing_config: CheckpointingConfig,
) -> bool:
    """
    Load the model and optimizer checkpoints if they exist.
    Args:
        model: The model to load the checkpoint into.
        optimizer: The optimizer to load the checkpoint into.
        checkpointing_config: The training configuration containing the checkpointing paths.
    Returns:
        bool: True if the model and optimizer were loaded successfully, False otherwise.
    """
    if not checkpointing_config.load_from_path:
        logger.info(
            f"No checkpoint specified to load from. Skipping loading checkpoints."
        )
        return False
    load_from_checkpoint_path: Uri = UriFactory.create_uri(
        checkpointing_config.load_from_path
    )
    logger.info(
        f"Loading model and optimizer from checkpoint path: {load_from_checkpoint_path}"
    )
    app_state = AppState(model=model, optimizer=optimizer)
    load_checkpoint_from_uri(
        state_dict=app_state.to_state_dict(),
        checkpoint_id=load_from_checkpoint_path,
    )
    return True 
[docs]
def maybe_save_checkpoint(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    checkpointing_config: CheckpointingConfig,
    checkpoint_id: str = "",
) -> Optional[Union[Future[Uri], Uri]]:
    """
    Save the model and optimizer checkpoints if specified in the training configuration.
    Args:
        model: The model to save the checkpoint for.
        optimizer: The optimizer to save the checkpoint for.
        checkpointing_config: The training configuration containing the checkpointing paths.
        checkpoint_id: An optional identifier for the checkpoint, used to differentiate between checkpoints if needed.
    Returns:
        Optional[Union[Future[Uri], Uri]]: The URI where the checkpoint was saved, or a Future object if saved asynchronously.
        If no checkpointing path is specified, returns None.
    """
    # Set up the checkpoint saving paths.
    should_save_checkpoint_async = checkpointing_config.should_save_async
    logger.info(f"Got saving condition: {should_save_checkpoint_async}")
    if not checkpointing_config.save_to_path:
        logger.info(f"No checkpoint specified to save to. Skipping saving checkpoint.")
        return None
    save_to_checkpoint_path = UriFactory.create_uri(checkpointing_config.save_to_path)
    checkpoint_id_uri = Uri.join(save_to_checkpoint_path, checkpoint_id)
    app_state = AppState(model=model, optimizer=optimizer)
    return save_checkpoint_to_uri(
        state_dict=app_state.to_state_dict(),
        checkpoint_id=checkpoint_id_uri,
        should_save_asynchronously=should_save_checkpoint_async,
    )