Source code for itwinai.distributed

# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# - Linus Eickhoff <linus.maximilian.eickhoff@cern.ch> - CERN
# --------------------------------------------------------------------------------------

import builtins as __builtin__
import functools
import logging
import os
import subprocess
import sys
from typing import TYPE_CHECKING, Any, Callable

from pydantic import BaseModel

if TYPE_CHECKING:
    from ray.train import ScalingConfig


py_logger = logging.getLogger(__name__)


[docs] class ClusterEnvironment(BaseModel): """Stores information about distributed environment.""" #: Global rank of current worker, in a distributed environment. #: ``global_rank==0`` identifies the main worker. #: Defaults to 0. global_rank: int = 0 #: Local rank of current worker, in a distributed environment. #: Defaults to 0. local_rank: int = 0 #: Total number of workers in a distributed environment. #: Defaults to 1. global_world_size: int = 1 #: Number of workers on the same node in a distributed environment. #: Defaults to 1. local_world_size: int = 1
[docs] def ray_cluster_is_running() -> bool: try: # Run the `ray status` command. It should be less overhead than ray.init() result = subprocess.run( ["ray", "status"], capture_output=True, text=True, check=True, ) # Check if the output indicates a running cluster return ( "Node status" in result.stdout and "Resources" in result.stdout and "Usage" in result.stdout ) except subprocess.CalledProcessError: # If the command fails, the cluster is not running py_logger.debug( "Ray was checking for the existence of a Ray cluster by trying to " "connect to it, but could not do it. This is not a problem if you " "are not planning to connect to a Ray cluster." ) return False except FileNotFoundError: # If `ray` command is not found, Ray is not installed py_logger.debug( "Error: 'ray' command not found while checking if a Ray cluster " "exists. Is Ray installed?" ) return False
def _get_env( name: str, *, default: Any | None = None, cast: Callable[[str], Any] = lambda x: x, required: bool = False, ) -> Any: """Fetches and casts an environment variable. Args: name: the ENV var name. default: returned if var is unset (and required is False). cast: function to transform the raw str into the desired type. required: if True and var is missing, raises KeyError. """ raw = os.getenv(name) if raw is None: if required: raise KeyError(f"Required environment variable {name!r} not set") return default try: return cast(raw) except Exception as e: py_logger.warning("Failed to cast env %s=%r using %s: %s", name, raw, cast, e) return default def _get_int(name: str, default: int | None = None, required: bool = False) -> int: return _get_env(name, default=default, cast=int, required=required)
[docs] def detect_distributed_environment() -> ClusterEnvironment: """Detects a distributed environment by probing known env vars.""" # 1) TorchElastic if os.getenv("TORCHELASTIC_RUN_ID") is not None: py_logger.debug("Using TorchElastic distributed cluster") # https://pytorch.org/docs/stable/elastic/run.html#environment-variables return ClusterEnvironment( global_rank=_get_int("RANK", required=True), local_rank=_get_int("LOCAL_RANK", required=True), local_world_size=_get_int("LOCAL_WORLD_SIZE", required=True), global_world_size=_get_int("WORLD_SIZE", required=True), ) # 2) Open MPI (guard against stale SLURM_* that might be set) ompi_size = _get_int("OMPI_COMM_WORLD_SIZE", default=-1) slurm_tasks = _get_int("SLURM_NTASKS", default=0) if ompi_size >= slurm_tasks: py_logger.debug("Using Open MPI distributed cluster") # https://docs.open-mpi.org/en/v5.0.x/tuning-apps/environment-var.html return ClusterEnvironment( global_rank=_get_int("OMPI_COMM_WORLD_RANK", required=True), local_rank=_get_int("OMPI_COMM_WORLD_LOCAL_RANK", required=True), local_world_size=_get_int("OMPI_COMM_WORLD_LOCAL_SIZE", required=True), global_world_size=ompi_size, ) # 3) SLURM (fallback) if os.getenv("SLURM_JOB_ID") is not None: py_logger.debug("Using SLURM distributed environment") # https://hpcc.umd.edu/hpcc/help/slurmenv.html return ClusterEnvironment( global_rank=_get_int("SLURM_PROCID", required=True), local_rank=_get_int("SLURM_LOCALID", required=True), local_world_size=_get_int("SLURM_NTASKS_PER_NODE", default=1), global_world_size=_get_int("SLURM_NTASKS", required=True), ) # 4) default: no distributed env py_logger.debug("No distributed environment was detected") return ClusterEnvironment()
#: Save original builtin print before patching it in distributed environments builtin_print = __builtin__.print
[docs] def distributed_patch_print(is_main: bool) -> Callable: """Disable ``print()`` when not in main worker. Args: is_main (bool): whether it is called from main worker. Returns: Callable: patched ``print()``. """ def patched_print(*args, **kwself): """Print is disables on workers different from the main one, unless the print is called with ``force=True`` argument. """ force = kwself.pop("force", False) if is_main or force: builtin_print(*args, **kwself) return patched_print
[docs] def suppress_workers_print(func: Callable) -> Callable: """Decorator to suppress ``print()`` calls in workers having global rank different from 0. To force printing on all workers you need to use ``print(..., force=True)``. """ @functools.wraps(func) def wrapper(*args, **kwargs) -> Any: # Disable print in workers different from the main one, # when in distributed environments. dist_grank = detect_distributed_environment().global_rank patched_print = distributed_patch_print(is_main=dist_grank == 0) previous_print_backup = __builtin__.print __builtin__.print = patched_print try: result = func(*args, **kwargs) except Exception as exc: # Reset print to builtin before raising the exception. __builtin__.print = previous_print_backup raise exc # Reset print to builtin __builtin__.print = previous_print_backup return result return wrapper
[docs] def suppress_workers_output(func): """Decorator to suppress ``stadout`` and ``stderr`` in workers having global rank different from 0. """ @functools.wraps(func) def wrapper(*args, **kwargs): # Save the original stdout and stderr original_stdout = sys.stdout original_stderr = sys.stderr # Get global rank dist_grank = detect_distributed_environment().global_rank try: if dist_grank == 0: # If on main worker return func(*args, **kwargs) # If not on main worker, redirect stdout and stderr to devnull with open(os.devnull, "w") as devnull: sys.stdout = devnull sys.stderr = devnull # Execute the wrapped function return func(*args, **kwargs) finally: # Restore original stdout and stderr sys.stdout = original_stdout sys.stderr = original_stderr return wrapper
[docs] def get_adaptive_ray_scaling_config() -> "ScalingConfig": """Returns a Ray scaling config for distributed ML training depending on the resources available in the Ray cluster. The number of workers is equal to the number of GPUs available, and if there are not GPUs two CPU-only workers are used. """ import ray from ray.train import ScalingConfig # Initialize Ray if not already initialized if not ray.is_initialized(): ray.init() # Get cluster resources cluster_resources = ray.cluster_resources() num_gpus = int(cluster_resources.get("GPU", 0)) # Configure ScalingConfig based on GPU availability if num_gpus <= 1: # If 0 or 1 GPU, don't use GPU for training py_logger.debug("Returning a scaling config to run distributed ML on 2 CPUs") return ScalingConfig( num_workers=2, # Default to 2 CPU workers use_gpu=False, ) else: # If multiple GPUs, use all available GPUs py_logger.debug(f"Returning a scaling config to run distributed ML on {num_gpus} GPUs") return ScalingConfig(num_workers=num_gpus, use_gpu=True)