Source code for gigl.common.utils.torch_training
import os
from typing import Optional
import torch.distributed
from gigl.common.logger import Logger
[docs]
def get_world_size() -> int:
    """
    This is automatically set by Kubeflow PyTorchJob launcher
    Returns:
        int: Total number of processes involved in distributed training
    """
    return int(os.environ.get("WORLD_SIZE", 1)) 
[docs]
def get_rank() -> int:
    """
    This is automatically set by Kubeflow PyTorchJob launcher
    Returns:
        int: The index of the process involved in distributed training
    """
    return int(os.environ.get("RANK", 0)) 
[docs]
def is_distributed_local_debug() -> bool:
    """
    For local debugging purpose only
    This sets necessary environment variables for distributed training at local machine
    Returns:
        bool: If True, then should_distribute early exit and enables distributed training
    """
    if not int(os.environ.get("DISTRIBUTED_LOCAL_DEBUG", 0)):
        return False
    os.environ["WORLD_SIZE"] = "1"
    os.environ["RANK"] = "0"
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29501")
    logger.info(
        f'Overriding local environment variables for debugging WORLD_SIZE={os.environ["WORLD_SIZE"]}, RANK={os.environ["RANK"]}, MASTER_ADDR={os.environ["MASTER_ADDR"]}, MASTER_PORT={os.environ["MASTER_PORT"]}'
    )
    return True 
[docs]
def should_distribute() -> bool:
    """
    Determines whether the process should be configured for distributed training.
    Returns:
        bool: True if the process is configured for distributed training
    """
    if is_distributed_local_debug():
        logger.info(f"Distributed training enabled for local debugging")
        return True
    should_distribute = torch.distributed.is_available() and get_world_size() > 1
    logger.info(f"Should we distribute training? {should_distribute}")
    return should_distribute 
[docs]
def get_distributed_backend(use_cuda: bool) -> Optional[str]:
    """
    Returns the distributed backend based on whether distributed training is enabled and whether CUDA is used.
    Args:
        use_cuda (bool): Whether CUDA is used for training
    Returns:
        Optional[str]: The distributed backend (NCCL or GLOO) if distributed training is enabled, None otherwise
    """
    if not should_distribute():
        return None
    return (
        torch.distributed.Backend.NCCL if use_cuda else torch.distributed.Backend.GLOO
    ) 
[docs]
def is_distributed_available_and_initialized() -> bool:
    """
    Returns:
        bool: True if distributed training is available and initialized, False otherwise
    """
    return torch.distributed.is_available() and torch.distributed.is_initialized()