Source code for gigl.distributed.utils.device
import torch
[docs]
def get_available_device(local_process_rank: int) -> torch.device:
r"""Returns the available device for the current process.
Args:
local_process_rank (int): The local rank of the current process within a node.
Returns:
torch.device: The device to use.
"""
device = torch.device(
"cpu"
if not torch.cuda.is_available()
# If the number of processes are larger than the available GPU,
# we assign each process to one GPU in a round robin manner.
else f"cuda:{local_process_rank % torch.cuda.device_count()}"
)
return device
[docs]
def get_device_from_process_group() -> torch.device:
"""
Returns the device for the current process group.
Args:
None
Raises:
ValueError: If the distributed environment is not initialized.
Returns:
torch.device: The device to use.
"""
if not torch.distributed.is_initialized():
raise ValueError(
"Distributed environment must be initialized to get device from process group"
)
return torch.device("cuda" if torch.distributed.get_backend() == "nccl" else "cpu")