Source code for gigl.distributed.utils.device

import torch


[docs] def get_available_device(local_process_rank: int) -> torch.device: r"""Returns the available device for the current process. Args: local_process_rank (int): The local rank of the current process within a node. Returns: torch.device: The device to use. """ device = torch.device( "cpu" if not torch.cuda.is_available() # If the number of processes are larger than the available GPU, # we assign each process to one GPU in a round robin manner. else f"cuda:{local_process_rank % torch.cuda.device_count()}" ) return device