# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# - Anna Lappe <anna.elisa.lappe@cern.ch> - CERN
# - Jarl Sondre Sæther <jarl.sondre.saether@cern.ch> - CERN
# - Linus Eickhoff <linus.maximilian.eickhoff@cern.ch> - CERN
# --------------------------------------------------------------------------------------
"""Provides training logic for PyTorch models via Trainer classes."""
import logging
import sys
import tempfile
from collections import defaultdict
from functools import partial
from pathlib import Path
from time import perf_counter
from typing import Any, Callable, Dict, List, Literal, Tuple
import mlflow
import ray.train
import ray.tune
import torch
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler
import yaml
from ray.train import Checkpoint, DataConfig, ScalingConfig
from ray.train.torch import TorchConfig
from ray.train.torch import TorchTrainer as RayTorchTrainer
from ray.tune import RunConfig, TuneConfig
from torch.optim import SGD, Adadelta, Adam, AdamW, RMSprop
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Metric
from tqdm import tqdm
from itwinai.torch.monitoring.monitoring import measure_gpu_utilization
from itwinai.torch.profiling.profiler import profile_torch_trainer
from ..components import Trainer, monitor_exec
from ..distributed import ray_cluster_is_running
from ..loggers import Logger, LogMixin, get_mlflow_logger
from ..utils import generate_random_name, load_yaml, time_and_log, to_uri
from .config import TrainingConfiguration
from .distributed import (
DeepSpeedStrategy,
HorovodStrategy,
NonDistributedStrategy,
RayDDPStrategy,
RayDeepSpeedStrategy,
RayTorchDistributedStrategy,
TorchDDPStrategy,
TorchDistributedStrategy,
distributed_resources_available,
)
from .reproducibility import set_seed
from .tuning import search_space
py_logger = logging.getLogger(__name__)
def _get_tuning_metric_name(tune_config: TuneConfig | None) -> str:
"""Extracts the metric name from TuneConfig or scheduler in a generic way."""
DEFAULT_NAME = "loss"
if not tune_config:
return DEFAULT_NAME
# Try to get from TuneConfig
if tune_config.metric:
return tune_config.metric
# Try to get from the scheduler (if defined)
scheduler = tune_config.scheduler
if scheduler and hasattr(scheduler, "metric") and scheduler.metric:
return scheduler.metric
return DEFAULT_NAME
[docs]
class TorchTrainer(Trainer, LogMixin):
"""Trainer class for torch training algorithms.
Args:
config (Dict | TrainingConfiguration): training configuration containing
hyperparameters.
epochs (int): number of training epochs.
model (nn.Module | None, optional): pytorch model to train or a string identifier.
Defaults to None.
strategy (Literal['ddp', 'deepspeed', 'horovod'], optional): distributed strategy.
Defaults to 'ddp'.
test_every (int | None, optional): run a test epoch every ``test_every`` epochs.
Disabled if None. Defaults to None.
random_seed (int | None, optional): set random seed for reproducibility. If None, the
seed is not set. Defaults to None.
logger (Logger | None, optional): logger for ML tracking. Defaults to None.
metrics (Dict[str, Callable] | None, optional): map of torchmetrics metrics. Defaults
to None.
checkpoints_location (str): path to checkpoints directory. Defaults to "checkpoints".
checkpoint_every (int | None): save a checkpoint every ``checkpoint_every`` epochs.
Disabled if None. Defaults to None.
disable_tqdm (bool): whether to disable tqdm progress bar(s).
name (str | None, optional): trainer custom name. Defaults to None.
profiling_wait_epochs (int): how many epochs to wait before starting the profiler.
profiling_warmup_epochs (int): length of the profiler warmup phase in terms of number
of epochs.
measure_gpu_data (bool): enable the collection of data on average GPU utilization and
total energy consumption throughout training. Defaults to False.
enable_torch_profiling (bool): enable the profiling of computation. It uses the torch
profiler and it may slow down training. Defaults to False.
measure_epoch_time (bool): enable the measurement of epoch duration (in seconds).
Defaults to False.
ray_scaling_config (ScalingConfig, optional): scaling config for Ray Trainer. Defaults
to None.
ray_tune_config (TuneConfig, optional): tune config for Ray Tuner. Defaults to None.
ray_run_config (ray.tune.RunConfig, optional): run config for Ray Tuner. Distributed
training with Ray but without HPO will still be wrapped into a Ray Tuner, to keep
everything homogeneous. Defaults to None.
ray_search_space (Dict[str, Any], optional): search space for Ray Tuner. Defaults to
None.
ray_torch_config (TorchConfig, optional): torch configuration for Ray's TorchTrainer.
Defaults to None.
ray_data_config (DataConfig, optional): dataset configuration for Ray. Defaults to
None.
from_checkpoint (str | Path, optional): path to checkpoint directory. Defaults to None.
initial_best_validation_metric (str): initial value for the best validation metric.
Usually the validation metric is a loss to be minimized and this value exceeds the
highest possible loss value, so that it will be overwritten when the first
validation loss is computed. Example values are "inf" and "-inf", depending on
wether the best validation metric should be minimized or maximized.
Defaults to "inf".
run_name (str, optional): name used to identify a specific run when collecting
metrics on the trainer (e.g. GPU utilization). Defaults to None.
time_ray (bool): whether to time and log the execution of Ray functions. Defaults to
False.
"""
_strategy: TorchDistributedStrategy | None = None
#: PyTorch ``DataLoader`` for training dataset.
train_dataloader: DataLoader | None = None
#: PyTorch ``DataLoader`` for validation dataset.
validation_dataloader: DataLoader | None = None
#: PyTorch ``DataLoader`` for test dataset.
test_dataloader: DataLoader | None = None
#: PyTorch model to train.
model: nn.Module | None = None
#: Loss criterion.
loss: Callable | None = None
#: Optimizer.
optimizer: Optimizer | None = None
#: Learning rate scheduler.
lr_scheduler: LRScheduler | None = None
#: PyTorch random number generator (PRNG).
torch_rng: torch.Generator | None = None
#: itwinai ``itwinai.Logger``
logger: Logger | None = None
#: Total number training batches used so far, across all epochs.
train_glob_step: int = 0
#: Total number validation batches used so far, across all epochs.
validation_glob_step: int = 0
#: Total number test batches used so far, across all epochs.
test_glob_step: int = 0
#: Dictionary of ``torchmetrics`` metrics, indexed by user-defined names.
metrics: Dict[str, Callable]
#: PyTorch Profiler for computation ratio profiling.
profiler: Any | None
#: Toggle for GPU utilization monitoring
measure_gpu_data: bool = False
#: Toggle for computation fraction profiling
enable_torch_profiling: bool = False
#: Store PyTorch Profiling traces
store_torch_profiling_traces: bool = False
#: Toggle for epoch time tracking
measure_epoch_time: bool = False
#: Run ID
run_name: str
#: Toggle for Ray time logging
time_ray: bool = False
# Tune run id for nested runs in mlflow
mlflow_tune_run_id: str | None = None
# train run id
mlflow_train_run_id: str | None = None
# worker run_id
mlflow_worker_run_id: str | None = None
def __init__(
self,
config: Dict | TrainingConfiguration,
epochs: int,
model: nn.Module | None = None,
strategy: Literal["ddp", "deepspeed", "horovod"] = "ddp",
test_every: int | None = None,
random_seed: int | None = None,
logger: Logger | None = None,
metrics: Dict[str, Metric] | None = None,
checkpoints_location: str | Path = "checkpoints",
checkpoint_every: int | None = None,
disable_tqdm: bool = False,
name: str | None = None,
profiling_wait_epochs: int = 0,
profiling_warmup_epochs: int = 0,
measure_gpu_data: bool = False,
enable_torch_profiling: bool = False,
store_torch_profiling_traces: bool = False,
measure_epoch_time: bool = False,
ray_scaling_config: ScalingConfig | None = None,
ray_tune_config: TuneConfig | None = None,
ray_run_config: RunConfig | None = None,
ray_search_space: Dict[str, Any] | None = None,
ray_torch_config: TorchConfig | None = None,
ray_data_config: DataConfig | None = None,
from_checkpoint: Path | str | None = None,
initial_best_validation_metric: str = "inf",
run_name: str | None = None,
time_ray: bool = False,
) -> None:
super().__init__(name)
self.save_parameters(**self.locals2params(locals()))
# config is meant to store all hyperparameters, which can vary from use
# case to use case and include learning_rate, batch_size....
config = {} if config is None else config
if isinstance(config, dict):
config = TrainingConfiguration(**config)
if store_torch_profiling_traces and not enable_torch_profiling:
raise ValueError(
"`store_torch_profiling_traces` is True, but `enable_torch_profiling` is"
" False. Cannot store traces without enabling profiling."
)
self.config = config
self.epochs = epochs
self.model = model
self.strategy = strategy
self.test_every = test_every
self.random_seed = random_seed
self.logger = logger
self.metrics = metrics if metrics is not None else {}
self.checkpoints_location = checkpoints_location
Path(self.checkpoints_location).mkdir(exist_ok=True, parents=True)
self.checkpoint_every = checkpoint_every
self.disable_tqdm = disable_tqdm
self.profiler = None
self.profiling_wait_epochs = profiling_wait_epochs
self.profiling_warmup_epochs = profiling_warmup_epochs
self.measure_gpu_data = measure_gpu_data
self.enable_torch_profiling = enable_torch_profiling
self.store_torch_profiling_traces = store_torch_profiling_traces
self.measure_epoch_time = measure_epoch_time
self.ray_scaling_config = ray_scaling_config
self.ray_tune_config = ray_tune_config
self.ray_run_config = ray_run_config
self.ray_search_space = ray_search_space
self.ray_torch_config = ray_torch_config
self.ray_data_config = ray_data_config
self.from_checkpoint = from_checkpoint
self.time_ray = time_ray
if self.from_checkpoint:
self.from_checkpoint = Path(self.from_checkpoint)
if not self.from_checkpoint.exists():
raise RuntimeError(
"from_checkpoint argument was passed, but the checkpoint is not found "
f"at {self.from_checkpoint}"
)
if self.checkpoints_location:
Path(self.checkpoints_location).mkdir(exist_ok=True, parents=True)
py_logger.debug(f"ray_scaling_config: {ray_scaling_config}")
py_logger.debug(f"ray_tune_config: {ray_tune_config}")
py_logger.debug(f"ray_run_config: {ray_run_config}")
py_logger.debug(f"ray_torch_config: {ray_torch_config}")
py_logger.debug(f"ray_data_config: {ray_data_config}")
py_logger.debug(f"ray_search_space: {ray_search_space}")
# Initial training state -- can be resumed from a checkpoint
self.model_state_dict = None
self.optimizer_state_dict = None
self.lr_scheduler_state_dict = None
self.torch_rng_state = None
# This is initialized to inf as it usually represents a loss to minimize.
# If the validation metric is meant to be maximized, change this to -inf.
self.best_validation_metric = float(initial_best_validation_metric)
self.current_epoch = 0
self.mlflow_logger = get_mlflow_logger(logger)
if run_name is None:
run_name = generate_random_name()
self.run_name = run_name
@property
def strategy(self) -> TorchDistributedStrategy:
"""Strategy currently in use."""
assert self._strategy is not None, "Expected strategy to be initialized before access"
return self._strategy
@strategy.setter
def strategy(self, strategy: str | TorchDistributedStrategy) -> None:
if isinstance(strategy, TorchDistributedStrategy):
self._strategy = strategy
else:
self._strategy = self._detect_distributed_strategy(strategy)
@property
def device(self) -> str:
"""Current device from distributed strategy."""
return self.strategy.device()
def _detect_distributed_strategy(self, strategy: str) -> TorchDistributedStrategy:
"""When a Ray cluster is detected the Ray-equivalent distributed strategy is
automatically selected, without needing the user to explicitly set it.
"""
py_logger.debug(f"Strategy was set to {strategy}")
dist_resources = distributed_resources_available()
ray_cluster_running = ray_cluster_is_running()
enough_resources = dist_resources or ray_cluster_running
py_logger.debug(
f"Enough resources? {enough_resources}"
f" (distributed_resources_available: {dist_resources},"
f" ray_cluster_is_running: {ray_cluster_running})"
)
# NOTE: setting strategy to None prevents the trainer to run distributed ML, regardless
# of the availability of the resources.
if strategy is None or not enough_resources:
py_logger.warning("Falling back to non-distributed strategy.")
return NonDistributedStrategy()
if ray_cluster_is_running():
py_logger.info(
f"Ray cluster was detected, thus the Ray equivalent for {strategy} is used"
)
match strategy, ray_cluster_is_running():
case "ddp", True:
return RayDDPStrategy()
case "ddp", False:
return TorchDDPStrategy(backend=self.config.dist_backend)
case "horovod", True:
py_logger.warning(
"Horovod strategy is no longer supported with Ray V2. See "
"https://github.com/ray-project/ray/issues/49454#issuecomment-2899138398. "
"Falling back to HorovodStrategy without Ray."
)
return HorovodStrategy()
case "horovod", False:
return HorovodStrategy()
case "deepspeed", True:
return RayDeepSpeedStrategy(backend=self.config.dist_backend)
case "deepspeed", False:
return DeepSpeedStrategy(backend=self.config.dist_backend)
case _:
raise ValueError(f"Strategy '{strategy}' is not recognized.")
def _init_distributed_strategy(self) -> None:
if not self.strategy.is_initialized:
self.strategy.init()
def _set_optimizer_from_config(self) -> None:
if self.model is None:
raise ValueError(
"self.model should be initialized before setting optimizer from configuration."
)
match self.config.optimizer:
case "adadelta":
self.optimizer = Adadelta(
self.model.parameters(),
lr=self.config.optim_lr,
weight_decay=self.config.optim_weight_decay,
)
case "adam":
self.optimizer = Adam(
self.model.parameters(),
lr=self.config.optim_lr,
betas=self.config.optim_betas,
weight_decay=self.config.optim_weight_decay,
)
case "adamw":
self.optimizer = AdamW(
self.model.parameters(),
lr=self.config.optim_lr,
betas=self.config.optim_betas,
weight_decay=self.config.optim_weight_decay,
)
case "rmsprop":
self.optimizer = RMSprop(
self.model.parameters(),
lr=self.config.optim_lr,
weight_decay=self.config.optim_weight_decay,
momentum=self.config.optim_momentum,
)
case "sgd":
self.optimizer = SGD(
self.model.parameters(),
lr=self.config.optim_lr,
weight_decay=self.config.optim_weight_decay,
momentum=self.config.optim_momentum,
)
case _:
raise ValueError(
"Unrecognized self.config.optimizer! Check the docs for "
"supported values and consider overriding "
"create_model_loss_optimizer method for more flexibility."
)
def _set_lr_scheduler_from_config(self) -> None:
"""Parse Lr scheduler from training config"""
if not self.config.lr_scheduler:
return
if not self.optimizer:
raise ValueError("Trying to instantiate a LR scheduler but the optimizer is None!")
match self.config.lr_scheduler:
case "constant":
self.lr_scheduler = lr_scheduler.ConstantLR(self.optimizer)
case "polynomial":
self.lr_scheduler = lr_scheduler.PolynomialLR(self.optimizer)
case "exponential":
self.lr_scheduler = lr_scheduler.ExponentialLR(
self.optimizer, gamma=self.config.lr_scheduler_gamma
)
case "linear":
self.lr_scheduler = lr_scheduler.LinearLR(self.optimizer)
case "multistep":
self.lr_scheduler = lr_scheduler.MultiStepLR(
self.optimizer, milestones=self.config.lr_scheduler_step_size
)
case "step":
self.lr_scheduler = lr_scheduler.StepLR(
self.optimizer, step_size=self.config.lr_scheduler_step_size
)
case _:
raise ValueError(
"Unrecognized self.config.lr_scheduler! Check the docs for "
"supported values and consider overriding "
"create_model_loss_optimizer method for more flexibility."
)
def _set_loss_from_config(self) -> None:
match self.config.loss:
case "nllloss":
self.loss = nn.functional.nll_loss
case "cross_entropy":
self.loss = nn.functional.cross_entropy
case "mse":
self.loss = nn.functional.mse_loss
case "bceloss":
self.loss = nn.functional.binary_cross_entropy
case _:
raise ValueError(
"Unrecognized self.config.loss! Check the docs for "
"supported values and consider overriding "
"create_model_loss_optimizer method for more flexibility."
)
[docs]
def get_default_distributed_kwargs(self) -> Dict:
"""Gives the default kwargs for the trainer's strategy's distributed() method."""
if isinstance(self.strategy, DeepSpeedStrategy):
# Batch size definition is not optional for DeepSpeedStrategy!
distribute_kwargs = {
"config_params": {"train_micro_batch_size_per_gpu": self.config.batch_size}
}
elif isinstance(self.strategy, HorovodStrategy):
import horovod.torch as hvd
distribute_kwargs = {
"compression": (
hvd.Compression.fp16
if self.config.fp16_allreduce
else hvd.Compression.none
),
"op": hvd.Adasum if self.config.use_adasum else hvd.Average,
"gradient_predivide_factor": self.config.gradient_predivide_factor,
}
else:
distribute_kwargs = {}
return distribute_kwargs
[docs]
def create_model_loss_optimizer(self) -> None:
"""Instantiate a torch model, loss, optimizer, and LR scheduler using the
configuration provided in the Trainer constructor.
Generally a user-defined method.
Raises:
ValueError: If ``self.model`` is None.
"""
###################################
# Dear user, this is a method you #
# may be interested to override! #
###################################
# Model, optimizer, and lr scheduler may have already been loaded from a checkpoint
if self.model is None:
raise ValueError(
"self.model is None! Either pass it to the constructor, load a checkpoint, or "
"override create_model_loss_optimizer method."
)
if self.model_state_dict:
# Load model from checkpoint
self.model.load_state_dict(self.model_state_dict, strict=False)
# Parse optimizer from training configuration
# Optimizer can be changed with a custom one here!
self._set_optimizer_from_config()
# Parse LR scheduler from training configuration
# LR scheduler can be changed with a custom one here!
self._set_lr_scheduler_from_config()
if self.optimizer_state_dict:
# Load optimizer state from checkpoint
# IMPORTANT: this must be after the learning rate scheduler was already initialized
# by passing to it the optimizer. Otherwise the optimizer state just loaded will
# be modified by the lr scheduler.
self.optimizer.load_state_dict(self.optimizer_state_dict)
if self.lr_scheduler_state_dict and self.lr_scheduler:
# Load LR scheduler state from checkpoint
self.lr_scheduler.load_state_dict(self.lr_scheduler_state_dict)
# Parse loss from training configuration
# Loss can be changed with a custom one here!
self._set_loss_from_config()
# IMPORTANT: model, optimizer, and scheduler need to be distributed from here on
distribute_kwargs = self.get_default_distributed_kwargs()
# Distributed model, optimizer, and scheduler
(self.model, self.optimizer, self.lr_scheduler) = self.strategy.distributed(
self.model, self.optimizer, self.lr_scheduler, **distribute_kwargs
)
[docs]
def save_checkpoint(
self,
name: str,
best_validation_metric: torch.Tensor | None = None,
checkpoints_root: str | Path | None = None,
force: bool = False,
) -> str | None:
"""Save training checkpoint.
Args:
name (str): name of the checkpoint directory.
best_validation_metric (torch.Tensor | None): best validation metric throughout
training so far (if available). Usually this is the validation loss.
checkpoints_root (str | None): path for root checkpoints dir. If None, uses
``self.checkpoints_location`` as base.
force (bool): force checkpointing now.
Returns:
path to the checkpoint file or ``None`` when the checkpoint is not created.
"""
# Determine whether a checkpoint should be created
should_checkpoint = self.strategy.is_main_worker and (
force
or self.checkpoint_every
and (self.current_epoch + 1) % self.checkpoint_every == 0
)
ckpt_dir = Path(checkpoints_root or self.checkpoints_location) / name
py_logger.info(f"Saving checkpoint at {ckpt_dir.resolve()}? {should_checkpoint}")
if not should_checkpoint:
# Do nothing and return
return
ckpt_dir = Path(checkpoints_root or self.checkpoints_location) / name
ckpt_dir.mkdir(parents=True, exist_ok=True)
# Save state (epoch, loss, optimizer, scheduler)
state = {
"epoch": self.current_epoch,
# This could store the best validation loss
"best_validation_metric": (
best_validation_metric.item() if best_validation_metric is not None else None
),
"optimizer_state_dict": self.optimizer.state_dict(),
"lr_scheduler_state_dict": (
self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None
),
"torch_rng_state": self.torch_rng.get_state(),
"random_seed": self.random_seed,
}
state_path = ckpt_dir / "state.pt"
torch.save(state, state_path)
# Save PyTorch model separately
model_path = ckpt_dir / "model.pt"
# TODO: check that the state dict is stripped from any distributed info
torch.save(self.model.state_dict(), model_path)
# Save Pydantic config as YAML
config_path = ckpt_dir / "config.yaml"
with config_path.open("w") as f:
yaml.safe_dump(self.config.model_dump(), f)
# Log each file with an appropriate identifier
self.log(str(state_path), f"{name}_state", kind="artifact")
self.log(str(model_path), f"{name}_model", kind="artifact")
self.log(str(config_path), f"{name}_config", kind="artifact")
assert state_path.exists()
assert model_path.exists()
assert config_path.exists()
py_logger.info(f"Saved checkpoint at {ckpt_dir.resolve()}")
return str(ckpt_dir)
[docs]
def load_checkpoint(self) -> None:
"""Reload training state from checkpoint."""
if not self.from_checkpoint:
# A checkpoint path was NOT provided
return
# A checkpoint path was provided
py_logger.info(f"Loading from existing checkpoint at {self.from_checkpoint}")
if not isinstance(self.strategy, RayTorchDistributedStrategy):
# Not using Ray, falling back to simple checkpoint reload
py_logger.debug("Loading from existing checkpoint without using Ray")
self._load_checkpoint(checkpoint_dir=self.from_checkpoint)
return
# A Ray checkpoint directory was passed to the trainer -- assuming to be inside a trial
checkpoint = ray.train.get_checkpoint()
if not checkpoint:
py_logger.warning(
"A checkpoint path was passed, but Ray could not find a valid "
"checkpoint directory. Skipping loading from checkpoint."
)
return
with checkpoint.as_directory() as checkpoint_dir:
py_logger.debug("Loading from existing Ray checkpoint")
self._load_checkpoint(checkpoint_dir=checkpoint_dir)
def _load_checkpoint(self, checkpoint_dir: str | Path) -> None:
"""Load checkpoint from path."""
checkpoint_dir = Path(checkpoint_dir)
state = torch.load(checkpoint_dir / "state.pt")
# Override initial training state
self.model_state_dict = torch.load(checkpoint_dir / "model.pt")
self.optimizer_state_dict = state["optimizer_state_dict"]
self.lr_scheduler_state_dict = state["lr_scheduler_state_dict"]
self.torch_rng_state = state["torch_rng_state"]
# Direct overrides (don't require further attention)
self.random_seed = state["random_seed"]
self.current_epoch = state["epoch"] + 1 # Start from next epoch
if state["best_validation_metric"]:
self.best_validation_metric = state["best_validation_metric"]
[docs]
def create_dataloaders(
self,
train_dataset: Dataset,
validation_dataset: Dataset | None = None,
test_dataset: Dataset | None = None,
) -> None:
"""
Create train, validation and test dataloaders using the
configuration provided in the Trainer constructor.
Generally a user-defined method.
Args:
train_dataset (Dataset): training dataset object.
validation_dataset (Dataset | None): validation dataset object.
Default None.
test_dataset (Dataset | None): test dataset object.
Default None.
"""
###################################
# Dear user, this is a method you #
# may be interested to override! #
###################################
self.train_dataloader = self.strategy.create_dataloader(
dataset=train_dataset,
batch_size=self.config.batch_size,
num_workers=self.config.num_workers_dataloader,
pin_memory=self.config.pin_gpu_memory,
generator=self.torch_rng,
shuffle=self.config.shuffle_train,
)
if validation_dataset is not None:
self.validation_dataloader = self.strategy.create_dataloader(
dataset=validation_dataset,
batch_size=self.config.batch_size,
num_workers=self.config.num_workers_dataloader,
pin_memory=self.config.pin_gpu_memory,
generator=self.torch_rng,
shuffle=self.config.shuffle_validation,
)
if test_dataset is not None:
self.test_dataloader = self.strategy.create_dataloader(
dataset=test_dataset,
batch_size=self.config.batch_size,
num_workers=self.config.num_workers_dataloader,
pin_memory=self.config.pin_gpu_memory,
generator=self.torch_rng,
shuffle=self.config.shuffle_test,
)
def _setup_metrics(self) -> None:
"""Move metrics to current device."""
for m_name, metric in self.metrics.items():
self.metrics[m_name] = metric.to(self.device)
[docs]
@monitor_exec
def execute(
self,
train_dataset: Dataset,
validation_dataset: Dataset | None = None,
test_dataset: Dataset | None = None,
) -> Tuple[Dataset, Dataset, Dataset, Any]:
"""Prepares distributed environment and data structures
for the actual training.
Args:
train_dataset (Dataset): training dataset.
validation_dataset (Dataset | None, optional): validation
dataset. Defaults to None.
test_dataset (Dataset | None, optional): test dataset.
Defaults to None.
Returns:
Tuple[Dataset, Dataset, Dataset, Any]: training dataset,
validation dataset, test dataset, trained model.
"""
if isinstance(self.strategy, RayTorchDistributedStrategy):
# Execute with Ray
return self._execute_with_ray(
train_dataset=train_dataset,
validation_dataset=validation_dataset,
test_dataset=test_dataset,
)
# Execute without ray
if self.ray_scaling_config:
py_logger.warning(
"Ray scaling config was passed, but it's ignored as Ray is not used"
)
if self.ray_run_config:
py_logger.warning("Ray run config was passed, but it's ignored as Ray is not used")
if self.ray_tune_config:
py_logger.warning(
"Ray tune config was passed, but it's ignored as Ray is not used"
)
if self.ray_search_space:
py_logger.warning(
"Ray search space was passed, but it's ignored as Ray is not used"
)
if self.ray_torch_config:
py_logger.warning(
"Ray torch config was passed, but it's ignored as Ray is not used"
)
if self.ray_data_config:
py_logger.warning(
"Ray dataset config was passed, but it's ignored as Ray is not used"
)
self._run_worker(
config={},
train_dataset=train_dataset,
validation_dataset=validation_dataset,
test_dataset=test_dataset,
)
return train_dataset, validation_dataset, test_dataset, None
def _execute_with_ray(
self,
train_dataset: Dataset,
validation_dataset: Dataset | None = None,
test_dataset: Dataset | None = None,
) -> Tuple[Dataset, Dataset, Dataset, Any]:
"""Launch training and, optionally, hyperparameter tuning with Ray
Args:
train_dataset (Dataset): training dataset.
validation_dataset (Dataset | None, optional): validation
dataset. Defaults to None.
test_dataset (Dataset | None, optional): test dataset.
Defaults to None.
Returns:
Dataset: The training dataset
Dataset: The validation dataset
Dataset: The test dataset
Any: The trained model
"""
if self.ray_run_config and self.ray_run_config.storage_path:
# Create Ray checkpoints dir if it does not exist yet
ckpt_dir = Path(self.ray_run_config.storage_path)
ckpt_dir.mkdir(parents=True, exist_ok=True)
if (
self.ray_scaling_config
and getattr(self.ray_scaling_config, "num_workers", 1) > 1
and getattr(self.ray_scaling_config.resources_per_worker, "GPU", 0) > 0.0
and getattr(self.ray_scaling_config.resources_per_worker, "GPU", 0) < 1.0
):
raise ValueError(
"Distributed trials with fractional gpu resources are not supported."
" Please ensure that either num_workers is set to 1 or GPUs in"
" resources_per_worker is 0 or 1"
)
if (
self.ray_tune_config
and self.ray_tune_config.scheduler is not None
and self.measure_gpu_data
):
py_logger.info(
"A Trial scheduler for Ray is specified"
f" ({type(self.ray_tune_config.scheduler)}), while measuring gpu data."
" Trials stopped by the scheduler might not close logger context in time,"
" leaving the status of the mlflow run in 'pending'. This is just a visual"
" caveat and can be ignored."
)
if self.mlflow_logger:
# Create mlflow runs per trial (will be started by the trial's main worker)
tune_run = self.mlflow_logger.mlflow.start_run(
experiment_id=self.mlflow_logger.resolve_experiment_id(),
run_name=self.run_name,
)
self.mlflow_tune_run_id = tune_run.info.run_id
self.mlflow_logger.mlflow.end_run()
# Passes datasets to workers efficiently through Ray storage
train_with_data = ray.tune.with_parameters(
self._run_worker,
train_dataset=train_dataset,
validation_dataset=validation_dataset,
test_dataset=test_dataset,
)
if self.from_checkpoint:
# Create trainer from checkpoint
if RayTorchTrainer.can_restore(to_uri(self.from_checkpoint)):
trainer = RayTorchTrainer.restore(
path=to_uri(self.from_checkpoint),
train_loop_per_worker=train_with_data,
train_loop_config=None,
)
else:
# Ray is unable to restore the checkpoint implicitly, but it's passing
# it to the trial
trainer = RayTorchTrainer(
train_loop_per_worker=train_with_data,
train_loop_config=None,
scaling_config=self.ray_scaling_config,
torch_config=self.ray_torch_config,
dataset_config=self.ray_data_config,
resume_from_checkpoint=Checkpoint(to_uri(self.from_checkpoint)),
)
else:
# Create trainer without checkpoint
trainer = RayTorchTrainer(
train_loop_per_worker=train_with_data,
train_loop_config=None,
scaling_config=self.ray_scaling_config,
torch_config=self.ray_torch_config,
dataset_config=self.ray_data_config,
)
# Create the parameter space for hyperparameter tuning
param_space = {"train_loop_config": search_space(self.ray_search_space)}
# Create the tuner with the driver function
tuner = ray.tune.Tuner(
trainable=trainer,
param_space=param_space,
tune_config=self.ray_tune_config,
run_config=self.ray_run_config,
)
# Run the tuner and capture results
if self.time_ray and self.logger is not None:
self.logger.create_logger_context(run_id=self.mlflow_tune_run_id)
self.tune_result_grid = time_and_log(
func=tuner.fit,
logger=self.logger,
identifier="ray_fit_time_s",
step=0,
)
else:
# Run the tuner and capture results
self.tune_result_grid = tuner.fit()
return train_dataset, validation_dataset, test_dataset, None
def _run_worker(
self,
config: Dict,
train_dataset: Dataset,
validation_dataset: Dataset | None = None,
test_dataset: Dataset | None = None,
) -> None:
self.load_checkpoint()
self._override_config(config)
self._set_seed()
self._init_distributed_strategy()
self._setup_metrics()
if self.logger:
py_logger.debug(f"Using logger: {self.logger.__class__.__name__}")
worker_run_name = f"worker_{self.strategy.global_rank()}"
if self.strategy.is_main_worker and self.mlflow_logger:
# required so env vars are set correctly for ray
mlflow.set_tracking_uri(self.mlflow_logger.tracking_uri)
mlflow.set_experiment(self.mlflow_logger.experiment_name)
experiment_id = self.mlflow_logger.resolve_experiment_id()
self.mlflow_logger.mlflow.set_tracking_uri(self.mlflow_logger.tracking_uri)
if self.mlflow_tune_run_id:
# If a tune_run_id is set, we create a nested run (Ray)
train_run_name = ray.tune.get_context().get_trial_name()
train_run = self.mlflow_logger.mlflow.start_run(
experiment_id=experiment_id,
run_name=train_run_name,
parent_run_id=self.mlflow_tune_run_id,
)
else:
train_run_name = self.run_name
train_run = self.mlflow_logger.mlflow.start_run(
experiment_id=experiment_id,
run_name=train_run_name,
)
# store the mlflow run id as a parent for the worker runs
self.mlflow_train_run_id = train_run.info.run_id
# Stop train run to remove pending status in mlflow
# (metrics are logged to workers)
self.mlflow_logger.mlflow.end_run()
worker_run_name += " (main)"
# Broadcast trial_run_id from main worker to all workers
# Ensure the broadcasted value is not None (Horovod otherwise deadlocks)
self.mlflow_train_run_id = self.strategy.broadcast_obj(
self.mlflow_train_run_id or "", src_rank=0
)
py_logger.debug(
f"Broadcasted mlflow_trial_run_id {self.mlflow_train_run_id} to all workers"
)
if self.mlflow_logger and not self.strategy.is_main_worker:
# Set the tracking uri and experiment for other workers after main worker
mlflow.set_tracking_uri(self.mlflow_logger.tracking_uri)
mlflow.set_experiment(self.mlflow_logger.experiment_name)
# Create logger on worker level
self.logger.create_logger_context(
rank=self.strategy.global_rank(),
parent_run_id=self.mlflow_train_run_id,
run_name=worker_run_name,
)
self.log(
item=self.strategy.name,
identifier="strategy",
kind="param",
)
self.log(
item=self.strategy.global_rank(),
identifier="global_rank",
kind="param",
)
self.log(
item=self.strategy.global_world_size(),
identifier="global_world_size",
kind="param",
)
if self.mlflow_logger and self.mlflow_logger.should_log():
self.mlflow_worker_run_id = self.mlflow_logger.active_run.info.run_id
py_logger.debug("...the logger has been initialized")
hparams = self.config.model_dump()
hparams["distributed_strategy"] = self.strategy.__class__.__name__
self.logger.save_hyperparameters(hparams)
self.create_dataloaders(
train_dataset=train_dataset,
validation_dataset=validation_dataset,
test_dataset=test_dataset,
)
self.create_model_loss_optimizer()
self.train()
if self.logger:
self.logger.destroy_logger_context()
self.strategy.clean_up()
def _set_seed(self) -> None:
py_logger.debug(f"Using random seed: {self.random_seed}")
self.torch_rng = set_seed(self.random_seed)
if self.torch_rng_state is not None:
# Resume state from checkpoint
py_logger.debug("Resuming torch PRNG state from checkpoint")
self.torch_rng.set_state(self.torch_rng_state)
def _override_config(self, config: Dict) -> None:
"""Override self.config with a sample from the search space from the Ray tuner."""
self.config = self.config.model_copy(update=config)
py_logger.debug("Overrode self.config with trial config (if given)")
def _set_epoch_dataloaders(self, epoch: int) -> None:
"""Sets epoch in the distributed sampler of a dataloader when using it."""
if not self.strategy.is_distributed:
return
self.train_dataloader.sampler.set_epoch(epoch)
if self.validation_dataloader is not None:
self.validation_dataloader.sampler.set_epoch(epoch)
if self.test_dataloader is not None:
self.test_dataloader.sampler.set_epoch(epoch)
[docs]
def set_epoch(self) -> None:
"""Set current epoch at the beginning of training."""
# We don't want to start stepping until after the first epoch
if self.profiler and self.current_epoch > 0:
# Always step the profiler at the beginning of the epoch
self.profiler.step()
if self.lr_scheduler:
self.lr_scheduler.step()
self._set_epoch_dataloaders(self.current_epoch)
[docs]
def log(
self,
item: Any | List[Any],
identifier: str | List[str],
kind: str = "metric",
step: int | None = None,
batch_idx: int | None = None,
**kwargs,
) -> None:
"""Log ``item`` with ``identifier`` name of ``kind`` type at ``step``
time step.
Args:
item (Any | List[Any]): element to be logged (e.g., metric).
identifier (str | List[str]): unique identifier for the
element to log(e.g., name of a metric).
kind (str, optional): type of the item to be logged. Must be one
among the list of self.supported_types. Defaults to 'metric'.
step (int | None, optional): logging step. Defaults to None.
batch_idx (int | None, optional): DataLoader batch counter
(i.e., batch idx), if available. Defaults to None.
"""
if self.logger:
self.logger.log(
item=item,
identifier=identifier,
kind=kind,
step=step,
batch_idx=batch_idx,
**kwargs,
)
[docs]
def ray_report(
self,
metrics: Dict[str, float],
checkpoint_file: str | Path | None = None,
checkpoint_dir: str | Path | None = None,
checkpoint_data: Any | None = None,
) -> None:
"""Report a dictionary of metrics and optionally a checkpoint to Ray, only when using
Ray distributed strategies. The checkpoint could be in the form of a Python object
(passed as ``checkpoint_data``), the path to a single file (passed as
``checkpoint_file``), or the path to an existing checkpoint directory (passed as
``checkpoint_dir``).
Args:
metrics (Dict[str, float]): metrics to be reported.
checkpoint_file (str | Path | None, optional): path to the checkpoint file.
Defaults to None.
checkpoint_dir (str | Path | None, optional): path to the checkpoint directory.
Defaults to None.
checkpoint_data (Any | None, optional):object to serialize as a checkpoint.
Defaults to None.
"""
if not isinstance(self.strategy, RayTorchDistributedStrategy):
# Ray is not used, thus do nothing
return
if checkpoint_file:
# A checkpoint is given as a file
with tempfile.TemporaryDirectory() as tmp_dir:
import shutil
shutil.copy(checkpoint_file, tmp_dir)
checkpoint = ray.train.Checkpoint.from_directory(tmp_dir)
ray.train.report(metrics, checkpoint=checkpoint)
elif checkpoint_data:
# A checkpoint is given as a python object which needs to be serialized
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir = Path(tmp_dir)
ckpt_file = tmp_dir / "ckpt.pt"
torch.save(checkpoint_data, ckpt_file)
checkpoint = ray.train.Checkpoint.from_directory(tmp_dir)
ray.train.report(metrics, checkpoint=checkpoint)
elif checkpoint_dir:
# A checkpoint is given as a directory
checkpoint = ray.train.Checkpoint.from_directory(checkpoint_dir)
ray.train.report(metrics, checkpoint=checkpoint)
else:
# No checkpoint is given: only report metrics
ray.train.report(metrics)
[docs]
def compute_metrics(
self,
true: torch.Tensor,
pred: torch.Tensor,
logger_step: int,
batch_idx: int | None,
stage: str = "train",
) -> Dict[str, Any]:
"""Compute and log metrics.
Args:
true (torch.Tensor): true values.
pred (torch.Tensor): predicted values.
logger_step (int): global step to pass to the logger.
stage (str): 'train', 'validation'...
Returns:
Dict[str, Any]: metric values.
"""
m_values = {}
for m_name, metric in self.metrics.items():
# metric = metric.to(self.device)
m_val = metric(pred, true).detach().cpu().numpy()
self.log(
item=m_val,
identifier=f"{stage}_{m_name}",
kind="metric",
step=logger_step,
batch_idx=batch_idx,
)
m_values[m_name] = m_val
return m_values
[docs]
@profile_torch_trainer
@measure_gpu_utilization
def train(self) -> None:
"""Trains a machine learning model.
Main training loop/logic.
Args:
train_dataset (Dataset): training dataset.
validation_dataset (Dataset): validation dataset.
test_dataset (Dataset): test dataset.
Returns:
Dataset: The training dataset.
Dataset: The validation dataset.
Dataset: The test dataset.
Any: The trained model
"""
progress_bar = tqdm(
range(self.current_epoch, self.epochs),
desc="Epochs",
disable=self.disable_tqdm or not self.strategy.is_main_worker,
)
for self.current_epoch in progress_bar:
if self.strategy.is_main_worker:
epoch_start_time = perf_counter()
progress_bar.set_description(f"Epoch {self.current_epoch + 1}/{self.epochs}")
self.set_epoch()
self.train_epoch()
val_metric = self.validation_epoch()
# Periodic checkpointing
periodic_ckpt_path = self.save_checkpoint(name=f"epoch_{self.current_epoch}")
# Checkpointing current best model
best_ckpt_path = None
worker_val_metrics = self.strategy.gather(val_metric, dst_rank=0)
if self.strategy.is_main_worker:
avg_metric = torch.mean(torch.stack(worker_val_metrics)).detach().cpu()
self.log(
item=avg_metric.item(),
identifier="global_validation_loss_epoch",
kind="metric",
step=self.current_epoch,
)
if avg_metric < self.best_validation_metric:
best_ckpt_path = self.save_checkpoint(
name="best_model",
best_validation_metric=avg_metric,
force=True,
)
self.best_validation_metric = avg_metric
# Report validation metrics to Ray (useful for tuning!)
metric_name = _get_tuning_metric_name(self.ray_tune_config)
if metric_name is None:
raise ValueError("Could not find a metric in the TuneConfig")
if (
self.time_ray
and self.logger is not None
and isinstance(self.strategy, RayTorchDistributedStrategy)
):
time_and_log(
func=partial(
self.ray_report,
metrics={metric_name: val_metric.item()},
checkpoint_dir=best_ckpt_path or periodic_ckpt_path,
),
logger=self.logger,
identifier="ray_report_time_s_per_epoch",
step=self.current_epoch,
)
else:
self.ray_report(
metrics={metric_name: val_metric.item()},
checkpoint_dir=best_ckpt_path or periodic_ckpt_path,
)
if self.test_every and (self.current_epoch + 1) % self.test_every == 0:
self.test_epoch()
# Measure epoch time and log
if self.strategy.is_main_worker:
epoch_time = perf_counter() - epoch_start_time
self.log(
item=epoch_time,
identifier="epoch_time_s",
kind="metric",
step=self.current_epoch,
)
[docs]
def train_epoch(self) -> torch.Tensor:
"""Perform a complete sweep over the training dataset, completing an epoch of training.
Args:
epoch (int): current epoch number, from 0 to ``self.epochs - 1``.
Returns:
torch.Tensor: average training loss for the current epoch.
"""
self.model.train()
train_loss_sum = 0.0
train_metrics_sum = defaultdict(float)
batch_counter = 0
progress_bar = tqdm(
enumerate(self.train_dataloader),
total=len(self.train_dataloader) // self.strategy.global_world_size(),
desc="Train batches",
disable=self.disable_tqdm or not self.strategy.is_main_worker,
leave=False, # Set this to true to see how many batches were used
)
for batch_idx, train_batch in progress_bar:
loss, metrics = self.train_step(batch=train_batch, batch_idx=batch_idx)
train_loss_sum += loss
batch_counter += 1
for name, val in metrics.items():
train_metrics_sum[name] += val
# Important: update counter
self.train_glob_step += 1
# Aggregate and log losses
avg_loss = train_loss_sum / batch_counter
self.log(
item=avg_loss.item(),
identifier="train_loss_epoch",
kind="metric",
step=self.train_glob_step,
)
# Aggregate and log metrics
for m_name, m_val in train_metrics_sum.items():
self.log(
item=m_val / batch_counter,
identifier="train_" + m_name + "_epoch",
kind="metric",
step=self.train_glob_step,
)
return avg_loss
[docs]
def train_step(
self, batch: torch.Tensor, batch_idx: int
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Perform a single optimization step using a batch sampled from the training dataset.
Args:
batch (torch.Tensor): batch sampled by a dataloader.
batch_idx (int): batch index in the dataloader.
Returns:
torch.Tensor: The batch loss.
Dict[str, Any]: Dictionary of metric values (same structure as ``self.metrics``).
"""
x, y = batch
x, y = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad()
pred_y = self.model(x)
loss = self.loss(pred_y, y)
loss.backward()
self.optimizer.step()
# Log metrics
self.log(
item=loss.item(),
identifier="train_loss",
kind="metric",
step=self.train_glob_step,
batch_idx=batch_idx,
)
metrics: Dict[str, Any] = self.compute_metrics(
true=y,
pred=pred_y,
logger_step=self.train_glob_step,
batch_idx=batch_idx,
stage="train",
)
return loss, metrics
[docs]
def validation_epoch(self) -> torch.Tensor | None:
"""Perform a complete sweep over the validation dataset, completing an epoch of
validation.
Returns:
torch.Tensor | None: average validation loss for the current epoch if
self.validation_dataloader is not None
"""
if self.validation_dataloader is None:
return
progress_bar = tqdm(
enumerate(self.validation_dataloader),
total=len(self.validation_dataloader) // self.strategy.global_world_size(),
desc="Validation batches",
disable=self.disable_tqdm or not self.strategy.is_main_worker,
leave=False, # Set this to true to see how many batches were used
)
self.model.eval()
validation_loss_sum = 0.0
validation_metrics_sum = defaultdict(float)
batch_counter = 0
for batch_idx, val_batch in progress_bar:
loss, metrics = self.validation_step(batch=val_batch, batch_idx=batch_idx)
validation_loss_sum += loss
batch_counter += 1
for name, val in metrics.items():
validation_metrics_sum[name] += val
# Important: update counter
self.validation_glob_step += 1
# Aggregate and log losses
avg_loss = validation_loss_sum / batch_counter
self.log(
item=avg_loss.item(),
identifier="validation_loss_epoch",
kind="metric",
step=self.validation_glob_step,
)
# Aggregate and log metrics
for m_name, m_val in validation_metrics_sum.items():
self.log(
item=m_val / batch_counter,
identifier="validation_" + m_name + "_epoch",
kind="metric",
step=self.validation_glob_step,
)
return avg_loss
[docs]
def validation_step(
self, batch: torch.Tensor, batch_idx: int
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Perform a single optimization step using a batch sampled from the validation
dataset.
Args:
batch (torch.Tensor): batch sampled by a dataloader.
batch_idx (int): batch index in the dataloader.
Returns:
torch.Tensor: Batch loss.
Dict[str, Any]: Dictionary of metric values (same structure as ``self.metrics``).
"""
x, y = batch
x, y = x.to(self.device), y.to(self.device)
with torch.no_grad():
pred_y = self.model(x)
loss: torch.Tensor = self.loss(pred_y, y)
self.log(
item=loss.item(),
identifier="validation_loss",
kind="metric",
step=self.validation_glob_step,
batch_idx=batch_idx,
)
metrics: Dict[str, Any] = self.compute_metrics(
true=y,
pred=pred_y,
logger_step=self.validation_glob_step,
batch_idx=batch_idx,
stage="validation",
)
return loss, metrics
[docs]
def test_epoch(self) -> torch.Tensor:
"""Perform a complete sweep over the test dataset, completing an epoch of test.
Returns:
torch.Tensor: average test loss for the current epoch.
"""
raise NotImplementedError()
[docs]
def test_step(
self, batch: torch.Tensor, batch_idx: int
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Perform a single predictions step using a batch sampled from the test dataset.
Args:
batch (torch.Tensor): batch sampled by a dataloader.
batch_idx (int): batch index in the dataloader.
Returns:
torch.Tensor: The batch loss
Dict[str, Any]: Dictionary of metric values (same structure as ``self.metrics``).
"""
raise NotImplementedError()
[docs]
class TorchLightningTrainer(Trainer):
"""Generic trainer for torch Lightning workflows.
Args:
config (Dict | str): `Lightning configuration`_
which can be the path to a file or a Python dictionary.
mlflow_saved_model (str, optional): name of the model created in
MLFlow. Defaults to 'my_model'.
.. _Lightning configuration:
https://pytorch-lightning.readthedocs.io/en/1.6.5/common/lightning_cli.html
"""
def __init__(self, config: Dict | str, mlflow_saved_model: str = "my_model"):
self.save_parameters(**self.locals2params(locals()))
super().__init__()
if isinstance(config, str) and Path(config).is_file():
# Load from YAML
config = load_yaml(config)
self.conf = config
self.mlflow_saved_model = mlflow_saved_model
[docs]
@monitor_exec
def execute(self) -> Any:
from lightning import LightningDataModule, LightningModule
from lightning.pytorch.cli import LightningCLI
from .mlflow import init_lightning_mlflow, teardown_lightning_mlflow
init_lightning_mlflow(
self.conf, tmp_dir="/tmp", registered_model_name=self.mlflow_saved_model
)
old_argv = sys.argv
sys.argv = ["some_script_placeholder.py"]
cli = LightningCLI(
args=self.conf,
model_class=LightningModule,
datamodule_class=LightningDataModule,
run=False,
save_config_kwargs={
"overwrite": True,
"config_filename": "pl-training.yml",
},
subclass_mode_model=True,
subclass_mode_data=True,
)
sys.argv = old_argv
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
teardown_lightning_mlflow()