# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# - Linus Eickhoff <linus.maximilian.eickhoff@cern.ch> - CERN
# --------------------------------------------------------------------------------------
import builtins as __builtin__
import functools
import logging
import os
import re
import subprocess
import sys
from typing import TYPE_CHECKING, Any, Callable
from pydantic import BaseModel
if TYPE_CHECKING:
from ray.train import ScalingConfig
py_logger = logging.getLogger(__name__)
[docs]
class ClusterEnvironment(BaseModel):
"""Stores information about distributed environment."""
#: Global rank of current worker, in a distributed environment.
#: ``global_rank==0`` identifies the main worker.
#: Defaults to 0.
global_rank: int = 0
#: Local rank of current worker, in a distributed environment.
#: Defaults to 0.
local_rank: int = 0
#: Total number of workers in a distributed environment.
#: Defaults to 1.
global_world_size: int = 1
#: Number of workers on the same node in a distributed environment.
#: Defaults to 1.
local_world_size: int = 1
[docs]
def ray_cluster_is_running() -> bool:
"""Check if a Ray cluster is running.
Returns:
bool: True if a running Ray cluster is detected. False otherwise.
"""
try:
# Run the `ray status` command. It should be less overhead than ray.init()
result = subprocess.run(
["ray", "status"],
capture_output=True,
text=True,
check=True,
)
# Check if the output indicates a running cluster
return (
"Node status" in result.stdout
and "Resources" in result.stdout
and "Usage" in result.stdout
)
except subprocess.CalledProcessError:
# If the command fails, the cluster is not running
py_logger.debug(
"Ray was checking for the existence of a Ray cluster by trying to "
"connect to it, but could not do it. This is not a problem if you "
"are not planning to connect to a Ray cluster."
)
return False
except FileNotFoundError:
# If `ray` command is not found, Ray is not installed
py_logger.debug(
"Error: 'ray' command not found while checking if a Ray cluster "
"exists. Is Ray installed?"
)
return False
def _get_env(
name: str,
*,
default: Any | None = None,
cast: Callable[[str], Any] = lambda x: x,
required: bool = False,
) -> Any | None:
"""Fetches and casts an environment variable.
Args:
name (str): the ENV var name.
default (Any): returned if var is unset (and required is False). Defaults to None.
cast (Callable[[str], Any]): function to transform the raw str into the desired type.
Defaults to the identity function.
required (bool): if True and var is missing, raises KeyError. Defaults to False.
Raises:
KeyError: when a required env variable is required but not found.
"""
raw = os.getenv(name)
if raw is None:
if required:
raise KeyError(f"Required environment variable {name!r} not set")
return default
try:
return cast(raw)
except Exception as e:
py_logger.warning("Failed to cast env %s=%r using %s: %s", name, raw, cast, e)
return default
def _get_int(name: str, default: int | None = None, required: bool = False) -> int:
val = _get_env(name, default=default, cast=int, required=required)
if val is None:
raise ValueError(
f"Could not cast variable {name!r} to int because it is None. "
f"Default value was set to {default!r}"
)
return int(val)
def _has_all(*names: str) -> bool:
"""True iff all env vars are set (not None)."""
return all(os.getenv(n) is not None for n in names)
def _has_any(*names: str) -> bool:
"""True iff at least one env var is set (not None)."""
return any(os.getenv(n) is not None for n in names)
def _is_torchrun_env() -> bool:
"""Detect torchrun / TorchElastic cluster."""
# https://pytorch.org/docs/stable/elastic/run.html#environment-variables
return _has_all("RANK", "WORLD_SIZE", "LOCAL_RANK", "LOCAL_WORLD_SIZE") and (
os.getenv("TORCHELASTIC_RUN_ID") is not None or _has_all("MASTER_ADDR", "MASTER_PORT")
)
def _parse_slurm_tasks_per_node(*vals: str | None) -> int | None:
"""Parse SLURM tasks-per-node env vars.
Accepts one or more candidate strings (e.g. SLURM_NTASKS_PER_NODE,
SLURM_TASKS_PER_NODE). Returns the first successfully parsed value.
Handles formats like:
"2"
"2(x3)"
"2(x2),1"
"4,4,4"
Returns: max tasks on any node, or None if nothing is parseable.
"""
for val in vals:
if not val:
continue
parts = [p.strip() for p in val.strip().split(",")]
counts: list[int] = []
for p in parts:
m = re.match(r"^(\d+)(?:\(x\d+\))?$", p)
if m:
counts.append(int(m.group(1)))
if counts:
return max(counts)
return None
def _get_n_visible_devices() -> int | None:
"""Return the number of visible GPUs from common env vars.
Checks NVIDIA and AMD/ROCm conventions:
- CUDA_VISIBLE_DEVICES (NVIDIA; also sometimes set for CUDA apps)
- ROCR_VISIBLE_DEVICES (ROCm/HIP runtime)
- HIP_VISIBLE_DEVICES (ROCm/HIP; sometimes preferred by frameworks)
Notes:
- Values can be comma-separated IDs ("0,1,2") or UUIDs.
- Some runtimes use "-1" to mean "no devices"; treat that as 0.
- If the variable is unset/empty, returns None.
"""
for var in ("CUDA_VISIBLE_DEVICES", "ROCR_VISIBLE_DEVICES", "HIP_VISIBLE_DEVICES"):
visible = os.getenv(var)
if not visible:
continue
s = visible.strip()
if s in ("-1", "none", "None", ""):
return None
# Typical format: "0,1,2" (can include whitespace)
tokens = [t.strip() for t in s.split(",") if t.strip() != ""]
if tokens:
return len(tokens)
# Some setups may set a single value without commas; count as 1
return 1
return None
[docs]
def detect_distributed_environment() -> ClusterEnvironment:
"""Detect a distributed environment by probing known env vars.
Robust across:
- laptop (no SLURM/OMPI/torchrun env): returns default (rank=0, world=1)
- interactive SLURM allocation without a job step: returns default
- SLURM batch / srun step: detects via SLURM_PROCID
- OpenMPI: detects via OMPI rank/size
- torchrun / TorchElastic: detects via rank/size + extra torch markers
Returns:
ClusterEnvironment: The detected cluster environment.
"""
# 1) torchrun / TorchElastic
if _is_torchrun_env():
py_logger.debug(
"Detected torchrun/TorchElastic distributed environment (elastic_run_id=%s)",
os.getenv("TORCHELASTIC_RUN_ID"),
)
return ClusterEnvironment(
global_rank=_get_int("RANK", required=True),
local_rank=_get_int("LOCAL_RANK", required=True),
local_world_size=_get_int("LOCAL_WORLD_SIZE", required=True),
global_world_size=_get_int("WORLD_SIZE", required=True),
)
# 2) Open MPI
if _has_all(
"OMPI_COMM_WORLD_RANK",
"OMPI_COMM_WORLD_SIZE",
"OMPI_COMM_WORLD_LOCAL_SIZE",
"OMPI_COMM_WORLD_LOCAL_RANK",
):
py_logger.debug("Detected OpenMPI distributed cluster")
return ClusterEnvironment(
global_rank=_get_int("OMPI_COMM_WORLD_RANK", required=True),
local_rank=_get_int("OMPI_COMM_WORLD_LOCAL_RANK", required=True),
local_world_size=_get_int("OMPI_COMM_WORLD_LOCAL_SIZE", required=True),
global_world_size=_get_int("OMPI_COMM_WORLD_SIZE", required=True),
)
# 3) SLURM
if (
_has_all("SLURM_PROCID", "SLURM_LOCALID")
and _has_any("SLURM_STEP_NUM_TASKS", "SLURM_NTASKS")
and _has_any(
"SLURM_NTASKS_PER_NODE", "SLURM_TASKS_PER_NODE", "SLURM_STEP_TASKS_PER_NODE"
)
):
py_logger.debug("Detected SLURM distributed environment")
# Prefer step-level task count if available, else fall back to job-level
global_world_size = _get_env("SLURM_STEP_NUM_TASKS", default=None, cast=int)
if global_world_size is None:
global_world_size = _get_int("SLURM_NTASKS", required=True)
local_world_size = _parse_slurm_tasks_per_node(
os.getenv("SLURM_NTASKS_PER_NODE"),
os.getenv("SLURM_TASKS_PER_NODE"),
os.getenv("SLURM_STEP_TASKS_PER_NODE"),
)
if local_world_size is None or local_world_size <= 0:
# If is is still None, crash
raise ValueError(
"Could not determine local_world_size from SLURM tasks-per-node env vars "
"(SLURM_NTASKS_PER_NODE / SLURM_TASKS_PER_NODE / SLURM_STEP_TASKS_PER_NODE)."
)
return ClusterEnvironment(
global_rank=_get_int("SLURM_PROCID", required=True),
local_rank=_get_int("SLURM_LOCALID", required=True),
local_world_size=int(local_world_size),
global_world_size=int(global_world_size),
)
# 4) Default: no distributed env (e.g., laptop)
py_logger.debug("No distributed environment was detected")
return ClusterEnvironment()
#: Save original builtin print before patching it in distributed environments
builtin_print = __builtin__.print
[docs]
def distributed_patch_print(is_main: bool) -> Callable:
"""Disable ``print()`` when not in main worker.
Args:
is_main (bool): whether it is called from main worker.
Returns:
Callable: patched ``print()``.
"""
def patched_print(*args, **kwself):
"""Print is disables on workers different from
the main one, unless the print is called with
``force=True`` argument.
"""
force = kwself.pop("force", False)
if is_main or force:
builtin_print(*args, **kwself)
return patched_print
[docs]
def suppress_workers_print(func: Callable) -> Callable:
"""Decorator to suppress ``print()`` calls in workers having global rank
different from 0. To force printing on all workers you need to use
``print(..., force=True)``.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
# Disable print in workers different from the main one,
# when in distributed environments.
dist_grank = detect_distributed_environment().global_rank
patched_print = distributed_patch_print(is_main=dist_grank == 0)
previous_print_backup = __builtin__.print
__builtin__.print = patched_print
try:
result = func(*args, **kwargs)
except Exception as exc:
# Reset print to builtin before raising the exception.
__builtin__.print = previous_print_backup
raise exc
# Reset print to builtin
__builtin__.print = previous_print_backup
return result
return wrapper
[docs]
def suppress_workers_output(func):
"""Decorator to suppress ``stadout`` and ``stderr`` in workers having global rank
different from 0.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Save the original stdout and stderr
original_stdout = sys.stdout
original_stderr = sys.stderr
# Get global rank
dist_grank = detect_distributed_environment().global_rank
try:
if dist_grank == 0:
# If on main worker
return func(*args, **kwargs)
# If not on main worker, redirect stdout and stderr to devnull
with open(os.devnull, "w") as devnull:
sys.stdout = devnull
sys.stderr = devnull
# Execute the wrapped function
return func(*args, **kwargs)
finally:
# Restore original stdout and stderr
sys.stdout = original_stdout
sys.stderr = original_stderr
return wrapper
[docs]
def get_adaptive_ray_scaling_config() -> "ScalingConfig":
"""Returns a Ray scaling config for distributed ML training depending on the resources
available in the Ray cluster. The number of workers is equal to the number of GPUs
available, and if there are not GPUs two CPU-only workers are used.
"""
import ray
from ray.train import ScalingConfig
# Initialize Ray if not already initialized
if not ray.is_initialized():
ray.init()
# Get cluster resources
cluster_resources = ray.cluster_resources()
num_gpus = int(cluster_resources.get("GPU", 0))
# Configure ScalingConfig based on GPU availability
if num_gpus <= 1:
# If 0 or 1 GPU, don't use GPU for training
py_logger.debug("Returning a scaling config to run distributed ML on 2 CPUs")
return ScalingConfig(
num_workers=2, # Default to 2 CPU workers
use_gpu=False,
)
else:
# If multiple GPUs, use all available GPUs
py_logger.debug(f"Returning a scaling config to run distributed ML on {num_gpus} GPUs")
return ScalingConfig(num_workers=num_gpus, use_gpu=True)