Source code for itwinai.distributed

import abc
import os
from pydantic import BaseModel
import builtins as __builtin__


[docs] class DistributedStrategy(abc.ABC): """Abstract class to define the distributed backend methods."""
[docs] class ClusterEnvironment(BaseModel): """Stores information about distributed environment.""" #: Global rank of current worker, in a distributed environment. #: ``global_rank==0`` identifies the main worker. #: Defaults to 0. global_rank: int = 0 #: Local rank of current worker, in a distributed environment. #: Defaults to 0. local_rank: int = 0 #: Total number of workers in a distributed environment. #: Defaults to 1. global_world_size: int = 1 #: Number of workers on the same node in a distributed environment. #: Defaults to 1. local_world_size: int = 1
[docs] def detect_distributed_environment() -> ClusterEnvironment: """Detects distributed environment, extracting information like global ans local ranks, and world size. """ if os.getenv('TORCHELASTIC_RUN_ID') is not None: # Torch elastic environment # https://pytorch.org/docs/stable/elastic/run.html#environment-variables return ClusterEnvironment( global_rank=os.getenv('RANK'), local_rank=os.getenv('LOCAL_RANK'), local_world_size=os.getenv('LOCAL_WORLD_SIZE'), global_world_size=os.getenv('WORLD_SIZE') ) elif os.getenv('OMPI_COMM_WORLD_SIZE') is not None: # Open MPI environment # https://docs.open-mpi.org/en/v5.0.x/tuning-apps/environment-var.html return ClusterEnvironment( global_rank=os.getenv('OMPI_COMM_WORLD_RANK'), local_rank=os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'), local_world_size=os.getenv('OMPI_COMM_WORLD_LOCAL_SIZE'), global_world_size=os.getenv('OMPI_COMM_WORLD_SIZE') ) elif os.getenv('SLURM_JOB_ID') is not None: print("WARNING: detected SLURM environment, but " "unable to determine ranks and world sizes!") return ClusterEnvironment() else: return ClusterEnvironment()
#: Save original builtin print before patching it in distributed environments builtin_print = __builtin__.print
[docs] def distributed_patch_print(is_main: bool) -> None: """Disable ``print()`` when not in main worker. Args: is_main (bool): whether it is called from main worker. """ def print(*args, **kwself): """Print is disables on workers different from the main one, unless the print is called with ``force=True`` argument. """ force = kwself.pop('force', False) if is_main or force: builtin_print(*args, **kwself) # Patch builtin print __builtin__.print = print