# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# - Jarl Sondre Sæther <jarl.sondre.saether@cern.ch> - CERN
# - Henry Mutegeki <henry.mutegeki@cern.ch> - CERN
# - Anna Lappe <anna.elisa.lappe@cern.ch> - CERN
# - Rakesh Sarma <r.sarma@fz-juelich.de> - Juelich
# - Linus Eickhoff <linus.maximilian.eickhoff@cern.ch> - CERN
# --------------------------------------------------------------------------------------
import abc
import functools
import logging
import os
from typing import Any, Callable, Iterable, List, Literal, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.optim.optimizer import Optimizer
from torch.utils.data import (
DataLoader,
Dataset,
DistributedSampler,
RandomSampler,
Sampler,
SequentialSampler,
)
from torch.utils.data.dataloader import _collate_fn_t, _worker_init_fn_t
from ..distributed import detect_distributed_environment, ray_cluster_is_running
from .type import DistributedStrategyError, UninitializedStrategyError
py_logger = logging.getLogger(__name__)
[docs]
def distributed_resources_available() -> bool:
"""Check if the current execution environment
has (enough) GPUs available to allow for distributed ML.
Returns:
bool: env can support distributed ML.
"""
if int(os.environ.get("ITWINAI_FORCE_DIST", "0")):
return True
cluster = detect_distributed_environment()
return cluster.global_world_size > 1
[docs]
def check_initialized(method: Callable) -> Callable:
"""Decorator for strategy methods to check whether the strategy
was correctly initialized before calling the method."""
@functools.wraps(method)
def wrapper(self: "TorchDistributedStrategy", *args, **kwargs):
if not self.is_initialized:
raise UninitializedStrategyError(
f"{self.__class__.__name__} has not been initialized. Use the init method."
)
return method(self, *args, **kwargs)
return wrapper
[docs]
def initialize_ray() -> None:
"""This method is used by the RayDDPStrategy and RayDeepSpeedStrategy to initialize
the Ray backend if it is not already initialized. This is meant to be called before
submitting a function to Ray (as a trial in tuning, or as a worker in distributed ML).
Raises:
RuntimeError: when no Ray cluster is detected.
EnvironmentError: If required environment variables `HEAD_NODE_PORT` or
`HEAD_NODE_IP` are not set.
These should be set from the slurm script where the ray cluster is launched.
"""
import ray
from ray.runtime_env import RuntimeEnv
if not ray_cluster_is_running():
raise RuntimeError(
"You are trying to initialize Ray, but the cluster seems not to be running"
)
if ray.is_initialized():
return
mlflow_username = os.environ.get("MLFLOW_TRACKING_USERNAME", "")
mlflow_password = os.environ.get("MLFLOW_TRACKING_PASSWORD", "")
if not mlflow_username:
py_logger.warning("MLFLOW_TRACKING_USERNAME env variable is not set.")
if not mlflow_password:
py_logger.warning("MLFLOW_TRACKING_PASSWORD env variable is not set.")
# Set mlflow credentials to be accessible for all the workers
runtime_env = RuntimeEnv(
env_vars={
"MLFLOW_TRACKING_USERNAME": mlflow_username,
"MLFLOW_TRACKING_PASSWORD": mlflow_password,
}
)
ray.init(address="auto", runtime_env=runtime_env)
py_logger.info(f"Nodes in the cluster: {ray.nodes()}")
py_logger.info(f"Available cluster resources: {ray.available_resources()}")
[docs]
class TorchDistributedStrategy(abc.ABC):
"""Abstract class to define the distributed backend methods for
PyTorch models.
"""
#: Allows to discriminate distributed strategies from non-distributed.
#: Defaults to True.
is_distributed: bool = True
#: Set to True when the current strategy is initialized.
#: Defaults to False.
is_initialized: bool = False
# Provides the name of the strategy for logging purposes etc.
name: str
@property
@check_initialized
def is_main_worker(self) -> bool:
"""Checks if local worker has global rank equal to zero.
Returns:
bool: True if main worker.
"""
return self.global_rank() == 0
[docs]
@abc.abstractmethod
def init(self) -> None:
"""Initializes the chosen distributed backend"""
[docs]
@abc.abstractmethod
def distributed(
self,
model: nn.Module,
optimizer: Optimizer,
lr_scheduler: LRScheduler | None = None,
) -> Tuple[nn.Module, Optimizer, LRScheduler | None]:
"""Setup model, optimizer and scheduler for distributed."""
[docs]
@abc.abstractmethod
def global_world_size(self) -> int:
"""Returns the total number of processes (global world size).
Returns:
int: global world size.
"""
[docs]
@abc.abstractmethod
def local_world_size(self) -> int:
"""Returns the number of local workers available on a node
(local world size).
Usually it is equal to the number of available GPUs.
Returns:
int: local world size.
"""
[docs]
@abc.abstractmethod
def global_rank(self) -> int:
"""Returns the global rank of the current process.
Rank ranges from 0 to world_size.
Returns:
int: global rank.
"""
[docs]
@abc.abstractmethod
def local_rank(self) -> int:
"""Returns the local rank of the current process.
Returns:
int: local rank.
"""
[docs]
@abc.abstractmethod
def barrier(self) -> None:
"""Forces all the workers to wait for each other."""
[docs]
@check_initialized
def device(self) -> str:
"""Device used by local worker.
Returns:
str: torch device in the form 'device:N' (e.g., 'cuda:0', 'cpu').
"""
if torch.cuda.is_available():
return f"cuda:{self.local_rank()}"
return "cpu"
[docs]
def set_device(self):
"""Set local device."""
if torch.cuda.is_available():
torch.cuda.device(self.local_rank())
# Needed by torch.distributed.gather_object
torch.cuda.set_device(self.local_rank())
[docs]
@check_initialized
def create_dataloader(
self,
dataset: Dataset,
batch_size: int | None = 1,
shuffle: bool | None = None,
sampler: Sampler | Iterable | None = None,
batch_sampler: Sampler[List] | Iterable[List] | None = None,
num_workers: int = 0,
collate_fn: _collate_fn_t | None = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn: _worker_init_fn_t | None = None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor: int | None = None,
persistent_workers: bool = False,
pin_memory_device: str = "",
):
"""Create a distributed DataLoader by using ``DistributedSampler`` as
random sampler.
Args:
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: ``1``).
shuffle (bool, optional): set to ``True`` to have the data
reshuffled at every epoch (default: ``False``).
sampler (Sampler or Iterable, optional): defines the strategy to
draw
samples from the dataset. Can be any ``Iterable`` with
``__len__``
implemented. If specified, :attr:`shuffle` must not be
specified.
batch_sampler (Sampler or Iterable, optional): like
:attr:`sampler`, but
returns a batch of indices at a time. Mutually exclusive with
:attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
and :attr:`drop_last`.
num_workers (int, optional): how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main
process. (default: ``0``)
collate_fn (Callable, optional): merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from
a map-style dataset.
pin_memory (bool, optional): If ``True``, the data loader will
copy Tensors
into device/CUDA pinned memory before returning them. If your
data elements
are a custom type, or your :attr:`collate_fn` returns a batch
that is a custom type,
see the example below.
drop_last (bool, optional): set to ``True`` to drop the last
incomplete batch,
if the dataset size is not divisible by the batch size.
If ``False`` and
the size of dataset is not divisible by the batch size, then
the last batch
will be smaller. (default: ``False``)
timeout (numeric, optional): if positive, the timeout value for
collecting a batch
from workers. Should always be non-negative. (default: ``0``)
worker_init_fn (Callable, optional): If not ``None``,
this will be called on each
worker subprocess with the worker id (an int in
``[0, num_workers - 1]``) as
input, after seeding and before data loading.
(default: ``None``)
multiprocessing_context (str or
multiprocessing.context.BaseContext, optional): If
``None``, the default `multiprocessing context`_ of
your operating system will
be used. (default: ``None``)
generator (torch.Generator, optional): If not ``None``,
this RNG will be used
by RandomSampler to generate random indexes and
multiprocessing to generate
``base_seed`` for workers. (default: ``None``)
prefetch_factor (int, optional, keyword-only arg): Number of
batches loaded
in advance by each worker. ``2`` means there will be a total of
2 * num_workers batches prefetched across all workers.
(default value depends
on the set value for num_workers. If value of num_workers=0
default is ``None``.
Otherwise, if value of ``num_workers > 0`` default is ``2``).
persistent_workers (bool, optional): If ``True``, the data loader
will not shut down
the worker processes after a dataset has been consumed once.
This allows to
maintain the workers `Dataset` instances alive.
(default: ``False``)
pin_memory_device (str, optional): the device to
:attr:`pin_memory` to if ``pin_memory`` is ``True``.
Raises:
UninitializedStrategyError: when this method is called for a
strategy which had not been initialized.
RuntimeError: when a user-provided sampler, if given, is not of
type ``DistributedSampler``.
.. warning:: If the ``spawn`` start method is used,
:attr:`worker_init_fn`
cannot be an unpicklable object, e.g., a lambda function.
See `Multiprocessing best practices`_ on more
details related to multiprocessing in PyTorch.
.. warning:: ``len(dataloader)`` heuristic is based on the length of
the sampler used.
When :attr:`dataset` is an
:class:`~torch.utils.data.IterableDataset`,
it instead returns an estimate based on
``len(dataset) / batch_size``, with proper
rounding depending on :attr:`drop_last`, regardless
of multi-process loading
configurations. This represents the best guess PyTorch
can make because PyTorch
trusts user :attr:`dataset` code in correctly handling
multi-process
loading to avoid duplicate data.
However, if sharding results in multiple workers having
incomplete last batches,
this estimate can still be inaccurate, because (1) an
otherwise complete batch can
be broken into multiple ones and (2) more than one batch
worth of samples can be
dropped when :attr:`drop_last` is set. Unfortunately,
PyTorch can not detect such cases in general.
See `Dataset Types`_ for more details on these two
types of datasets and how
:class:`~torch.utils.data.IterableDataset` interacts with
`Multi-process data loading`_.
.. warning:: See `Reproducibility`_, and
`My data loader workers return identical random numbers`_,
and
`Randomness in multi-process data loading`_ notes for
random seed related questions.
.. _multiprocessing context:
https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
.. _Multiprocessing best practices:
https://pytorch.org/docs/stable/notes/multiprocessing.html#multiprocessing-best-practices
.. _Reproducibility:
https://pytorch.org/docs/stable/notes/randomness.html#reproducibility
.. _My data loader workers return identical random numbers:
https://pytorch.org/docs/stable/notes/faq.html#dataloader-workers-random-seed
.. _Randomness in multi-process data loading:
https://pytorch.org/docs/stable/data.html#data-loading-randomness
.. _Multi-process data loading:
https://pytorch.org/docs/stable/data.html#multi-process-data-loading
.. _Dataset Types:
https://pytorch.org/docs/stable/data.html#dataset-types
"""
if batch_sampler is not None:
py_logger.warning("WARNING: batch_sampler is ignored by TorchDistributedStrategy")
if self.is_distributed:
if sampler is None:
sampler = DistributedSampler(
dataset,
num_replicas=self.global_world_size(),
rank=self.global_rank(),
shuffle=shuffle,
)
elif not isinstance(sampler, DistributedSampler):
raise RuntimeError("User-provided sampler must implement DistributedSampler.")
else:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
# shuffle and batch_sampler must be unset
return DataLoader(
dataset=dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context,
generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory_device=pin_memory_device,
)
[docs]
@abc.abstractmethod
def clean_up(self) -> None:
"""Cleans up resources allocated by distributed strategy."""
[docs]
@abc.abstractmethod
def allgather_obj(self, obj: Any) -> List[Any]:
"""All-gathers any object from the whole group in a list
(to all workers).
Args:
obj (Any): object to gather from all workers.
Returns:
List[Any]: list of objects gathered from all workers.
"""
[docs]
@abc.abstractmethod
def gather_obj(self, obj: Any, dst_rank: int = 0) -> List[Any]:
"""Gathers any object from the whole group in a list
(to all workers).
Args:
obj (Any): object to gather from all workers.
dst_rank (int): rank of the worker on which the objects list
are gathered.
Returns:
List[Any]: list of objects gathered from all workers.
"""
[docs]
@abc.abstractmethod
def gather(self, tensor: torch.Tensor, dst_rank: int = 0) -> List | None:
"""Gathers any object from the whole group in a list
(to all workers).
Args:
obj (Any): object to gather from all workers.
dst_rank (int): rank of the worker on which the objects list
are gathered.
Returns:
Optional[List[Any]]: list of objects gathered from all workers if main
worker, otherwise return None.
"""
[docs]
@abc.abstractmethod
def broadcast_obj(self, obj: Any, src_rank: int) -> Any:
"""Broadcasts an object to all workers.
Args:
obj (Any): object to broadcast to all workers.
src_rank (int): the rank that broadcasted
Returns:
Any: broadcasted object.
"""
[docs]
class TorchDDPStrategy(TorchDistributedStrategy):
"""PyTorch ``DistributedDataParallel`` distributed strategy class.
Args:
backend (Literal['nccl', 'gloo', 'mpi']): Name of the
distributed communication backend to employ.
"""
#: Torch distributed communication backend.
backend: Literal["nccl", "gloo", "mpi"]
def __init__(self, backend: Literal["nccl", "gloo", "mpi"]) -> None:
super().__init__()
self.backend = backend
self.name = "torch-ddp"
[docs]
def init(self) -> None:
"""Initializes the distributed process group and the distributed
package.
Raises:
RuntimeError: when there are not (enough) GPUs available.
DistributedStrategyError: when trying to initialize a strategy
which is already initialized.
"""
if not distributed_resources_available():
raise RuntimeError("Trying to run distributed on insufficient resources.")
if self.is_initialized:
raise DistributedStrategyError("Strategy was already initialized")
dist.init_process_group(backend=self.backend)
self.is_initialized = True
self.set_device()
[docs]
@check_initialized
def distributed(
self,
model: nn.Module,
optimizer: Optimizer,
lr_scheduler: LRScheduler | None = None,
find_unused_parameters: bool = False,
**kwargs,
) -> Tuple[nn.Module, Optimizer, LRScheduler | None]:
"""Setup model, optimizer and scheduler for distributed."""
if torch.cuda.is_available():
# If GPUs are available
model = model.to(self.device())
dist_model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[self.device()],
output_device=self.device(),
find_unused_parameters=find_unused_parameters,
)
elif distributed_resources_available():
# If GPUs are not available, but running distributed ML on CPUs
dist_model = torch.nn.parallel.DistributedDataParallel(
model,
find_unused_parameters=find_unused_parameters,
)
else:
dist_model = model
return dist_model, optimizer, lr_scheduler
[docs]
@check_initialized
def barrier(self) -> None:
"""Forces all the workers to wait for each other."""
return dist.barrier()
[docs]
@check_initialized
def global_world_size(self) -> int:
"""Returns the total number of processes (global world size).
Returns:
int: global world size.
"""
return dist.get_world_size()
[docs]
@check_initialized
def local_world_size(self) -> int:
"""Returns the local number of workers available per node,
which is usually the number of GPUs available.
Returns:
int: local world size.
Raises:
RuntimeError: when the local world size cannot be retrieved.
"""
if torch.cuda.is_available():
return torch.cuda.device_count()
if "LOCAL_WORLD_SIZE" not in os.environ:
raise RuntimeError(
"Could not retrieve local world size as CUDA is unavailable and there is "
"no 'LOCAL_WORLD_SIZE' environment variable."
)
return int(os.environ["LOCAL_WORLD_SIZE"])
[docs]
@check_initialized
def global_rank(self) -> int:
"""Returns the global rank of the current process, where
rank ranges from 0 to world_size.
Returns:
int: global rank.
"""
return dist.get_rank()
[docs]
@check_initialized
def local_rank(self) -> int:
"""Returns the local rank of the current process.
Returns:
int: local rank.
"""
return dist.get_rank() % self.local_world_size()
[docs]
@check_initialized
def clean_up(self) -> None:
"""Destroys the current process group."""
if distributed_resources_available():
dist.barrier()
dist.destroy_process_group()
[docs]
@check_initialized
def allgather_obj(self, obj: Any) -> List[Any]:
"""All-gathers any object from the whole group
in a list (to all workers).
Args:
obj (Any): Object to gather from all workers.
Returns:
List[Any]: List of gathered objects.
"""
# https://pytorch.org/docs/stable/distributed.html#collective-functions
res = [None] * self.global_world_size()
dist.all_gather_object(res, obj)
return res
[docs]
@check_initialized
def gather_obj(self, obj: Any, dst_rank: int = 0) -> List | None:
"""Gathers any object from the whole group in a list
(to all workers).
Args:
obj (Any): object to gather from all workers.
dst_rank (int): rank of the worker on which the objects list
are gathered.
Returns:
List | None: list of objects gathered from all workers
or ``None`` on non-destination ranks.
"""
# https://pytorch.org/docs/stable/distributed.html#collective-functions
if self.global_rank() == dst_rank:
res = [None] * self.global_world_size()
dist.gather_object(obj, res, dst=dst_rank)
return res
dist.gather_object(obj, dst=dst_rank)
[docs]
@check_initialized
def gather(self, tensor: torch.Tensor, dst_rank: int = 0) -> List | None:
# https://pytorch.org/docs/stable/distributed.html#collective-functions
# Ensure that the tensor is on the correct device (CUDA)
tensor = tensor.to(self.device())
if self.global_rank() != dst_rank:
dist.gather(tensor, dst=dst_rank)
return
res = [
torch.zeros_like(tensor, device=self.device())
for _ in range(self.global_world_size())
]
dist.gather(tensor, gather_list=res, dst=dst_rank)
# Moving everything to the CPU before returning
return [val.cpu() for val in res]
[docs]
@check_initialized
def broadcast_obj(self, obj: Any, src_rank: int) -> Any:
"""Broadcasts an object to all workers. (object must be picklable)
Args:
obj (Any): object to broadcast to all workers.
src_rank (int): the rank that broadcasted
Returns:
Any: broadcasted object.
"""
obj_list = [obj]
# https://pytorch.org/docs/stable/distributed.html#collective-functions
dist.broadcast_object_list(obj_list, src=src_rank)
return obj_list[0]
[docs]
class DeepSpeedStrategy(TorchDistributedStrategy):
"""DeepSpeed distributed strategy class.
Args:
backend (Literal['nccl', 'gloo', 'mpi']): Name of the
distributed communication backend to employ.
config (Union[dict, Path, str]): DeepSpeed config. Either a
dictionary or a path to a JSON file.
"""
#: Torch distributed communication backend.
backend: Literal["nccl", "gloo", "mpi"]
def __init__(self, backend: Literal["nccl", "gloo", "mpi"]) -> None:
super().__init__()
self.backend = backend
self.name = "deepspeed"
[docs]
def init(self) -> None:
"""Initializes the distributed process group and the distributed
package.
Raises:
RuntimeError: when there are not (enough) GPUs available.
DistributedStrategyError: when trying to initialize a strategy
already initialized.
"""
import deepspeed
# Removing the .put() method of the cache manager
# This is the same bug that was earlier removed in the generic_torch.sh script,
# using the sed command
from deepspeed.ops.transformer.inference.triton.matmul_ext import AutotuneCacheManager
def noop_put(self, table):
pass
AutotuneCacheManager.put = noop_put
py_logger.warning(
"[WARNING]: Disabling Triton's AutotuneCacheManager's `put()` method to fix "
"bug with temporary files. This might be fixed in the future by DeepSpeed,"
"in which case our fix should be removed."
)
self.deepspeed = deepspeed
if not distributed_resources_available():
raise RuntimeError("Trying to run distributed on insufficient resources.")
if self.is_initialized:
raise DistributedStrategyError("Strategy was already initialized")
# https://github.com/Lightning-AI/pytorch-lightning/issues/13567
# This block of code should be removed as some point
if os.environ.get("LOCAL_RANK"):
os.environ["OMPI_COMM_WORLD_LOCAL_RANK"] = os.environ.get("LOCAL_RANK")
# https://deepspeed.readthedocs.io/en/latest/initialize.html#training-initialization
self.deepspeed.init_distributed(dist_backend=self.backend)
self.is_initialized = True
self.set_device()
[docs]
@check_initialized
def distributed(
self,
model: nn.Module,
optimizer: Optimizer | None = None,
lr_scheduler: LRScheduler | None = None,
model_parameters: Any | None = None,
**init_kwargs,
) -> Tuple[nn.Module, Optimizer, LRScheduler | None]:
"""Setup model, optimizer and scheduler for distributed."""
py_logger.debug(f"Distributing the model using device: {self.device()}")
# model = model.to(self.device())
distrib_model, optimizer, _, lr_scheduler = self.deepspeed.initialize(
model=model,
model_parameters=model_parameters,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dist_init_required=True,
**init_kwargs,
)
return distrib_model, optimizer, lr_scheduler
[docs]
@check_initialized
def barrier(self) -> None:
"""Forces all the workers to wait for each other."""
return dist.barrier()
[docs]
@check_initialized
def global_world_size(self) -> int:
"""Returns the total number of processes (global world size).
Returns:
int: global world size.
"""
return dist.get_world_size()
[docs]
@check_initialized
def local_world_size(self) -> int:
"""Returns the local number of workers available per node,
which is usually the number of GPUs available.
Returns:
int: local world size.
Raises:
RuntimeError: when the local world size cannot be retrieved.
"""
if torch.cuda.is_available():
return torch.cuda.device_count()
if "LOCAL_WORLD_SIZE" not in os.environ:
raise RuntimeError(
"Could not retrieve local world size as CUDA is unavailable and there is "
"no 'LOCAL_WORLD_SIZE' environment variable."
)
return int(os.environ["LOCAL_WORLD_SIZE"])
[docs]
@check_initialized
def global_rank(self) -> int:
"""Returns the global rank of the current process, where
rank ranges from 0 to world_size.
Returns:
int: global rank.
"""
return dist.get_rank()
[docs]
@check_initialized
def local_rank(self) -> int:
"""Returns the local rank of the current process.
Returns:
int: local rank.
"""
return dist.get_rank() % self.local_world_size()
[docs]
@check_initialized
def clean_up(self) -> None:
"""Destroys the current process group."""
if distributed_resources_available():
dist.barrier()
dist.destroy_process_group()
[docs]
@check_initialized
def allgather_obj(self, obj: Any) -> List[Any]:
"""All-gathers any object from the whole group
in a list (to all workers).
Args:
obj (Any): Object to gather from all workers.
Returns:
List[Any]: List of gathered objects.
"""
# https://pytorch.org/docs/stable/distributed.html#collective-functions
res = [None] * self.global_world_size()
dist.all_gather_object(res, obj)
return res
[docs]
@check_initialized
def gather_obj(self, obj: Any, dst_rank: int = 0) -> List[Any] | None:
"""Gathers any object from the whole group in a list
(to all workers).
Args:
obj (Any): object to gather from all workers.
dst_rank (int): rank of the worker on which the objects list
are gathered.
Returns:
Optional[List[Any]]: list of objects gathered from all workers
or ``None`` on non-destination ranks.
"""
# https://pytorch.org/docs/stable/distributed.html#collective-functions
if self.global_rank() == dst_rank:
res = [None] * self.global_world_size()
dist.gather_object(obj, res, dst=dst_rank)
return res
dist.gather_object(obj, dst=dst_rank)
[docs]
@check_initialized
def gather(self, tensor: torch.Tensor, dst_rank: int = 0) -> List[torch.Tensor] | None:
"""Gathers a tensor from the whole group in a list
(to all workers).
Args:
obj (Any): object to gather from all workers.
dst_rank (int): rank of the worker on which the objects list
are gathered.
Returns:
Optional[List[torch.Tensor]]: list of tensors gathered from all workers
or ``None`` on non-destination ranks.
"""
# https://pytorch.org/docs/stable/distributed.html#collective-functions
# Ensure that the tensor is on the correct device (CUDA)
tensor = tensor.to(self.device())
if self.global_rank() != dst_rank:
dist.gather(tensor, dst=dst_rank)
return
res = [
torch.zeros_like(tensor, device=self.device())
for _ in range(self.global_world_size())
]
dist.gather(tensor, gather_list=res, dst=dst_rank)
# Moving all the tensors to CPU before returning
return [val.cpu() for val in res]
[docs]
@check_initialized
def broadcast_obj(self, obj: Any, src_rank: int) -> Any:
"""Broadcasts an object to all workers. (object must be picklable)
Args:
obj (Any): object to broadcast to all workers.
src_rank (int): the rank that broadcasted
Returns:
Any: broadcasted object.
"""
obj_list = [obj]
# https://pytorch.org/docs/stable/distributed.html#collective-functions
dist.broadcast_object_list(obj_list, src=src_rank)
return obj_list[0]
[docs]
class HorovodStrategy(TorchDistributedStrategy):
"""Horovod distributed strategy class."""
def __init__(self):
super().__init__()
self.name = "horovod"
[docs]
def init(self) -> None:
"""Initializes the Horovod distributed backend.
Raises:
RuntimeError: when there are not (enough) GPUs available.
DistributedStrategyError: when trying to initialize a strategy
already initialized.
"""
if not distributed_resources_available():
raise RuntimeError("Trying to run distributed on insufficient resources.")
if self.is_initialized:
raise DistributedStrategyError("Strategy was already initialized")
import horovod.torch as hvd
self.hvd = hvd
self.hvd.init()
self.is_initialized = True
self.set_device()
[docs]
@check_initialized
def distributed(
self,
model: nn.Module,
optimizer: Optimizer | None = None,
lr_scheduler: LRScheduler | None = None,
**optim_kwargs,
) -> Tuple[nn.Module, Optimizer, LRScheduler | None]:
"""Setup model, optimizer and scheduler for distributed."""
model.to(self.device())
# Scale learning rate
# https://github.com/horovod/horovod/issues/1653#issuecomment-574764452
lr_scaler = 1
if optim_kwargs.get("op") == self.hvd.Adasum:
lr_scaler = self.hvd.local_size()
elif optim_kwargs.get("op") == self.hvd.Average:
lr_scaler = self.hvd.size()
for g in optimizer.param_groups:
g["lr"] *= lr_scaler
self._broadcast_params(model, optimizer)
distOptimizer = self.hvd.DistributedOptimizer(
optimizer, named_parameters=model.named_parameters(), **optim_kwargs
)
return model, distOptimizer, lr_scheduler
[docs]
@check_initialized
def barrier(self) -> None:
"""Forces all the workers to wait for each other."""
self.hvd.barrier()
def _broadcast_params(self, model: nn.Module, optimizer: optim.Optimizer) -> None:
"""Broadcasts variables from root rank to all other processes.
Args:
model (nn.Module): ML model that is to be broadcasted
across processes.
optimizer (optim.Optimizer): Optimizer that is to be broadcasted
across processes.
"""
self.hvd.broadcast_parameters(model.state_dict(), root_rank=0)
self.hvd.broadcast_optimizer_state(optimizer, root_rank=0)
[docs]
@check_initialized
def global_world_size(self) -> int:
"""Returns the total number of processes (global world size).
Returns:
int: global world size.
"""
return self.hvd.size()
[docs]
@check_initialized
def local_world_size(self) -> int:
"""Returns the local number of workers available per node,
which is usually the number of GPUs available.
Returns:
int: local world size.
"""
return self.hvd.local_size()
[docs]
@check_initialized
def global_rank(self) -> int:
"""Returns the global rank of the current process, where
rank ranges from 0 to world_size.
Returns:
int: global rank.
"""
return self.hvd.rank()
[docs]
@check_initialized
def local_rank(self) -> int:
"""Returns the local rank of the current process.
Returns:
int: local rank.
"""
return self.hvd.local_rank()
[docs]
@check_initialized
def clean_up(self) -> None:
"""Shuts Horovod down."""
self.hvd.barrier()
self.hvd.shutdown()
[docs]
@check_initialized
def allgather_obj(self, obj: Any) -> list[Any]:
"""All-gathers any object from the whole group
in a list (to all workers).
Args:
obj (Any): Object to gather from all workers.
Returns:
List[Any]: List of gathered objects.
"""
return self.hvd.allgather_object(obj)
[docs]
@check_initialized
def gather_obj(self, obj: Any, dst_rank: int = 0) -> list[Any] | None:
"""Gathers any object from the whole group in a list
(to all workers). Under the hood it relies on allgather as gather is
not supported by Horovod.
Args:
obj (Any): object to gather from all workers.
dst_rank (int): rank of the worker on which the objects list
are gathered.
Returns:
Optional[List[Any]]: list of objects gathered from all workers
or ``None`` on non-destination ranks.
"""
result = self.allgather_obj(obj)
if self.global_rank() == dst_rank:
# Return only if on rank == dst_rank
return result
[docs]
@check_initialized
def gather(self, tensor: torch.Tensor, dst_rank: int = 0) -> List[torch.Tensor] | None:
"""Gathers a tensor from the whole group in a list
(to all workers). Under the hood it relies on allgather as gather is
not supported by Horovod.
Args:
obj (Any): object to gather from all workers.
dst_rank (int): rank of the worker on which the objects list
are gathered.
Returns:
Optional[List[torch.Tensor]]: list of tensors gathered from all workers
or ``None`` on non-destination ranks.
"""
result = self.allgather_obj(tensor)
if self.global_rank() == dst_rank:
# Return only if on rank == dst_rank
# Moving all the tensors to CPU before returning
return [val.cpu() for val in result]
[docs]
@check_initialized
def broadcast_obj(self, obj: Any, src_rank: int) -> Any:
"""Broadcasts an object to all workers. (object must be picklable)
Args:
obj (Any): object to broadcast to all workers.
src_rank (int): the rank that broadcasted
Returns:
Any: broadcasted object.
"""
if obj is None:
py_logger.warning(
"Broadcasting None object in Horovod. This might lead to unexpected behavior"
" such as deadlocks."
)
# https://horovod.readthedocs.io/en/stable/_modules/horovod/torch/functions.html#broadcast_object
return self.hvd.broadcast_object(obj, root_rank=src_rank)
[docs]
class NonDistributedStrategy(TorchDistributedStrategy):
"""Dummy class for non-distributed environments."""
#: This strategy is not distributed.
#: Defaults to False.
is_distributed: bool = False
def __init__(self):
super().__init__()
self.name = "non-distributed"
[docs]
def init(self) -> None:
"""If CUDA is available set CUDA device, and do nothing more.
Raises:
DistributedStrategyError: when trying to initialize a strategy
already initialized.
"""
if self.is_initialized:
raise DistributedStrategyError("Strategy was already initialized")
self.set_device()
self.is_initialized = True
[docs]
@check_initialized
def distributed(
self,
model: nn.Module,
optimizer: Optimizer | None = None,
lr_scheduler: LRScheduler | None = None,
**kwargs,
) -> Tuple[nn.Module, Optimizer, LRScheduler | None]:
"""Do nothing and return model, optimizer and scheduler."""
if torch.cuda.is_available():
model = model.cuda()
return model, optimizer, lr_scheduler
[docs]
@check_initialized
def barrier(self) -> None:
"""Forces all the workers to wait for each other."""
[docs]
def global_world_size(self) -> int:
"""Returns the total number of processes (global world size).
Returns:
int: global world size.
"""
return 1
[docs]
def local_world_size(self) -> int:
"""Returns the local number of workers available per node,
which is usually the number of GPUs available.
Returns:
int: local world size.
"""
return 1
[docs]
def global_rank(self) -> int:
"""Returns the global rank of the current process, where
rank ranges from 0 to world_size.
Returns:
int: global rank.
"""
return 0
[docs]
def local_rank(self) -> int:
"""Returns the local rank of the current process.
Returns:
int: local rank.
"""
return 0
[docs]
def clean_up(self) -> None:
"""Do nothing."""
[docs]
def allgather_obj(self, obj: Any) -> list[Any]:
"""Wraps ``obj`` into a List object.
Args:
obj (Any): object in a worker.
Returns:
list[Any]: input object wrapped in a list.
"""
return [obj]
[docs]
def gather_obj(self, obj: Any, dst_rank: int = 0) -> list[Any]:
"""Wraps ``obj`` into a List object.
Args:
obj (Any): object in a worker.
dst_rank (int): ignored.
Returns:
list[Any]: input object wrapped in a list.
"""
return [obj]
[docs]
def gather(self, tensor: torch.Tensor, dst_rank: int = 0):
"""Wraps ``tensor`` into a List object.
Args:
tensor (Any): object in a worker.
dst_rank (int): ignored.
Returns:
list[Any]: input object wrapped in a list.
"""
return [tensor]
[docs]
def broadcast_obj(self, obj: Any, src_rank: int) -> Any:
"""Broadcasts an object to all workers.
Args:
obj (Any): object to broadcast to all workers.
src_rank (int): the rank that broadcasted
Returns:
Any: broadcasted object.
"""
return obj
[docs]
class RayTorchDistributedStrategy(TorchDistributedStrategy):
"""Base class for all ray distributed strategies."""
[docs]
class RayDDPStrategy(TorchDDPStrategy, RayTorchDistributedStrategy):
"""A distributed data-parallel (DDP) strategy using Ray Train for PyTorch training."""
def __init__(self) -> None:
initialize_ray()
import ray.train
self.ray_train = ray.train
self.name = "ray-torch-ddp"
[docs]
def init(self) -> None:
"""Initializes Ray trial/worker.
Raises:
RuntimeError: when the Ray cluster is not detected.
"""
if not ray_cluster_is_running():
raise RuntimeError("Ray cluster was not detected")
if self.is_initialized:
raise DistributedStrategyError("Strategy was already initialized")
self.is_initialized = True
[docs]
@check_initialized
def global_world_size(self) -> int:
return self.ray_train.get_context().get_world_size()
[docs]
@check_initialized
def local_world_size(self) -> int:
return self.ray_train.get_context().get_local_world_size()
[docs]
@check_initialized
def global_rank(self) -> int:
return self.ray_train.get_context().get_world_rank()
[docs]
@check_initialized
def local_rank(self) -> int:
return self.ray_train.get_context().get_local_rank()
[docs]
@check_initialized
def distributed(
self,
model: nn.Module,
optimizer: Optimizer,
lr_scheduler: LRScheduler | None = None,
) -> Tuple[nn.Module, Optimizer, LRScheduler | None]:
model = self.ray_train.torch.prepare_model(model)
return model, optimizer, lr_scheduler
[docs]
class RayDeepSpeedStrategy(DeepSpeedStrategy, RayTorchDistributedStrategy):
"""A distributed strategy using Ray and DeepSpeed for PyTorch training.
Args:
backend (Literal["nccl", "gloo", "mpi"]): The backend for distributed communication.
"""
def __init__(self, backend: Literal["nccl", "gloo", "mpi"]) -> None:
initialize_ray()
super().__init__(backend=backend)
self.name = "ray-deepspeed"
[docs]
def init(self) -> None:
"""Initializes the distributed process group and the distributed
package.
Raises:
RuntimeError: when there is not a Ray cluster running.
DistributedStrategyError: when trying to initialize a strategy
already initialized.
"""
import deepspeed
self.deepspeed = deepspeed
if not ray_cluster_is_running():
raise RuntimeError("Ray cluster was not detected")
if self.is_initialized:
raise DistributedStrategyError("Strategy was already initialized")
# https://github.com/Lightning-AI/pytorch-lightning/issues/13567
# This block of code should be removed as some point
if os.environ.get("LOCAL_RANK"):
os.environ["OMPI_COMM_WORLD_LOCAL_RANK"] = os.environ.get("LOCAL_RANK")
# https://deepspeed.readthedocs.io/en/latest/initialize.html#training-initialization
self.deepspeed.init_distributed(dist_backend=self.backend)
self.is_initialized = True
self.set_device()