# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# - Anna Lappe <anna.elisa.lappe@cern.ch> - CERN
# - Jarl Sondre Sæther <jarl.sondre.saether@cern.ch> - CERN
# - Linus Eickhoff <linus.maximilian.eickhoff@cern.ch> - CERN
# --------------------------------------------------------------------------------------
"""Provides training logic for PyTorch models via Trainer classes."""
import logging
import os
import sys
import tempfile
import time
from collections import defaultdict
from pathlib import Path
from time import perf_counter as default_timer
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import ray.train
import ray.train.torch
import ray.tune
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import yaml
from ray.train import Checkpoint, DataConfig, RunConfig, ScalingConfig
from ray.train.torch import TorchConfig
from ray.train.torch import TorchTrainer as RayTorchTrainer
from ray.tune import TuneConfig
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
# Cyclic imports...
from itwinai.torch.monitoring.monitoring import measure_gpu_utilization
from itwinai.torch.profiling.profiler import profile_torch_trainer
from ..components import Trainer, monitor_exec
from ..distributed import ray_cluster_is_running
from ..loggers import EpochTimeTracker, Logger, LogMixin
from ..utils import generate_random_name, load_yaml, to_uri
from .config import TrainingConfiguration
from .distributed import (
DeepSpeedStrategy,
HorovodStrategy,
NonDistributedStrategy,
RayDDPStrategy,
RayDeepSpeedStrategy,
RayHorovodStrategy,
RayTorchDistributedStrategy,
TorchDDPStrategy,
TorchDistributedStrategy,
distributed_resources_available,
)
from .reproducibility import seed_worker, set_seed
from .tuning import search_space
from .type import Batch, Metric
if TYPE_CHECKING:
from ray.train.horovod import HorovodConfig
py_logger = logging.getLogger(__name__)
def _get_tuning_metric_name(tune_config: TuneConfig | None) -> str:
"""Extracts the metric name from TuneConfig or scheduler in a generic way."""
DEFAULT_NAME = "loss"
if not tune_config:
return DEFAULT_NAME
# Try to get from TuneConfig
if tune_config.metric:
return tune_config.metric
# Try to get from the scheduler (if defined)
scheduler = tune_config.scheduler
if scheduler and hasattr(scheduler, "metric") and scheduler.metric:
return scheduler.metric
return DEFAULT_NAME
[docs]
class TorchTrainer(Trainer, LogMixin):
"""Trainer class for torch training algorithms.
Args:
config (Dict | TrainingConfiguration): training configuration
containing hyperparameters.
epochs (int): number of training epochs.
model (Union[nn.Module, str] | None, optional): pytorch model to
train or a string identifier. Defaults to None.
strategy (Literal['ddp', 'deepspeed', 'horovod'], optional):
distributed strategy. Defaults to 'ddp'.
test_every (int | None, optional): run a test epoch
every ``test_every`` epochs. Disabled if None. Defaults to None.
random_seed (int | None, optional): set random seed for
reproducibility. If None, the seed is not set. Defaults to None.
logger (Logger | None, optional): logger for ML tracking.
Defaults to None.
metrics (Dict[str, Callable] | None, optional): map of torchmetrics
metrics. Defaults to None.
checkpoints_location (str): path to checkpoints directory.
Defaults to "checkpoints".
checkpoint_every (int | None): save a checkpoint every
``checkpoint_every`` epochs. Disabled if None. Defaults to None.
disable_tqdm (bool): whether to disable tqdm progress bar(s).
name (str | None, optional): trainer custom name. Defaults to None.
profiling_wait_epochs (int): how many epochs to wait before starting
the profiler.
profiling_warmup_epochs (int): length of the profiler warmup phase in terms of
number of epochs.
measure_gpu_data (bool): enable the collection of data on average GPU utilization and
total energy consumption throughout training. Defaults to False.
measure_communication_overhead (bool): enable the profiling of computation and
multi-worker communication operations. It uses the torch profiler and it may
slow down training. Dafults to False.
measure_epoch_time (bool): enable the measurement of epoch duration (in seconds).
Defaults to False,
ray_scaling_config (ScalingConfig, optional): scaling config for Ray Trainer.
Defaults to None,
ray_tune_config (TuneConfig, optional): tune config for Ray Tuner.
Defaults to None.
ray_run_config (RunConfig, optional): run config for Ray Trainer.
Defaults to None.
ray_search_space (Dict[str, Any], optional): search space for Ray Tuner.
Defaults to None.
ray_torch_config (TorchConfig, optional): torch configuration for Ray's TorchTrainer.
Defaults to None.
ray_data_config (DataConfig, optional): dataset configuration for Ray.
Defaults to None.
ray_horovod_config (HorovodConfig, optional): horovod configuration for Ray's
HorovodTrainer. Defaults to None.
from_checkpoint (str | Path, optional): path to checkpoint directory. Defaults to None.
initial_best_validation_metric (str): initial value for the best validation metric.
Usually the validation metric is a loss to be minimized and this value exceeds the
highest possible loss value, so that it will be overwritten when the first
vaidation loss is computed. Example values are "inf" and "-inf", depending on
wether the best validation metric should be minimized or maximized.
Defaults to "inf".
run_id (str, optional): name used to identify a specific run when collecting
metrics on the trainer (e.g. GPU utilization). Defaults to None.
time_ray (bool): whether to time and log the execution of Ray functions.
Defaults to False.
"""
_strategy: TorchDistributedStrategy | None = None
#: PyTorch ``DataLoader`` for training dataset.
train_dataloader: DataLoader | None = None
#: PyTorch ``DataLoader`` for validation dataset.
validation_dataloader: DataLoader | None = None
#: PyTorch ``DataLoader`` for test dataset.
test_dataloader: DataLoader | None = None
#: PyTorch model to train.
model: nn.Module | None = None
#: Loss criterion.
loss: Callable | None = None
#: Optimizer.
optimizer: Optimizer | None = None
#: Learning rate scheduler.
lr_scheduler: LRScheduler | None = None
#: PyTorch random number generator (PRNG).
torch_rng: torch.Generator | None = None
#: itwinai ``itwinai.Logger``
logger: Logger | None = None
#: Total number training batches used so far, across all epochs.
train_glob_step: int = 0
#: Total number validation batches used so far, across all epochs.
validation_glob_step: int = 0
#: Total number test batches used so far, across all epochs.
test_glob_step: int = 0
#: Dictionary of ``torchmetrics`` metrics, indexed by user-defined names.
metrics: Dict[str, Callable]
#: PyTorch Profiler for communication vs. computation comparison
profiler: Any | None
#: Toggle for GPU utilization monitoring
measure_gpu_data: bool = False
#: Toggle for communication vs computation fraction profiling
measure_communication_overhead: bool = False
#: Toggle for epoch time tracking
measure_epoch_time: bool = False
#: Run ID
run_id: str
#: Toggle for Ray time logging
time_ray: bool = False
def __init__(
self,
config: Dict | TrainingConfiguration,
epochs: int,
model: Union[nn.Module, str] | None = None,
strategy: Literal["ddp", "deepspeed", "horovod"] | None = "ddp",
test_every: int | None = None,
random_seed: int | None = None,
logger: Logger | None = None,
metrics: Dict[str, Metric] | None = None,
checkpoints_location: str | Path = "checkpoints",
checkpoint_every: int | None = None,
disable_tqdm: bool = False,
name: str | None = None,
profiling_wait_epochs: int = 1,
profiling_warmup_epochs: int = 2,
measure_gpu_data: bool = False,
measure_communication_overhead: bool = False,
measure_epoch_time: bool = False,
ray_scaling_config: ScalingConfig | None = None,
ray_tune_config: TuneConfig | None = None,
ray_run_config: RunConfig | None = None,
ray_search_space: Dict[str, Any] | None = None,
ray_torch_config: TorchConfig | None = None,
ray_data_config: DataConfig | None = None,
ray_horovod_config: Optional["HorovodConfig"] = None,
from_checkpoint: str | Path | None = None,
initial_best_validation_metric: str = "inf",
run_id: str | None = None,
time_ray: bool = False,
) -> None:
super().__init__(name)
self.save_parameters(**self.locals2params(locals()))
# config is mean to store all hyperparameters, which can very from use
# case to use case and include learning_rate, batch_size....
config = {} if config is None else config
if isinstance(config, dict):
config = TrainingConfiguration(**config)
self.config = config
self.epochs = epochs
self.model = model
self.strategy = strategy
self.test_every = test_every
self.random_seed = random_seed
self.logger = logger
self.metrics = metrics if metrics is not None else {}
self.checkpoints_location = checkpoints_location
os.makedirs(self.checkpoints_location, exist_ok=True)
self.checkpoint_every = checkpoint_every
self.disable_tqdm = disable_tqdm
self.profiler = None
self.profiling_wait_epochs = profiling_wait_epochs
self.profiling_warmup_epochs = profiling_warmup_epochs
self.measure_gpu_data = measure_gpu_data
self.measure_communication_overhead = measure_communication_overhead
self.measure_epoch_time = measure_epoch_time
self.ray_scaling_config = ray_scaling_config
self.ray_tune_config = ray_tune_config
self.ray_run_config = ray_run_config
self.ray_search_space = ray_search_space
self.ray_horovod_config = ray_horovod_config
self.ray_torch_config = ray_torch_config
self.ray_data_config = ray_data_config
self.from_checkpoint = from_checkpoint
self.time_ray = time_ray
if self.from_checkpoint:
self.from_checkpoint = Path(from_checkpoint)
if not self.from_checkpoint.exists():
raise RuntimeError(
"from_checkpoint argument was passed, but the checkpoint is not found "
f"at {self.from_checkpoint}"
)
if self.checkpoints_location:
Path(self.checkpoints_location).mkdir(exist_ok=True, parents=True)
py_logger.debug(f"ray_scaling_config: {ray_scaling_config}")
py_logger.debug(f"ray_tune_config: {ray_tune_config}")
py_logger.debug(f"ray_run_config: {ray_run_config}")
py_logger.debug(f"ray_horovod_config: {ray_horovod_config}")
py_logger.debug(f"ray_torch_config: {ray_torch_config}")
py_logger.debug(f"ray_data_config: {ray_data_config}")
py_logger.debug(f"ray_search_space: {ray_search_space}")
# Initial training state -- can be resumed from a checkpoint
self.model_state_dict = None
self.optimizer_state_dict = None
self.lr_scheduler_state_dict = None
self.torch_rng_state = None
# This is initialized to inf as it usually represents a loss to minimize.
# If the validation metric is meant to be maximized, change this to -inf.
self.best_validation_metric = float(initial_best_validation_metric)
self.current_epoch = 0
if run_id is None:
run_id = generate_random_name()
self.run_id = run_id
@property
def strategy(self) -> TorchDistributedStrategy:
"""Strategy currently in use."""
return self._strategy
@strategy.setter
def strategy(self, strategy: str | TorchDistributedStrategy) -> None:
if isinstance(strategy, TorchDistributedStrategy):
self._strategy = strategy
else:
self._strategy = self._detect_distributed_strategy(strategy)
@property
def device(self) -> str:
"""Current device from distributed strategy."""
return self.strategy.device()
def _detect_distributed_strategy(self, strategy: str) -> TorchDistributedStrategy:
"""When a Ray cluster is detected the Ray-equivalent distributed strategy is
automatically selected, without needing the user to explicitly set it.
"""
py_logger.debug(f"Strategy was set to {strategy}")
enough_resources = distributed_resources_available() or ray_cluster_is_running()
py_logger.debug(
f"Enough resources? {enough_resources} "
f"(distributed_resources_available: {distributed_resources_available()}) "
f"(ray_cluster_is_running: {ray_cluster_is_running()})"
)
# NOTE: setting strategy to None prevents the trainer to run distribtued ML, regardless
# of the availability of the resources.
if strategy is None or not enough_resources:
py_logger.warning("Falling back to non-distributed strategy.")
return NonDistributedStrategy()
if ray_cluster_is_running():
py_logger.info(
f"Ray cluster was detected, thus the Ray equivalent for {strategy} is used"
)
match strategy, ray_cluster_is_running():
case "ddp", True:
return RayDDPStrategy()
case "ddp", False:
return TorchDDPStrategy(backend=self.config.dist_backend)
case "horovod", True:
return RayHorovodStrategy()
case "horovod", False:
return HorovodStrategy()
case "deepspeed", True:
return RayDeepSpeedStrategy(backend=self.config.dist_backend)
case "deepspeed", False:
return DeepSpeedStrategy(backend=self.config.dist_backend)
case _:
raise RuntimeError(f"Strategy '{strategy}' is not recognized.")
def _init_distributed_strategy(self) -> None:
if not self.strategy.is_initialized:
self.strategy.init()
def _set_optimizer_from_config(self) -> None:
match self.config.optimizer:
case "adadelta":
self.optimizer = optim.Adadelta(
self.model.parameters(),
lr=self.config.optim_lr,
weight_decay=self.config.optim_weight_decay,
)
case "adam":
self.optimizer = optim.Adam(
self.model.parameters(),
lr=self.config.optim_lr,
betas=self.config.optim_betas,
weight_decay=self.config.optim_weight_decay,
)
case "adamw":
self.optimizer = optim.AdamW(
self.model.parameters(),
lr=self.config.optim_lr,
betas=self.config.optim_betas,
weight_decay=self.config.optim_weight_decay,
)
case "rmsprop":
self.optimizer = optim.RMSprop(
self.model.parameters(),
lr=self.config.optim_lr,
weight_decay=self.config.optim_weight_decay,
momentum=self.config.optim_momentum,
)
case "sgd":
self.optimizer = optim.SGD(
self.model.parameters(),
lr=self.config.optim_lr,
weight_decay=self.config.optim_weight_decay,
momentum=self.config.optim_momentum,
)
case _:
raise ValueError(
"Unrecognized self.config.optimizer! Check the docs for "
"supported values and consider overriding "
"create_model_loss_optimizer method for more flexibility."
)
def _set_lr_scheduler_from_config(self) -> None:
"""Parse Lr scheduler from training config"""
if not self.config.lr_scheduler:
return
if not self.optimizer:
raise ValueError("Trying to instantiate a LR scheduler but the optimizer is None!")
match self.config.lr_scheduler:
case "constant":
self.lr_scheduler = lr_scheduler.ConstantLR(self.optimizer)
case "polynomial":
self.lr_scheduler = lr_scheduler.PolynomialLR(self.optimizer)
case "exponential":
self.lr_scheduler = lr_scheduler.ExponentialLR(
self.optimizer, gamma=self.config.lr_scheduler_gamma
)
case "linear":
self.lr_scheduler = lr_scheduler.LinearLR(self.optimizer)
case "multistep":
self.lr_scheduler = lr_scheduler.MultiStepLR(
self.optimizer, milestones=self.config.lr_scheduler_step_size
)
case "step":
self.lr_scheduler = lr_scheduler.StepLR(
self.optimizer, step_size=self.config.lr_scheduler_step_size
)
case _:
raise ValueError(
"Unrecognized self.config.lr_scheduler! Check the docs for "
"supported values and consider overriding "
"create_model_loss_optimizer method for more flexibility."
)
def _set_loss_from_config(self) -> None:
match self.config.loss:
case "nllloss":
self.loss = nn.functional.nll_loss
case "cross_entropy":
self.loss = nn.functional.cross_entropy
case "mse":
self.loss = nn.functional.mse_loss
case "bceloss":
self.loss = nn.functional.binary_cross_entropy
case _:
raise ValueError(
"Unrecognized self.config.loss! Check the docs for "
"supported values and consider overriding "
"create_model_loss_optimizer method for more flexibility."
)
def _time_and_log(self, fn: Callable, identifier: str, step: int | None = None) -> Any:
"""Time and log the execution of a function (using time.monotonic()).
Args:
fn (Callable): function to execute, time and log (pass args using lambda)
identifier (str): identifier for the logged metric
step (int | None): step for logging
Returns:
result (Any): result of the function call
"""
if not self.logger:
py_logger.warning(f"No logger set! Cannot log time for {identifier}! ")
return fn()
if not self.logger.is_initialized:
py_logger.warning(
f"Logger context not initialized for timing {identifier}. Setting context."
)
self.logger.create_logger_context()
step = step or self.current_epoch
if step is None:
py_logger.warning("current_epoch is not set and no explicit step was provided!")
# Use monotonic time to avoid time drift
t_start = time.monotonic()
result = fn()
t_end = time.monotonic()
# already in seconds
t_delta = t_end - t_start
self.log(
item=t_delta,
identifier=identifier,
kind="metric",
step=step,
)
return result
[docs]
def get_default_distributed_kwargs(self) -> Dict:
"""Gives the default kwargs for the trainer's strategy's distributed() method."""
if isinstance(self.strategy, DeepSpeedStrategy):
# Batch size definition is not optional for DeepSpeedStrategy!
distribute_kwargs = dict(
config_params=dict(train_micro_batch_size_per_gpu=self.config.batch_size)
)
elif isinstance(self.strategy, HorovodStrategy):
import horovod.torch as hvd
distribute_kwargs = dict(
compression=(
hvd.Compression.fp16
if self.config.fp16_allreduce
else hvd.Compression.none
),
op=hvd.Adasum if self.config.use_adasum else hvd.Average,
gradient_predivide_factor=self.config.gradient_predivide_factor,
)
else:
distribute_kwargs = {}
return distribute_kwargs
[docs]
def create_model_loss_optimizer(self) -> None:
"""Instantiate a torch model, loss, optimizer, and LR scheduler using the
configuration provided in the Trainer constructor.
Generally a user-defined method.
"""
###################################
# Dear user, this is a method you #
# may be interested to override! #
###################################
# Model, optimizer, and lr scheduler may have already been loaded from a checkpoint
if self.model is None:
raise ValueError(
"self.model is None! Either pass it to the constructor, load a checkpoint, or "
"override create_model_loss_optimizer method."
)
if self.model_state_dict:
# Load model from checkpoint
self.model.load_state_dict(self.model_state_dict, strict=False)
# Parse optimizer from training configuration
# Optimizer can be changed with a custom one here!
self._set_optimizer_from_config()
# Parse LR scheduler from training configuration
# LR scheduler can be changed with a custom one here!
self._set_lr_scheduler_from_config()
if self.optimizer_state_dict:
# Load optimizer state from checkpoint
# IMPORTANT: this must be after the learning rate scheduler was already initialized
# by passing to it the optimizer. Otherwise the optimizer state just loaded will
# be modified by the lr scheduler.
self.optimizer.load_state_dict(self.optimizer_state_dict)
if self.lr_scheduler_state_dict and self.lr_scheduler:
# Load LR scheduler state from checkpoint
self.lr_scheduler.load_state_dict(self.lr_scheduler_state_dict)
# Parse loss from training configuration
# Loss can be changed with a custom one here!
self._set_loss_from_config()
# IMPORTANT: model, optimizer, and scheduler need to be distributed from here on
distribute_kwargs = self.get_default_distributed_kwargs()
# Distributed model, optimizer, and scheduler
(self.model, self.optimizer, self.lr_scheduler) = self.strategy.distributed(
self.model, self.optimizer, self.lr_scheduler, **distribute_kwargs
)
[docs]
def save_checkpoint(
self,
name: str,
best_validation_metric: torch.Tensor | None = None,
checkpoints_root: str | Path | None = None,
force: bool = False,
) -> str | None:
"""Save training checkpoint.
Args:
name (str): name of the checkpoint directory.
best_validation_metric (torch.Tensor | None): best validation metric throughout
training so far (if available). Usually this is the validation loss.
checkpoints_root (str | None): path for root checkpoints dir. If None, uses
``self.checkpoints_location`` as base.
force (bool): force checkpointign now.
Returns:
path to the checkpoint file or ``None`` when the checkpoint is not created.
"""
# Determine whether a checkpoint should be created
should_checkpoint = self.strategy.is_main_worker and (
force
or self.checkpoint_every
and (self.current_epoch + 1) % self.checkpoint_every == 0
)
ckpt_dir = Path(checkpoints_root or self.checkpoints_location) / name
py_logger.info(f"Saving checkpoint at {ckpt_dir.resolve()}? {should_checkpoint}")
if not should_checkpoint:
# Do nothing and return
return
ckpt_dir = Path(checkpoints_root or self.checkpoints_location) / name
ckpt_dir.mkdir(parents=True, exist_ok=True)
# Save state (epoch, loss, optimizer, scheduler)
state = {
"epoch": self.current_epoch,
# This could store the best validation loss
"best_validation_metric": (
best_validation_metric.item() if best_validation_metric is not None else None
),
"optimizer_state_dict": self.optimizer.state_dict(),
"lr_scheduler_state_dict": (
self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None
),
"torch_rng_state": self.torch_rng.get_state(),
"random_seed": self.random_seed,
}
state_path = ckpt_dir / "state.pt"
torch.save(state, state_path)
# Save PyTorch model separately
model_path = ckpt_dir / "model.pt"
# TODO: check that the state dict is stripped from any distributed info
torch.save(self.model.state_dict(), model_path)
# Save Pydantic config as YAML
config_path = ckpt_dir / "config.yaml"
with config_path.open("w") as f:
yaml.safe_dump(self.config.model_dump(), f)
# Log each file with an appropriate identifier
self.log(str(state_path), f"{name}_state", kind="artifact")
self.log(str(model_path), f"{name}_model", kind="artifact")
self.log(str(config_path), f"{name}_config", kind="artifact")
assert state_path.exists()
assert model_path.exists()
assert config_path.exists()
py_logger.info(f"Saved checkpoint at {ckpt_dir.resolve()}")
return str(ckpt_dir)
[docs]
def load_checkpoint(self) -> None:
"""Reload training state from checkpoint."""
if not self.from_checkpoint:
# A checkpoint path was NOT provided
return
# A checkpoint path was provided
py_logger.info(f"Loading from existing checkpoint at {self.from_checkpoint}")
if not isinstance(self.strategy, RayTorchDistributedStrategy):
# Not using Ray, falling back to simple checkpoint reload
py_logger.debug("Loading from existing checkpoint without using Ray")
self._load_checkpoint(checkpoint_dir=self.from_checkpoint)
return
# A Ray checkpoint directory was passed to the trainer -- assuming to be inside a trial
checkpoint = ray.train.get_checkpoint()
if not checkpoint:
py_logger.warning(
"A checkpoint path was passed, but Ray could not find a valid "
"checkpoint directory. Skipping loading from checkpoint."
)
return
with checkpoint.as_directory() as checkpoint_dir:
py_logger.debug("Loading from existing Ray checkpoint")
self._load_checkpoint(checkpoint_dir=checkpoint_dir)
def _load_checkpoint(self, checkpoint_dir: str | Path) -> None:
"""Load checkpoint from path."""
checkpoint_dir = Path(checkpoint_dir)
state = torch.load(checkpoint_dir / "state.pt")
# Override initial training state
self.model_state_dict = torch.load(checkpoint_dir / "model.pt")
self.optimizer_state_dict = state["optimizer_state_dict"]
self.lr_scheduler_state_dict = state["lr_scheduler_state_dict"]
self.torch_rng_state = state["torch_rng_state"]
# Direct overrides (don't require further attention)
self.random_seed = state["random_seed"]
self.current_epoch = state["epoch"] + 1 # Start from next epoch
if state["best_validation_metric"]:
self.best_validation_metric = state["best_validation_metric"]
[docs]
def create_dataloaders(
self,
train_dataset: Dataset,
validation_dataset: Dataset | None = None,
test_dataset: Dataset | None = None,
) -> None:
"""
Create train, validation and test dataloaders using the
configuration provided in the Trainer constructor.
Generally a user-defined method.
Args:
train_dataset (Dataset): training dataset object.
validation_dataset (Dataset | None): validation dataset object.
Default None.
test_dataset (Dataset | None): test dataset object.
Default None.
"""
###################################
# Dear user, this is a method you #
# may be interested to override! #
###################################
self.train_dataloader = self.strategy.create_dataloader(
dataset=train_dataset,
batch_size=self.config.batch_size,
num_workers=self.config.num_workers_dataloader,
pin_memory=self.config.pin_gpu_memory,
generator=self.torch_rng,
shuffle=self.config.shuffle_train,
)
if validation_dataset is not None:
self.validation_dataloader = self.strategy.create_dataloader(
dataset=validation_dataset,
batch_size=self.config.batch_size,
num_workers=self.config.num_workers_dataloader,
pin_memory=self.config.pin_gpu_memory,
generator=self.torch_rng,
shuffle=self.config.shuffle_validation,
)
if test_dataset is not None:
self.test_dataloader = self.strategy.create_dataloader(
dataset=test_dataset,
batch_size=self.config.batch_size,
num_workers=self.config.num_workers_dataloader,
pin_memory=self.config.pin_gpu_memory,
generator=self.torch_rng,
shuffle=self.config.shuffle_test,
)
def _setup_metrics(self) -> None:
"""Move metrics to current device."""
for m_name, metric in self.metrics.items():
self.metrics[m_name] = metric.to(self.device)
[docs]
@monitor_exec
def execute(
self,
train_dataset: Dataset,
validation_dataset: Dataset | None = None,
test_dataset: Dataset | None = None,
) -> Tuple[Dataset, Dataset, Dataset, Any]:
"""Prepares distributed environment and data structures
for the actual training.
Args:
train_dataset (Dataset): training dataset.
validation_dataset (Dataset | None, optional): validation
dataset. Defaults to None.
test_dataset (Dataset | None, optional): test dataset.
Defaults to None.
Returns:
Tuple[Dataset, Dataset, Dataset, Any]: training dataset,
validation dataset, test dataset, trained model.
"""
if isinstance(self.strategy, RayTorchDistributedStrategy):
# Execute with Ray
return self._execute_with_ray(
train_dataset=train_dataset,
validation_dataset=validation_dataset,
test_dataset=test_dataset,
)
# Execute without ray
if self.ray_scaling_config:
py_logger.warning(
"Ray scaling config was passed, but it's ignored as Ray is not used"
)
if self.ray_run_config:
py_logger.warning("Ray run config was passed, but it's ignored as Ray is not used")
if self.ray_tune_config:
py_logger.warning(
"Ray tune config was passed, but it's ignored as Ray is not used"
)
if self.ray_search_space:
py_logger.warning(
"Ray search space was passed, but it's ignored as Ray is not used"
)
if self.ray_horovod_config:
py_logger.warning(
"Ray horovod config was passed, but it's ignored as Ray is not used"
)
if self.ray_torch_config:
py_logger.warning(
"Ray torch config was passed, but it's ignored as Ray is not used"
)
if self.ray_data_config:
py_logger.warning(
"Ray dataset config was passed, but it's ignored as Ray is not used"
)
self._run_worker(
config={},
train_dataset=train_dataset,
validation_dataset=validation_dataset,
test_dataset=test_dataset,
)
return train_dataset, validation_dataset, test_dataset, None
def _execute_with_ray(
self,
train_dataset: Dataset,
validation_dataset: Dataset | None = None,
test_dataset: Dataset | None = None,
) -> Tuple[Dataset, Dataset, Dataset, Any]:
"""Launch training and, optionally, hyperarameter tuning with Ray"""
train_with_data = ray.tune.with_parameters(
self._run_worker,
train_dataset=train_dataset,
validation_dataset=validation_dataset,
test_dataset=test_dataset,
)
if self.ray_run_config:
# Create Ray checkpoints dir if it does not exist yet
ckpt_dir = Path(self.ray_run_config.storage_path)
ckpt_dir.mkdir(parents=True, exist_ok=True)
if isinstance(self.strategy, RayHorovodStrategy):
# Using Horovod with Ray
from ray.train.horovod import HorovodTrainer
if self.ray_torch_config:
py_logger.warning(
"Ray torch config was passed, but it's ignored as "
f"{self.strategy.__class__.__name__} strategy is used"
)
if self.from_checkpoint:
# Create trainer from checkpoint
if HorovodTrainer.can_restore(to_uri(self.from_checkpoint)):
trainer = HorovodTrainer.restore(
path=to_uri(self.from_checkpoint),
train_loop_per_worker=train_with_data,
train_loop_config=None,
)
else:
# Ray is unable to restore the checkpoint implicitly, but it's passing
# it to the trial
trainer = HorovodTrainer(
train_loop_per_worker=train_with_data,
train_loop_config=None,
horovod_config=self.ray_horovod_config,
scaling_config=self.ray_scaling_config,
run_config=self.ray_run_config,
dataset_config=self.ray_data_config,
resume_from_checkpoint=Checkpoint(to_uri(self.from_checkpoint)),
)
else:
# Create trainer without checkpoint
trainer = HorovodTrainer(
train_loop_per_worker=train_with_data,
train_loop_config=None,
horovod_config=self.ray_horovod_config,
scaling_config=self.ray_scaling_config,
run_config=self.ray_run_config,
dataset_config=self.ray_data_config,
)
else:
# Using DDP or DeepSpeed with Ray
if self.ray_horovod_config:
py_logger.warning(
"Ray horovod config was passed, but it's ignored as "
f"{self.strategy.__class__.__name__} strategy is used"
)
if self.from_checkpoint:
# Create trainer from checkpoint
if RayTorchTrainer.can_restore(to_uri(self.from_checkpoint)):
trainer = RayTorchTrainer.restore(
path=to_uri(self.from_checkpoint),
train_loop_per_worker=train_with_data,
train_loop_config=None,
)
else:
# Ray is unable to restore the checkpoint implicitly, but it's passing
# it to the trial
trainer = RayTorchTrainer(
train_loop_per_worker=train_with_data,
train_loop_config=None,
scaling_config=self.ray_scaling_config,
run_config=self.ray_run_config,
torch_config=self.ray_torch_config,
dataset_config=self.ray_data_config,
resume_from_checkpoint=Checkpoint(to_uri(self.from_checkpoint)),
)
else:
# Create trainer without checkpoint
trainer = RayTorchTrainer(
train_loop_per_worker=train_with_data,
train_loop_config=None,
scaling_config=self.ray_scaling_config,
run_config=self.ray_run_config,
torch_config=self.ray_torch_config,
dataset_config=self.ray_data_config,
)
# Wrap the trainer into a Tuner
param_space = {"train_loop_config": search_space(self.ray_search_space)}
tuner = ray.tune.Tuner(
trainable=trainer,
param_space=param_space,
tune_config=self.ray_tune_config,
)
if self.time_ray:
self.tune_result_grid = self._time_and_log(
lambda: tuner.fit(), "ray_fit_time_s", step=0
)
else:
self.tune_result_grid = tuner.fit()
return train_dataset, validation_dataset, test_dataset, None
def _run_worker(
self,
config: Dict,
train_dataset: Dataset,
validation_dataset: Dataset | None = None,
test_dataset: Dataset | None = None,
) -> None:
self.load_checkpoint()
self._override_config(config)
self._set_seed()
self._init_distributed_strategy()
self._setup_metrics()
if self.logger:
py_logger.debug(f"Using logger: {self.logger.__class__.__name__}")
self.logger.create_logger_context(rank=self.strategy.global_rank())
py_logger.debug("...the logger has been initialized")
hparams = self.config.model_dump()
hparams["distributed_strategy"] = self.strategy.__class__.__name__
self.logger.save_hyperparameters(hparams)
self.create_dataloaders(
train_dataset=train_dataset,
validation_dataset=validation_dataset,
test_dataset=test_dataset,
)
self.create_model_loss_optimizer()
self.train()
if self.logger:
self.logger.destroy_logger_context()
self.strategy.clean_up()
def _set_seed(self) -> None:
py_logger.debug(f"Using random seed: {self.random_seed}")
self.torch_rng = set_seed(self.random_seed)
if self.torch_rng_state is not None:
# Resume state from checkpoint
py_logger.debug("Resuming torch PRNG state from checkpoint")
self.torch_rng.set_state(self.torch_rng_state)
def _override_config(self, config: Dict) -> None:
"""Overrid self.config with a sample from the search space from the Ray tuner."""
self.config = self.config.model_copy(update=config)
py_logger.debug("Overridden self.config with trial config (if given)")
def _set_epoch_dataloaders(self, epoch: int) -> None:
"""Sets epoch in the distributed sampler of a dataloader when using it."""
if self.strategy.is_distributed:
self.train_dataloader.sampler.set_epoch(epoch)
if self.validation_dataloader is not None:
self.validation_dataloader.sampler.set_epoch(epoch)
if self.test_dataloader is not None:
self.test_dataloader.sampler.set_epoch(epoch)
[docs]
def set_epoch(self) -> None:
"""Set current epoch at the beginning of training."""
if self.profiler is not None and self.current_epoch > 0:
# We don't want to start stepping until after the first epoch
self.profiler.step()
if self.lr_scheduler:
self.lr_scheduler.step()
self._set_epoch_dataloaders(self.current_epoch)
[docs]
def log(
self,
item: Any | List[Any],
identifier: str | List[str],
kind: str = "metric",
step: int | None = None,
batch_idx: int | None = None,
**kwargs,
) -> None:
"""Log ``item`` with ``identifier`` name of ``kind`` type at ``step``
time step.
Args:
item (Any | List[Any]): element to be logged (e.g., metric).
identifier (str | List[str]): unique identifier for the
element to log(e.g., name of a metric).
kind (str, optional): type of the item to be logged. Must be one
among the list of self.supported_types. Defaults to 'metric'.
step (int | None, optional): logging step. Defaults to None.
batch_idx (int | None, optional): DataLoader batch counter
(i.e., batch idx), if available. Defaults to None.
"""
if self.logger:
self.logger.log(
item=item,
identifier=identifier,
kind=kind,
step=step,
batch_idx=batch_idx,
**kwargs,
)
[docs]
def ray_report(
self,
metrics: Dict[str, float],
checkpoint_file: str | Path | None = None,
checkpoint_dir: str | Path | None = None,
checkpoint_data: Any | None = None,
) -> None:
"""Report a dictionary of metrics and optionally a checkpoint to Ray, only when using
Ray distributed strategies. The checkpoint could be in the form of a Python object
(passed as ``checkpoint_data``), the path to a single file (passed as
``checkpoint_file``), or the path to an existing checkpoint directory (passed as
``checkpoint_dir``).
Args:
metrics (Dict[str, float]): metrics to be reported.
checkpoint_file (str | Path | None, optional): path to the checkpoint file.
Defaults to None.
checkpoint_dir (str | Path | None, optional): path to the checkpoint directory.
Defaults to None.
checkpoint_data (Any | None, optional):object to serialize as a checkpoint.
Defaults to None.
"""
if not isinstance(self.strategy, RayTorchDistributedStrategy):
# Ray is not used, thus do nothing
return
if checkpoint_file:
# A checkpoint is given as a file
with tempfile.TemporaryDirectory() as tmp_dir:
import shutil
shutil.copy(checkpoint_file, tmp_dir)
checkpoint = ray.train.Checkpoint.from_directory(tmp_dir)
ray.train.report(metrics, checkpoint=checkpoint)
elif checkpoint_data:
# A checkpoint is given as a python object which needs to be serialized
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir = Path(tmp_dir)
ckpt_file = tmp_dir / "ckpt.pt"
torch.save(checkpoint_data, ckpt_file)
checkpoint = ray.train.Checkpoint.from_directory(tmp_dir)
ray.train.report(metrics, checkpoint=checkpoint)
elif checkpoint_dir:
# A checkpoint is given as a directory
checkpoint = ray.train.Checkpoint.from_directory(checkpoint_dir)
ray.train.report(metrics, checkpoint=checkpoint)
else:
# No checkpoint is given: only report metrics
ray.train.report(metrics)
[docs]
def compute_metrics(
self,
true: Batch,
pred: Batch,
logger_step: int,
batch_idx: int | None,
stage: str = "train",
) -> Dict[str, Any]:
"""Compute and log metrics.
Args:
metrics (Dict[str, Callable]): metrics dict. Can be
``self.train_metrics`` or ``self.validation_metrics``.
true (Batch): true values.
pred (Batch): predicted values.
logger_step (int): global step to pass to the logger.
stage (str): 'train', 'validation'...
Returns:
Dict[str, Any]: metric values.
"""
m_values = {}
for m_name, metric in self.metrics.items():
# metric = metric.to(self.device)
m_val = metric(pred, true).detach().cpu().numpy()
self.log(
item=m_val,
identifier=f"{stage}_{m_name}",
kind="metric",
step=logger_step,
batch_idx=batch_idx,
)
m_values[m_name] = m_val
return m_values
[docs]
@profile_torch_trainer
@measure_gpu_utilization
def train(self) -> None:
"""Trains a machine learning model.
Main training loop/logic.
Args:
train_dataset (Dataset): training dataset.
validation_dataset (Dataset): validation dataset.
test_dataset (Dataset): test dataset.
Returns:
Tuple[Dataset, Dataset, Dataset, Any]: training dataset,
validation dataset, test dataset, trained model.
"""
epoch_time_logger: EpochTimeTracker | None = None
if self.strategy.is_main_worker and self.strategy.is_distributed:
if "SLURM_NNODES" not in os.environ:
raise EnvironmentError(
"'SLURM_NNODES' is not present in 'os.environ', but is required"
" when running distributed training!"
)
num_nodes = int(os.environ["SLURM_NNODES"])
epoch_time_output_dir = Path(f"scalability-metrics/{self.run_id}/epoch-time")
epoch_time_file_name = f"epochtime_{self.strategy.name}_{num_nodes}N.csv"
epoch_time_output_path = epoch_time_output_dir / epoch_time_file_name
epoch_time_logger = EpochTimeTracker(
strategy_name=self.strategy.name,
save_path=epoch_time_output_path,
num_nodes=num_nodes,
should_log=self.measure_epoch_time,
)
progress_bar = tqdm(
range(self.current_epoch, self.epochs),
desc="Epochs",
disable=self.disable_tqdm or not self.strategy.is_main_worker,
)
for self.current_epoch in progress_bar:
epoch_start_time = default_timer()
progress_bar.set_description(f"Epoch {self.current_epoch + 1}/{self.epochs}")
self.set_epoch()
self.train_epoch()
val_metric = self.validation_epoch()
# Periodic checkpointing
periodic_ckpt_path = self.save_checkpoint(name=f"epoch_{self.current_epoch}")
# Checkpointing current best model
best_ckpt_path = None
worker_val_metrics = self.strategy.gather(val_metric, dst_rank=0)
if self.strategy.is_main_worker:
avg_metric = torch.mean(torch.stack(worker_val_metrics)).detach().cpu()
if avg_metric < self.best_validation_metric:
best_ckpt_path = self.save_checkpoint(
name="best_model",
best_validation_metric=avg_metric,
force=True,
)
self.best_validation_metric = avg_metric
# Report validation metrics to Ray (useful for tuning!)
metric_name = _get_tuning_metric_name(self.ray_tune_config)
if metric_name is None:
raise ValueError("Could not find a metric in the TuneConfig")
if self.time_ray:
# time and log the ray_report call
self._time_and_log(
lambda: self.ray_report(
metrics={metric_name: val_metric.item()},
checkpoint_dir=best_ckpt_path or periodic_ckpt_path,
),
"ray_report_time_s_per_epoch",
step=self.current_epoch,
)
else:
self.ray_report(
metrics={metric_name: val_metric.item()},
checkpoint_dir=best_ckpt_path or periodic_ckpt_path,
)
if self.test_every and (self.current_epoch + 1) % self.test_every == 0:
self.test_epoch()
if self.strategy.is_main_worker and self.strategy.is_distributed:
assert epoch_time_logger is not None
epoch_time = default_timer() - epoch_start_time
epoch_time_logger.add_epoch_time(self.current_epoch + 1, epoch_time)
[docs]
def train_epoch(self) -> torch.Tensor:
"""Perform a complete sweep over the training dataset, completing an
epoch of training.
Args:
epoch (int): current epoch number, from 0 to ``self.epochs - 1``.
Returns:
Loss: average training loss for the current epoch.
"""
self.model.train()
train_loss_sum = 0.0
train_metrics_sum = defaultdict(float)
batch_counter = 0
progress_bar = tqdm(
enumerate(self.train_dataloader),
total=len(self.train_dataloader) // self.strategy.global_world_size(),
desc="Train batches",
disable=self.disable_tqdm or not self.strategy.is_main_worker,
leave=False, # Set this to true to see how many batches were used
)
for batch_idx, train_batch in progress_bar:
loss, metrics = self.train_step(batch=train_batch, batch_idx=batch_idx)
train_loss_sum += loss
batch_counter += 1
for name, val in metrics.items():
train_metrics_sum[name] += val
# Important: update counter
self.train_glob_step += 1
# Aggregate and log losses
avg_loss = train_loss_sum / batch_counter
self.log(
item=avg_loss.item(),
identifier="train_loss_epoch",
kind="metric",
step=self.train_glob_step,
)
# Aggregate and log metrics
for m_name, m_val in train_metrics_sum.items():
self.log(
item=m_val / batch_counter,
identifier="train_" + m_name + "_epoch",
kind="metric",
step=self.train_glob_step,
)
return avg_loss
[docs]
def train_step(self, batch: Batch, batch_idx: int) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Perform a single optimization step using a batch sampled from the
training dataset.
Args:
batch (Batch): batch sampled by a dataloader.
batch_idx (int): batch index in the dataloader.
Returns:
Tuple[Loss, Dict[str, Any]]: batch loss and dictionary of metric
values with the same structure of ``self.metrics``.
"""
x, y = batch
x, y = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad()
pred_y = self.model(x)
loss = self.loss(pred_y, y)
loss.backward()
self.optimizer.step()
# Log metrics
self.log(
item=loss.item(),
identifier="train_loss",
kind="metric",
step=self.train_glob_step,
batch_idx=batch_idx,
)
metrics: Dict[str, Any] = self.compute_metrics(
true=y,
pred=pred_y,
logger_step=self.train_glob_step,
batch_idx=batch_idx,
stage="train",
)
return loss, metrics
[docs]
def validation_epoch(self) -> torch.Tensor:
"""Perform a complete sweep over the validation dataset, completing an
epoch of validation.
Returns:
Loss | None: average validation loss for the current epoch if
self.validation_dataloader is not None
"""
if self.validation_dataloader is None:
return
progress_bar = tqdm(
enumerate(self.validation_dataloader),
total=len(self.validation_dataloader) // self.strategy.global_world_size(),
desc="Validation batches",
disable=self.disable_tqdm or not self.strategy.is_main_worker,
leave=False, # Set this to true to see how many batches were used
)
self.model.eval()
validation_loss_sum = 0.0
validation_metrics_sum = defaultdict(float)
batch_counter = 0
for batch_idx, val_batch in progress_bar:
loss, metrics = self.validation_step(batch=val_batch, batch_idx=batch_idx)
validation_loss_sum += loss
batch_counter += 1
for name, val in metrics.items():
validation_metrics_sum[name] += val
# Important: update counter
self.validation_glob_step += 1
# Aggregate and log losses
avg_loss = validation_loss_sum / batch_counter
self.log(
item=avg_loss.item(),
identifier="validation_loss_epoch",
kind="metric",
step=self.validation_glob_step,
)
# Aggregate and log metrics
for m_name, m_val in validation_metrics_sum.items():
self.log(
item=m_val / batch_counter,
identifier="validation_" + m_name + "_epoch",
kind="metric",
step=self.validation_glob_step,
)
return avg_loss
[docs]
def validation_step(
self, batch: Batch, batch_idx: int
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Perform a single optimization step using a batch sampled from the
validation dataset.
Args:
batch (Batch): batch sampled by a dataloader.
batch_idx (int): batch index in the dataloader.
Returns:
Tuple[Loss, Dict[str, Any]]: batch loss and dictionary of metric
values with the same structure of ``self.metrics``.
"""
x, y = batch
x, y = x.to(self.device), y.to(self.device)
with torch.no_grad():
pred_y = self.model(x)
loss: torch.Tensor = self.loss(pred_y, y)
self.log(
item=loss.item(),
identifier="validation_loss",
kind="metric",
step=self.validation_glob_step,
batch_idx=batch_idx,
)
metrics: Dict[str, Any] = self.compute_metrics(
true=y,
pred=pred_y,
logger_step=self.validation_glob_step,
batch_idx=batch_idx,
stage="validation",
)
return loss, metrics
[docs]
def test_epoch(self) -> torch.Tensor:
"""Perform a complete sweep over the test dataset, completing an
epoch of test.
Returns:
Loss: average test loss for the current epoch.
"""
raise NotImplementedError()
[docs]
def test_step(self, batch: Batch, batch_idx: int) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Perform a single predictions step using a batch sampled from the
test dataset.
Args:
batch (Batch): batch sampled by a dataloader.
batch_idx (int): batch index in the dataloader.
Returns:
Tuple[Loss, Dict[str, Any]]: batch loss and dictionary of metric
values with the same structure of ``self.metrics``.
"""
raise NotImplementedError()
[docs]
class TorchLightningTrainer(Trainer):
"""Generic trainer for torch Lightning workflows.
Args:
config (Dict | str): `Lightning configuration`_
which can be the path to a file or a Python dictionary.
mlflow_saved_model (str, optional): name of the model created in
MLFlow. Defaults to 'my_model'.
.. _Lightning configuration:
https://pytorch-lightning.readthedocs.io/en/1.6.5/common/lightning_cli.html
"""
def __init__(self, config: Dict | str, mlflow_saved_model: str = "my_model"):
self.save_parameters(**self.locals2params(locals()))
super().__init__()
if isinstance(config, str) and os.path.isfile(config):
# Load from YAML
config = load_yaml(config)
self.conf = config
self.mlflow_saved_model = mlflow_saved_model
[docs]
@monitor_exec
def execute(self) -> Any:
import lightning as L
from lightning.pytorch.cli import LightningCLI
from .mlflow import init_lightning_mlflow, teardown_lightning_mlflow
init_lightning_mlflow(
self.conf, tmp_dir="/tmp", registered_model_name=self.mlflow_saved_model
)
old_argv = sys.argv
sys.argv = ["some_script_placeholder.py"]
cli = LightningCLI(
args=self.conf,
model_class=L.LightningModule,
datamodule_class=L.LightningDataModule,
run=False,
save_config_kwargs={
"overwrite": True,
"config_filename": "pl-training.yml",
},
subclass_mode_model=True,
subclass_mode_data=True,
)
sys.argv = old_argv
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
teardown_lightning_mlflow()
def _distributed_dataloader(dataloader: DataLoader, gwsize, grank):
"""Makes a Dataloader distributed."""
sampler = DistributedSampler(
dataloader.dataset, num_replicas=gwsize, rank=grank, shuffle=True
)
# Recreate dataloader, with updated sampler
return DataLoader(
dataloader.dataset,
batch_size=dataloader.batch_size,
sampler=sampler,
num_workers=dataloader.num_workers,
collate_fn=dataloader.collate_fn,
pin_memory=dataloader.pin_memory,
drop_last=dataloader.drop_last,
timeout=dataloader.timeout,
worker_init_fn=seed_worker, # dataloader.worker_init_fn,
multiprocessing_context=dataloader.multiprocessing_context,
generator=dataloader.generator,
prefetch_factor=dataloader.prefetch_factor,
persistent_workers=dataloader.persistent_workers,
pin_memory_device=dataloader.pin_memory_device,
)
[docs]
def distributed(func):
"""The decorated function must have a standard signature.
Its first arguments must be:
model, train_dataloader, validation_dataloader, device (in this order).
Additional args or kwargs are allowed consistently with the signature
of the decorated function.
"""
def dist_train(
model, train_dataloader, validation_dataloader=None, device="cpu", *args, **kwargs
):
if torch.cuda.is_available():
dist.init_process_group(backend="nccl")
if torch.cuda.is_available():
lwsize = torch.cuda.device_count() # local world size - per node
gwsize = dist.get_world_size() # global world size - per run
grank = dist.get_rank() # global rank - assign per run
lrank = dist.get_rank() % lwsize # local rank - assign per node
else:
gwsize = 1
grank = 0
lrank = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu", lrank)
if torch.cuda.is_available():
torch.cuda.set_device(lrank)
model = model.to(device)
model = DDP(model, device_ids=[device], output_device=device)
train_dataloader = _distributed_dataloader(train_dataloader, gwsize, grank)
if validation_dataloader is not None:
validation_dataloader = _distributed_dataloader(
validation_dataloader, gwsize, grank
)
try:
func(model, train_dataloader, validation_dataloader, device, *args, **kwargs)
finally:
if torch.cuda.is_available():
dist.barrier()
dist.destroy_process_group()
return dist_train