# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# - Linus Eickhoff <linus.maximilian.eickhoff@cern.ch> - CERN
# --------------------------------------------------------------------------------------
import abc
import builtins as __builtin__
import functools
import logging
import os
import subprocess
import sys
from typing import TYPE_CHECKING, Any, Callable
from pydantic import BaseModel
if TYPE_CHECKING:
from ray.train import ScalingConfig
py_logger = logging.getLogger(__name__)
[docs]
class DistributedStrategy(abc.ABC):
"""Abstract class to define the distributed backend methods."""
[docs]
class ClusterEnvironment(BaseModel):
"""Stores information about distributed environment."""
#: Global rank of current worker, in a distributed environment.
#: ``global_rank==0`` identifies the main worker.
#: Defaults to 0.
global_rank: int = 0
#: Local rank of current worker, in a distributed environment.
#: Defaults to 0.
local_rank: int = 0
#: Total number of workers in a distributed environment.
#: Defaults to 1.
global_world_size: int = 1
#: Number of workers on the same node in a distributed environment.
#: Defaults to 1.
local_world_size: int = 1
[docs]
def ray_cluster_is_running() -> bool:
try:
# Run the `ray status` command. It should be less overhead than ray.init()
result = subprocess.run(
["ray", "status"],
capture_output=True,
text=True,
check=True,
)
# Check if the output indicates a running cluster
return (
"Node status" in result.stdout
and "Resources" in result.stdout
and "Usage" in result.stdout
)
except subprocess.CalledProcessError as exc:
# If the command fails, the cluster is not running
py_logger.debug(
f"Subprocess failed with return code {exc.returncode} while checking if "
"a Ray cluster exists.\n"
f"Stdout: {exc.stdout}\n"
f"Stderr: {exc.stderr}"
)
return False
except FileNotFoundError:
# If `ray` command is not found, Ray is not installed
py_logger.debug(
"Error: 'ray' command not found while checking if a Ray cluster "
"exists. Is Ray installed?"
)
return False
[docs]
def detect_distributed_environment() -> ClusterEnvironment:
"""Detects distributed environment, extracting information like
global ans local ranks, and world size.
"""
if os.getenv("TORCHELASTIC_RUN_ID") is not None:
# Torch elastic environment
# https://pytorch.org/docs/stable/elastic/run.html#environment-variables
return ClusterEnvironment(
global_rank=os.getenv("RANK"),
local_rank=os.getenv("LOCAL_RANK"),
local_world_size=os.getenv("LOCAL_WORLD_SIZE"),
global_world_size=os.getenv("WORLD_SIZE"),
)
# Fixes issue that OMPI_* env vars might be set despite using srun instead of mpirun
elif int(os.getenv("OMPI_COMM_WORLD_SIZE", -1)) >= int(os.getenv("SLURM_NTASKS", 0)):
py_logger.debug("Using Open MPI environment")
# Open MPI environment
# https://docs.open-mpi.org/en/v5.0.x/tuning-apps/environment-var.html
return ClusterEnvironment(
global_rank=os.getenv("OMPI_COMM_WORLD_RANK"),
local_rank=os.getenv("OMPI_COMM_WORLD_LOCAL_RANK"),
local_world_size=os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE"),
global_world_size=os.getenv("OMPI_COMM_WORLD_SIZE"),
)
# It is difficult to understand ranks from a Ray cluster... It could have been set up to
# tune a model with non-distributed strategy.
# elif ray_cluster_is_running():
# import ray
# ray_initialized_in_here = False
# if not ray.is_initialized():
# ray_initialized_in_here = True
# ray.init(address="auto")
# try:
# # Determine the local rank and local world size
# current_node = ray.util.get_node_ip_address()
# all_nodes = [node["NodeManagerAddress"] for node in ray.nodes()]
# # Filter tasks on the same node
# local_world_size = all_nodes.count(current_node)
# local_rank = (
# all_nodes[: all_nodes.index(current_node) + 1].count(current_node) - 1
# )
# cluster = ClusterEnvironment(
# global_rank=ray.get_runtime_context().get_node_id(),
# local_rank=local_rank,
# local_world_size=local_world_size,
# global_world_size=len(ray.nodes()),
# )
# finally:
# if ray_initialized_in_here:
# ray.shutdown()
# return cluster
elif os.getenv("SLURM_JOB_ID") is not None:
# https://hpcc.umd.edu/hpcc/help/slurmenv.html
return ClusterEnvironment(
global_rank=os.getenv("SLURM_PROCID"),
local_rank=os.getenv("SLURM_LOCALID"),
local_world_size=os.getenv("SLURM_NTASKS_PER_NODE", 1),
global_world_size=os.getenv("SLURM_NTASKS"),
)
else:
return ClusterEnvironment()
#: Save original builtin print before patching it in distributed environments
builtin_print = __builtin__.print
[docs]
def distributed_patch_print(is_main: bool) -> Callable:
"""Disable ``print()`` when not in main worker.
Args:
is_main (bool): whether it is called from main worker.
Returns:
Callable: patched ``print()``.
"""
def patched_print(*args, **kwself):
"""Print is disables on workers different from
the main one, unless the print is called with
``force=True`` argument.
"""
force = kwself.pop("force", False)
if is_main or force:
builtin_print(*args, **kwself)
return patched_print
[docs]
def suppress_workers_print(func: Callable) -> Callable:
"""Decorator to suppress ``print()`` calls in workers having global rank
different from 0. To force printing on all workers you need to use
``print(..., force=True)``.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
# Disable print in workers different from the main one,
# when in distributed environments.
dist_grank = detect_distributed_environment().global_rank
patched_print = distributed_patch_print(is_main=dist_grank == 0)
previous_print_backup = __builtin__.print
__builtin__.print = patched_print
try:
result = func(*args, **kwargs)
except Exception as exc:
# Reset print to builtin before raising the exception.
__builtin__.print = previous_print_backup
raise exc
# Reset print to builtin
__builtin__.print = previous_print_backup
return result
return wrapper
[docs]
def suppress_workers_output(func):
"""Decorator to suppress ``stadout`` and ``stderr`` in workers having global rank
different from 0.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Save the original stdout and stderr
original_stdout = sys.stdout
original_stderr = sys.stderr
# Get global rank
dist_grank = detect_distributed_environment().global_rank
try:
if dist_grank == 0:
# If on main worker
return func(*args, **kwargs)
# If not on main worker, redirect stdout and stderr to devnull
with open(os.devnull, "w") as devnull:
sys.stdout = devnull
sys.stderr = devnull
# Execute the wrapped function
return func(*args, **kwargs)
finally:
# Restore original stdout and stderr
sys.stdout = original_stdout
sys.stderr = original_stderr
return wrapper
[docs]
def get_adaptive_ray_scaling_config() -> "ScalingConfig":
"""Returns a Ray scaling config for distributed ML training depending on the resources
available in the Ray cluster. The number of workers is equal to the number of GPUs
available, and if there are not GPUs two CPU-only workers are used.
"""
import ray
from ray.train import ScalingConfig
# Initialize Ray if not already initialized
if not ray.is_initialized():
ray.init()
# Get cluster resources
cluster_resources = ray.cluster_resources()
num_gpus = int(cluster_resources.get("GPU", 0))
# Configure ScalingConfig based on GPU availability
if num_gpus <= 1:
# If 0 or 1 GPU, don't use GPU for training
py_logger.debug("Returning a scaling config to run distributed ML on 2 CPUs")
return ScalingConfig(
num_workers=2, # Default to 2 CPU workers
use_gpu=False,
)
else:
# If multiple GPUs, use all available GPUs
py_logger.debug(f"Returning a scaling config to run distributed ML on {num_gpus} GPUs")
return ScalingConfig(num_workers=num_gpus, use_gpu=True)