Source code for itwinai.distributed

# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# - Linus Eickhoff <linus.maximilian.eickhoff@cern.ch> - CERN
# --------------------------------------------------------------------------------------

import builtins as __builtin__
import functools
import logging
import os
import re
import subprocess
import sys
from typing import TYPE_CHECKING, Any, Callable

from pydantic import BaseModel

if TYPE_CHECKING:
    from ray.train import ScalingConfig


py_logger = logging.getLogger(__name__)


[docs] class ClusterEnvironment(BaseModel): """Stores information about distributed environment.""" #: Global rank of current worker, in a distributed environment. #: ``global_rank==0`` identifies the main worker. #: Defaults to 0. global_rank: int = 0 #: Local rank of current worker, in a distributed environment. #: Defaults to 0. local_rank: int = 0 #: Total number of workers in a distributed environment. #: Defaults to 1. global_world_size: int = 1 #: Number of workers on the same node in a distributed environment. #: Defaults to 1. local_world_size: int = 1
[docs] def ray_cluster_is_running() -> bool: """Check if a Ray cluster is running. Returns: bool: True if a running Ray cluster is detected. False otherwise. """ try: # Run the `ray status` command. It should be less overhead than ray.init() result = subprocess.run( ["ray", "status"], capture_output=True, text=True, check=True, ) # Check if the output indicates a running cluster return ( "Node status" in result.stdout and "Resources" in result.stdout and "Usage" in result.stdout ) except subprocess.CalledProcessError: # If the command fails, the cluster is not running py_logger.debug( "Ray was checking for the existence of a Ray cluster by trying to " "connect to it, but could not do it. This is not a problem if you " "are not planning to connect to a Ray cluster." ) return False except FileNotFoundError: # If `ray` command is not found, Ray is not installed py_logger.debug( "Error: 'ray' command not found while checking if a Ray cluster " "exists. Is Ray installed?" ) return False
def _get_env( name: str, *, default: Any | None = None, cast: Callable[[str], Any] = lambda x: x, required: bool = False, ) -> Any | None: """Fetches and casts an environment variable. Args: name (str): the ENV var name. default (Any): returned if var is unset (and required is False). Defaults to None. cast (Callable[[str], Any]): function to transform the raw str into the desired type. Defaults to the identity function. required (bool): if True and var is missing, raises KeyError. Defaults to False. Raises: KeyError: when a required env variable is required but not found. """ raw = os.getenv(name) if raw is None: if required: raise KeyError(f"Required environment variable {name!r} not set") return default try: return cast(raw) except Exception as e: py_logger.warning("Failed to cast env %s=%r using %s: %s", name, raw, cast, e) return default def _get_int(name: str, default: int | None = None, required: bool = False) -> int: val = _get_env(name, default=default, cast=int, required=required) if val is None: raise ValueError( f"Could not cast variable {name!r} to int because it is None. " f"Default value was set to {default!r}" ) return int(val) def _has_all(*names: str) -> bool: """True iff all env vars are set (not None).""" return all(os.getenv(n) is not None for n in names) def _has_any(*names: str) -> bool: """True iff at least one env var is set (not None).""" return any(os.getenv(n) is not None for n in names) def _is_torchrun_env() -> bool: """Detect torchrun / TorchElastic cluster.""" # https://pytorch.org/docs/stable/elastic/run.html#environment-variables return _has_all("RANK", "WORLD_SIZE", "LOCAL_RANK", "LOCAL_WORLD_SIZE") and ( os.getenv("TORCHELASTIC_RUN_ID") is not None or _has_all("MASTER_ADDR", "MASTER_PORT") ) def _parse_slurm_tasks_per_node(*vals: str | None) -> int | None: """Parse SLURM tasks-per-node env vars. Accepts one or more candidate strings (e.g. SLURM_NTASKS_PER_NODE, SLURM_TASKS_PER_NODE). Returns the first successfully parsed value. Handles formats like: "2" "2(x3)" "2(x2),1" "4,4,4" Returns: max tasks on any node, or None if nothing is parseable. """ for val in vals: if not val: continue parts = [p.strip() for p in val.strip().split(",")] counts: list[int] = [] for p in parts: m = re.match(r"^(\d+)(?:\(x\d+\))?$", p) if m: counts.append(int(m.group(1))) if counts: return max(counts) return None def _get_n_visible_devices() -> int | None: """Return the number of visible GPUs from common env vars. Checks NVIDIA and AMD/ROCm conventions: - CUDA_VISIBLE_DEVICES (NVIDIA; also sometimes set for CUDA apps) - ROCR_VISIBLE_DEVICES (ROCm/HIP runtime) - HIP_VISIBLE_DEVICES (ROCm/HIP; sometimes preferred by frameworks) Notes: - Values can be comma-separated IDs ("0,1,2") or UUIDs. - Some runtimes use "-1" to mean "no devices"; treat that as 0. - If the variable is unset/empty, returns None. """ for var in ("CUDA_VISIBLE_DEVICES", "ROCR_VISIBLE_DEVICES", "HIP_VISIBLE_DEVICES"): visible = os.getenv(var) if not visible: continue s = visible.strip() if s in ("-1", "none", "None", ""): return None # Typical format: "0,1,2" (can include whitespace) tokens = [t.strip() for t in s.split(",") if t.strip() != ""] if tokens: return len(tokens) # Some setups may set a single value without commas; count as 1 return 1 return None
[docs] def detect_distributed_environment() -> ClusterEnvironment: """Detect a distributed environment by probing known env vars. Robust across: - laptop (no SLURM/OMPI/torchrun env): returns default (rank=0, world=1) - interactive SLURM allocation without a job step: returns default - SLURM batch / srun step: detects via SLURM_PROCID - OpenMPI: detects via OMPI rank/size - torchrun / TorchElastic: detects via rank/size + extra torch markers Returns: ClusterEnvironment: The detected cluster environment. """ # 1) torchrun / TorchElastic if _is_torchrun_env(): py_logger.debug( "Detected torchrun/TorchElastic distributed environment (elastic_run_id=%s)", os.getenv("TORCHELASTIC_RUN_ID"), ) return ClusterEnvironment( global_rank=_get_int("RANK", required=True), local_rank=_get_int("LOCAL_RANK", required=True), local_world_size=_get_int("LOCAL_WORLD_SIZE", required=True), global_world_size=_get_int("WORLD_SIZE", required=True), ) # 2) Open MPI if _has_all( "OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE", "OMPI_COMM_WORLD_LOCAL_SIZE", "OMPI_COMM_WORLD_LOCAL_RANK", ): py_logger.debug("Detected OpenMPI distributed cluster") return ClusterEnvironment( global_rank=_get_int("OMPI_COMM_WORLD_RANK", required=True), local_rank=_get_int("OMPI_COMM_WORLD_LOCAL_RANK", required=True), local_world_size=_get_int("OMPI_COMM_WORLD_LOCAL_SIZE", required=True), global_world_size=_get_int("OMPI_COMM_WORLD_SIZE", required=True), ) # 3) SLURM if ( _has_all("SLURM_PROCID", "SLURM_LOCALID") and _has_any("SLURM_STEP_NUM_TASKS", "SLURM_NTASKS") and _has_any( "SLURM_NTASKS_PER_NODE", "SLURM_TASKS_PER_NODE", "SLURM_STEP_TASKS_PER_NODE" ) ): py_logger.debug("Detected SLURM distributed environment") # Prefer step-level task count if available, else fall back to job-level global_world_size = _get_env("SLURM_STEP_NUM_TASKS", default=None, cast=int) if global_world_size is None: global_world_size = _get_int("SLURM_NTASKS", required=True) local_world_size = _parse_slurm_tasks_per_node( os.getenv("SLURM_NTASKS_PER_NODE"), os.getenv("SLURM_TASKS_PER_NODE"), os.getenv("SLURM_STEP_TASKS_PER_NODE"), ) if local_world_size is None or local_world_size <= 0: # If is is still None, crash raise ValueError( "Could not determine local_world_size from SLURM tasks-per-node env vars " "(SLURM_NTASKS_PER_NODE / SLURM_TASKS_PER_NODE / SLURM_STEP_TASKS_PER_NODE)." ) return ClusterEnvironment( global_rank=_get_int("SLURM_PROCID", required=True), local_rank=_get_int("SLURM_LOCALID", required=True), local_world_size=int(local_world_size), global_world_size=int(global_world_size), ) # 4) Default: no distributed env (e.g., laptop) py_logger.debug("No distributed environment was detected") return ClusterEnvironment()
#: Save original builtin print before patching it in distributed environments builtin_print = __builtin__.print
[docs] def distributed_patch_print(is_main: bool) -> Callable: """Disable ``print()`` when not in main worker. Args: is_main (bool): whether it is called from main worker. Returns: Callable: patched ``print()``. """ def patched_print(*args, **kwself): """Print is disables on workers different from the main one, unless the print is called with ``force=True`` argument. """ force = kwself.pop("force", False) if is_main or force: builtin_print(*args, **kwself) return patched_print
[docs] def suppress_workers_print(func: Callable) -> Callable: """Decorator to suppress ``print()`` calls in workers having global rank different from 0. To force printing on all workers you need to use ``print(..., force=True)``. """ @functools.wraps(func) def wrapper(*args, **kwargs) -> Any: # Disable print in workers different from the main one, # when in distributed environments. dist_grank = detect_distributed_environment().global_rank patched_print = distributed_patch_print(is_main=dist_grank == 0) previous_print_backup = __builtin__.print __builtin__.print = patched_print try: result = func(*args, **kwargs) except Exception as exc: # Reset print to builtin before raising the exception. __builtin__.print = previous_print_backup raise exc # Reset print to builtin __builtin__.print = previous_print_backup return result return wrapper
[docs] def suppress_workers_output(func): """Decorator to suppress ``stadout`` and ``stderr`` in workers having global rank different from 0. """ @functools.wraps(func) def wrapper(*args, **kwargs): # Save the original stdout and stderr original_stdout = sys.stdout original_stderr = sys.stderr # Get global rank dist_grank = detect_distributed_environment().global_rank try: if dist_grank == 0: # If on main worker return func(*args, **kwargs) # If not on main worker, redirect stdout and stderr to devnull with open(os.devnull, "w") as devnull: sys.stdout = devnull sys.stderr = devnull # Execute the wrapped function return func(*args, **kwargs) finally: # Restore original stdout and stderr sys.stdout = original_stdout sys.stderr = original_stderr return wrapper
[docs] def get_adaptive_ray_scaling_config() -> "ScalingConfig": """Returns a Ray scaling config for distributed ML training depending on the resources available in the Ray cluster. The number of workers is equal to the number of GPUs available, and if there are not GPUs two CPU-only workers are used. """ import ray from ray.train import ScalingConfig # Initialize Ray if not already initialized if not ray.is_initialized(): ray.init() # Get cluster resources cluster_resources = ray.cluster_resources() num_gpus = int(cluster_resources.get("GPU", 0)) # Configure ScalingConfig based on GPU availability if num_gpus <= 1: # If 0 or 1 GPU, don't use GPU for training py_logger.debug("Returning a scaling config to run distributed ML on 2 CPUs") return ScalingConfig( num_workers=2, # Default to 2 CPU workers use_gpu=False, ) else: # If multiple GPUs, use all available GPUs py_logger.debug(f"Returning a scaling config to run distributed ML on {num_gpus} GPUs") return ScalingConfig(num_workers=num_gpus, use_gpu=True)