from typing import Any, Dict, Iterable, Optional, Type
import torch
import torch.nn as nn
from torch.distributed.optim import (
    _apply_optimizer_in_backward as apply_optimizer_in_backward,
)
from torch.optim import Optimizer
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.fbgemm_qcomm_codec import (
    CommType,
    QCommsConfig,
    get_qcomm_codecs_registry,
)
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.storage_reservations import (
    HeuristicalStorageReservation,
)
from torchrec.distributed.types import ShardingPlan
from torchrec.optim.keyed import KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.optim.rowwise_adagrad import RowWiseAdagrad
from gigl.common.logger import Logger
[docs]
def maybe_shard_model(
    model,
    device: torch.device,
    sharding_plan: ShardingPlan = None,
):
    """
    If in a distributed environment, apply DistributedModelParallel to the model,
    using an optionally specified ShardingPlan.
    If not in a distributed environment, return the model directly.
    Args:
        model: The model to be wrapped.
        device: The device to use for the model.
        sharding_plan: An optional ShardingPlan to use for the DistributedModelParallel.
    Returns:
        The model wrapped in DistributedModelParallel if in a distributed environment,
        otherwise the model itself.
    """
    if torch.distributed.is_initialized():
        # Build a sharding plan
        logger.info("***** Wrapping in DistributedModelParallel *****")
        logger.info(f"Model before wrapping: {model}")
        model = DistributedModelParallel(
            module=model,
            device=device,
            plan=sharding_plan,
        )
        logger.info(f"Model after wrapping: {model}")
    return model 
[docs]
def get_sharding_plan(
    model: nn.Module,
    batch_size: int,
    local_world_size: int,
    world_size: int,
    use_cuda: bool = False,
    storage_reservation_percentage: float = 0.15,
    qcomm_forward_precision: CommType = CommType.FP32,
    qcomm_backward_precision: CommType = CommType.FP32,
) -> ShardingPlan:
    """
    Create a sharding plan for the model using the EmbeddingShardingPlanner.
    Args:
        model: The model to be sharded.
        batch_size: The batch size for the sharding plan.
        use_cuda: Whether to use CUDA for the sharding plan.
        storage_reservation_percentage: The percentage of storage reservation.
        qcomm_forward_precision: The precision for forward communication (can be FP32, FP16, etc.).
        qcomm_backward_precision: The precision for backward communication (can be FP32, FP16, etc.).
    Returns:
        A ShardingPlan object representing the sharding plan for the model.
    """
    topology = Topology(
        world_size=world_size,
        local_world_size=local_world_size,  # TODO(nshah): We should expose this in torch_training.py
        compute_device="cuda" if use_cuda else "cpu",
        hbm_cap=torch.cuda.get_device_properties(0).total_memory if use_cuda else 0,
    )
    planner = EmbeddingShardingPlanner(
        topology=topology,
        batch_size=batch_size,
        storage_reservation=HeuristicalStorageReservation(
            percentage=storage_reservation_percentage
        ),  # bumping this % can alleviate OOM issues by being more conservative
    )
    # Enable custom fwd/bkwd precisions for QComms when using GPU
    qcomm_codecs_registry = (
        get_qcomm_codecs_registry(
            qcomms_config=QCommsConfig(
                forward_precision=qcomm_forward_precision,
                backward_precision=qcomm_backward_precision,
            )
        )
        if use_cuda
        else None
    )
    ebc_sharder = EmbeddingBagCollectionSharder(
        qcomm_codecs_registry=qcomm_codecs_registry
    )
    plan = planner.collective_plan(
        model, [ebc_sharder], torch.distributed.GroupMember.WORLD
    )
    return plan 
[docs]
def apply_sparse_optimizer(
    parameters: Iterable[nn.Parameter],
    optimizer_cls: Optional[Type[Optimizer]] = None,
    optimizer_kwargs: Dict[str, Any] = dict(),
) -> None:
    """
    Apply a sparse optimizer to the sparse/EBC parts of a model.
    This optimizer is fused, so it will be applied directly in the backward pass.
    This should only be used for sparse parameters.
    Args:
        parameters (Iterable[nn.Parameter]): The sparse parameters to apply the optimizer to.
        optimizer_cls (Type[Optimizer], optional): The optimizer class to use. Defaults to RowWiseAdagrad.
        optimizer_kwargs (Dict[str, Any], optional): Additional keyword arguments for the optimizer.
    """
    if not optimizer_cls and optimizer_kwargs:
        optimizer_cls = RowWiseAdagrad
        optimizer_kwargs = {"lr": 0.01}
    apply_optimizer_in_backward(optimizer_cls, parameters, optimizer_kwargs) 
[docs]
def apply_dense_optimizer(
    model: nn.Module,
    optimizer_cls: Type[Optimizer],
    optimizer_kwargs: Dict[str, Any] = dict(),
) -> Optional[KeyedOptimizerWrapper]:
    """
    This creates an optimizer for the dense parts of the model.
    It uses the `KeyedOptimizerWrapper` to wrap the optimizer.
    Args:
        model (nn.Module): The model containing dense parameters.
        optimizer_cls (Type[Optimizer]): The optimizer class to use for dense parameters.
        optimizer_kwargs (Dict[str, Any], optional): Additional keyword arguments for the optimizer.
    Returns:
        Optional[KeyedOptimizerWrapper]: A wrapped optimizer for dense parameters, or
            None if no dense parameters are found.
    """
    dense_params = dict(in_backward_optimizer_filter(model.named_parameters()))
    if not dense_params:
        # We cannot apply a dense optimizer if there are no dense parameters.
        logger.warning("No dense parameters found in the model.")
        return None
    dense_optimizer = KeyedOptimizerWrapper(
        dict(in_backward_optimizer_filter(model.named_parameters())),
        lambda params: optimizer_cls(params, **optimizer_kwargs),
    )
    return dense_optimizer