# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# - Anna Lappe <anna.elisa.lappe@cern.ch> - CERN
# - Jarl Sondre Sæther <jarl.sondre.saether@cern.ch> - CERN
# - Henry Mutegeki <henry.mutegeki@cern.ch> - CERN
# --------------------------------------------------------------------------------------
"""Provides training logic for PyTorch models via Trainer classes."""
import os
import sys
import tempfile
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset, Sampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from ..components import Trainer, monitor_exec
from ..loggers import Logger, LogMixin
from ..utils import load_yaml
from .config import TrainingConfiguration
from .distributed import (
DeepSpeedStrategy,
HorovodStrategy,
NonDistributedStrategy,
RayDDPStrategy,
RayDeepSpeedStrategy,
TorchDDPStrategy,
TorchDistributedStrategy,
distributed_resources_available,
)
from .reproducibility import seed_worker, set_seed
from .type import Batch, LrScheduler, Metric
[docs]
class TorchTrainer(Trainer, LogMixin):
"""Trainer class for torch training algorithms.
Args:
config (Union[Dict, TrainingConfiguration]): training configuration
containing hyperparameters.
epochs (int): number of training epochs.
model (Optional[Union[nn.Module, str]], optional): pytorch model to
train or a string identifier. Defaults to None.
strategy (Literal['ddp', 'deepspeed', 'horovod'], optional):
distributed strategy. Defaults to 'ddp'.
validation_every (Optional[int], optional): run a validation epoch
every ``validation_every`` epochs. Disabled if None. Defaults to 1.
test_every (Optional[int], optional): run a test epoch
every ``test_every`` epochs. Disabled if None. Defaults to None.
random_seed (Optional[int], optional): set random seed for
reproducibility. If None, the seed is not set. Defaults to None.
logger (Optional[Logger], optional): logger for ML tracking.
Defaults to None.
metrics (Optional[Dict[str, Metric]], optional): map of torchmetrics
metrics. Defaults to None.
checkpoints_location (str): path to checkpoints directory.
Defaults to "checkpoints".
checkpoint_every (Optional[int]): save a checkpoint every
``checkpoint_every`` epochs. Disabled if None. Defaults to None.
disable_tqdm (bool): whether to disable tqdm progress bar(s).
name (Optional[str], optional): trainer custom name. Defaults to None.
profiling_wait_epochs (int): how many epochs to wait before starting
the profiler.
profiling_warmup_epochs (int): length of the profiler warmup phase in terms of
number of epochs.
"""
# TODO:
# - extract BaseTorchTrainer and extend it creating a set of trainer
# templates (e.g.. GAN, Classifier, Transformer) allowing scientists
# to reuse ML algos.
_strategy: TorchDistributedStrategy = None
#: PyTorch ``DataLoader`` for training dataset.
train_dataloader: DataLoader = None
#: PyTorch ``DataLoader`` for validation dataset.
validation_dataloader: DataLoader = None
#: PyTorch ``DataLoader`` for test dataset.
test_dataloader: DataLoader = None
#: PyTorch model to train.
model: nn.Module = None
#: Loss criterion.
loss: Callable = None
#: Optimizer.
optimizer: Optimizer = None
#: Learning rate scheduler.
lr_scheduler: LrScheduler = None
#: PyTorch random number generator (PRNG).
torch_rng: torch.Generator = None
#: itwinai ``itwinai.Logger``
logger: Logger = None
#: Total number training batches used so far, across all epochs.
train_glob_step: int = 0
#: Total number validation batches used so far, across all epochs.
validation_glob_step: int = 0
#: Total number test batches used so far, across all epochs.
test_glob_step: int = 0
#: Dictionary of ``torchmetrics`` metrics, indexed by user-defined names.
metrics: Dict[str, Metric]
#: PyTorch Profiler for communication vs. computation comparison
profiler: Optional[Any]
def __init__(
self,
config: Union[Dict, TrainingConfiguration],
epochs: int,
model: Optional[Union[nn.Module, str]] = None,
strategy: Optional[Literal["ddp", "deepspeed", "horovod"]] = "ddp",
validation_every: Optional[int] = 1,
test_every: Optional[int] = None,
random_seed: Optional[int] = None,
logger: Optional[Logger] = None,
metrics: Optional[Dict[str, Metric]] = None,
checkpoints_location: str = "checkpoints",
checkpoint_every: Optional[int] = None,
disable_tqdm: bool = False,
name: Optional[str] = None,
profiling_wait_epochs: int = 1,
profiling_warmup_epochs: int = 2,
) -> None:
super().__init__(name)
self.save_parameters(**self.locals2params(locals()))
# config is mean to store all hyperparameters, which can very from use
# case to use case and include learning_rate, batch_size....
if isinstance(config, dict):
config = TrainingConfiguration(**config)
self.config = config
self.epochs = epochs
self.model = model
self.strategy = strategy
self.validation_every = validation_every
self.test_every = test_every
self.random_seed = random_seed
self.logger = logger
self.metrics = metrics if metrics is not None else {}
self.checkpoints_location = checkpoints_location
os.makedirs(self.checkpoints_location, exist_ok=True)
self.checkpoint_every = checkpoint_every
self.disable_tqdm = disable_tqdm
self.profiler = None
self.profiling_wait_epochs = profiling_wait_epochs
self.profiling_warmup_epochs = profiling_warmup_epochs
@property
def strategy(self) -> TorchDistributedStrategy:
"""Strategy currently in use."""
return self._strategy
@strategy.setter
def strategy(self, strategy: str | TorchDistributedStrategy) -> None:
if isinstance(strategy, TorchDistributedStrategy):
self._strategy = strategy
else:
self._strategy = self._detect_strategy(strategy)
@property
def device(self) -> str:
"""Current device from distributed strategy."""
return self.strategy.device()
def _detect_strategy(self, strategy: str) -> TorchDistributedStrategy:
if strategy is None or not distributed_resources_available():
print("WARNING: falling back to non-distributed strategy.")
strategy_obj = NonDistributedStrategy()
elif strategy == "ddp":
strategy_obj = TorchDDPStrategy(backend=self.config.dist_backend)
elif strategy == "horovod":
strategy_obj = HorovodStrategy()
elif strategy == "deepspeed":
strategy_obj = DeepSpeedStrategy(backend=self.config.dist_backend)
else:
raise NotImplementedError(f"Strategy '{strategy}' is not recognized/implemented.")
return strategy_obj
def _init_distributed_strategy(self) -> None:
if not self.strategy.is_initialized:
self.strategy.init()
def _optimizer_from_config(self) -> None:
if self.config.optimizer == "adadelta":
self.optimizer = optim.Adadelta(
self.model.parameters(),
lr=self.config.optim_lr,
weight_decay=self.config.optim_weight_decay,
)
elif self.config.optimizer == "adam":
self.optimizer = optim.Adam(
self.model.parameters(),
lr=self.config.optim_lr,
weight_decay=self.config.optim_weight_decay,
)
elif self.config.optimizer == "rmsprop":
self.optimizer = optim.RMSprop(
self.model.parameters(),
lr=self.config.optim_lr,
weight_decay=self.config.optim_weight_decay,
momentum=self.config.optim_momentum,
)
elif self.config.optimizer == "sgd":
self.optimizer = optim.SGD(
self.model.parameters(),
lr=self.config.optim_lr,
weight_decay=self.config.optim_weight_decay,
momentum=self.config.optim_momentum,
)
else:
raise ValueError(
"Unrecognized self.config.optimizer! Check the docs for "
"supported values and consider overriding "
"create_model_loss_optimizer method for more flexibility."
)
def _loss_from_config(self) -> None:
if self.config.loss == "nllloss":
self.loss = nn.functional.nll_loss
elif self.config.loss == "cross_entropy":
self.loss = nn.functional.cross_entropy
elif self.config.loss == "mse":
self.loss = nn.functional.mse_loss
else:
raise ValueError(
"Unrecognized self.config.loss! Check the docs for "
"supported values and consider overriding "
"create_model_loss_optimizer method for more flexibility."
)
[docs]
def get_default_distributed_kwargs(self) -> Dict:
"""Gives the default kwargs for the trainer's strategy's distributed() method."""
if isinstance(self.strategy, DeepSpeedStrategy):
# Batch size definition is not optional for DeepSpeedStrategy!
distribute_kwargs = dict(
config_params=dict(train_micro_batch_size_per_gpu=self.config.batch_size)
)
elif isinstance(self.strategy, HorovodStrategy):
import horovod.torch as hvd
distribute_kwargs = dict(
compression=(
hvd.Compression.fp16
if self.config.fp16_allreduce
else hvd.Compression.none
),
op=hvd.Adasum if self.config.use_adasum else hvd.Average,
gradient_predivide_factor=self.config.gradient_predivide_factor,
)
else:
distribute_kwargs = {}
return distribute_kwargs
[docs]
def create_model_loss_optimizer(self) -> None:
"""
Instantiate a torch model, loss, optimizer, and LR scheduler using the
configuration provided in the Trainer constructor.
Generally a user-defined method.
"""
###################################
# Dear user, this is a method you #
# may be interested to override! #
###################################
if self.model is None:
raise ValueError(
"self.model is None! Either pass it to the constructor or "
"override create_model_loss_optimizer method."
)
# Parse optimizer from training configuration
# Optimizer can be changed with a custom one here!
self._optimizer_from_config()
# Parse loss from training configuration
# Loss can be changed with a custom one here!
self._loss_from_config()
# IMPORTANT: model, optimizer, and scheduler need to be distributed
distribute_kwargs = self.get_default_distributed_kwargs()
# Distributed model, optimizer, and scheduler
(self.model, self.optimizer, self.lr_scheduler) = self.strategy.distributed(
self.model, self.optimizer, self.lr_scheduler, **distribute_kwargs
)
[docs]
def create_dataloaders(
self,
train_dataset: Dataset,
validation_dataset: Optional[Dataset] = None,
test_dataset: Optional[Dataset] = None,
) -> None:
"""
Create train, validation and test dataloaders using the
configuration provided in the Trainer constructor.
Generally a user-defined method.
Args:
train_dataset (Dataset): training dataset object.
validation_dataset (Optional[Dataset]): validation dataset object.
Default None.
test_dataset (Optional[Dataset]): test dataset object.
Default None.
"""
###################################
# Dear user, this is a method you #
# may be interested to override! #
###################################
self.train_dataloader = self.strategy.create_dataloader(
dataset=train_dataset,
batch_size=self.config.batch_size,
num_workers=self.config.num_workers_dataloader,
pin_memory=self.config.pin_gpu_memory,
generator=self.torch_rng,
shuffle=self.config.shuffle_train,
)
if validation_dataset is not None:
self.validation_dataloader = self.strategy.create_dataloader(
dataset=validation_dataset,
batch_size=self.config.batch_size,
num_workers=self.config.num_workers_dataloader,
pin_memory=self.config.pin_gpu_memory,
generator=self.torch_rng,
shuffle=self.config.shuffle_validation,
)
if test_dataset is not None:
self.test_dataloader = self.strategy.create_dataloader(
dataset=test_dataset,
batch_size=self.config.batch_size,
num_workers=self.config.num_workers_dataloader,
pin_memory=self.config.pin_gpu_memory,
generator=self.torch_rng,
shuffle=self.config.shuffle_test,
)
def _setup_metrics(self):
"""Move metrics to current device."""
for m_name, metric in self.metrics.items():
self.metrics[m_name] = metric.to(self.device)
[docs]
@monitor_exec
def execute(
self,
train_dataset: Dataset,
validation_dataset: Optional[Dataset] = None,
test_dataset: Optional[Dataset] = None,
) -> Tuple[Dataset, Dataset, Dataset, Any]:
"""Prepares distributed environment and data structures
for the actual training.
Args:
train_dataset (Dataset): training dataset.
validation_dataset (Optional[Dataset], optional): validation
dataset. Defaults to None.
test_dataset (Optional[Dataset], optional): test dataset.
Defaults to None.
Returns:
Tuple[Dataset, Dataset, Dataset, Any]: training dataset,
validation dataset, test dataset, trained model.
"""
self.torch_rng = set_seed(self.random_seed)
self._init_distributed_strategy()
self._setup_metrics()
self.create_dataloaders(
train_dataset=train_dataset,
validation_dataset=validation_dataset,
test_dataset=test_dataset,
)
self.create_model_loss_optimizer()
if self.logger:
self.logger.create_logger_context(rank=self.strategy.global_rank())
hparams = self.config.model_dump()
hparams["distributed_strategy"] = self.strategy.__class__.__name__
self.logger.save_hyperparameters(hparams)
self.train()
if self.logger:
self.logger.destroy_logger_context()
self.strategy.clean_up()
return train_dataset, validation_dataset, test_dataset, self.model
def _set_epoch_dataloaders(self, epoch: int):
"""
Sets epoch in the distributed sampler of a dataloader when using it.
"""
if self.strategy.is_distributed:
self.train_dataloader.sampler.set_epoch(epoch)
if self.validation_dataloader is not None:
self.validation_dataloader.sampler.set_epoch(epoch)
if self.test_dataloader is not None:
self.test_dataloader.sampler.set_epoch(epoch)
[docs]
def set_epoch(self, epoch: int) -> None:
"""Set current epoch at the beginning of training.
Args:
epoch (int): epoch number, from 0 to ``epochs-1``.
"""
if self.profiler is not None and epoch > 0:
# We don't want to start stepping until after the first epoch
self.profiler.step()
self._set_epoch_dataloaders(epoch)
[docs]
def log(
self,
item: Union[Any, List[Any]],
identifier: Union[str, List[str]],
kind: str = "metric",
step: Optional[int] = None,
batch_idx: Optional[int] = None,
**kwargs,
) -> None:
"""Log ``item`` with ``identifier`` name of ``kind`` type at ``step``
time step.
Args:
item (Union[Any, List[Any]]): element to be logged (e.g., metric).
identifier (Union[str, List[str]]): unique identifier for the
element to log(e.g., name of a metric).
kind (str, optional): type of the item to be logged. Must be one
among the list of self.supported_types. Defaults to 'metric'.
step (Optional[int], optional): logging step. Defaults to None.
batch_idx (Optional[int], optional): DataLoader batch counter
(i.e., batch idx), if available. Defaults to None.
"""
if self.logger:
self.logger.log(
item=item,
identifier=identifier,
kind=kind,
step=step,
batch_idx=batch_idx,
**kwargs,
)
[docs]
def save_checkpoint(
self, name: str, epoch: int, loss: Optional[torch.Tensor] = None
) -> None:
"""Save training checkpoint.
Args:
name (str): name of the checkpoint.
epoch (int): current training epoch.
loss (Optional[torch.Tensor]): current loss (if available).
"""
state = dict(
epoch=epoch,
loss=loss,
optimizer=self.optimizer.state_dict(),
model=self.model.state_dict(),
lr_scheduler=self.lr_scheduler,
)
ckpt_path = os.path.join(self.checkpoints_location, name)
torch.save(state, ckpt_path)
# print(f"Saved '{name}' checkpoint at {ckpt_path}")
# Save checkpoint to logger
self.log(ckpt_path, name, kind="artifact")
[docs]
def load_checkpoint(self, name: str) -> None:
"""Load state from a checkpoint.
Args:
name (str): name of the checkpoint to load, assuming it
is under ``self.checkpoints_location`` location.
"""
ckpt_path = os.path.join(self.checkpoints_location, name)
state = torch.load(ckpt_path, map_location=self.device)
self.model.load_state_dict(state["model"])
self.optimizer.load_state_dict(state["optimizer"])
self.lr_scheduler = state["lr_scheduler"]
[docs]
def compute_metrics(
self,
true: Batch,
pred: Batch,
logger_step: int,
batch_idx: Optional[int],
stage: str = "train",
) -> Dict[str, Any]:
"""Compute and log metrics.
Args:
metrics (Dict[str, Metric]): metrics dict. Can be
``self.train_metrics`` or ``self.validation_metrics``.
true (Batch): true values.
pred (Batch): predicted values.
logger_step (int): global step to pass to the logger.
stage (str): 'train', 'validation'...
Returns:
Dict[str, Any]: metric values.
"""
m_values = {}
for m_name, metric in self.metrics.items():
# metric = metric.to(self.device)
m_val = metric(pred, true).detach().cpu().numpy()
self.log(
item=m_val,
identifier=f"{stage}_{m_name}",
kind="metric",
step=logger_step,
batch_idx=batch_idx,
)
m_values[m_name] = m_val
return m_values
[docs]
def train(self):
"""Trains a machine learning model.
Main training loop/logic.
Args:
train_dataset (Dataset): training dataset.
validation_dataset (Dataset): validation dataset.
test_dataset (Dataset): test dataset.
Returns:
Tuple[Dataset, Dataset, Dataset, Any]: training dataset,
validation dataset, test dataset, trained model.
"""
best_loss = float("inf")
progress_bar = tqdm(
range(self.epochs),
desc="Epochs",
disable=self.disable_tqdm or not self.strategy.is_main_worker,
)
for epoch in progress_bar:
progress_bar.set_description(f"Epoch {epoch + 1}/{self.epochs}")
epoch_n = epoch + 1
self.set_epoch(epoch)
self.train_epoch(epoch)
if self.validation_every and epoch_n % self.validation_every == 0:
val_loss = self.validation_epoch(epoch)
# Checkpointing current best model
worker_val_losses = self.strategy.gather(val_loss, dst_rank=0)
if self.strategy.is_main_worker:
avg_loss = torch.mean(torch.stack(worker_val_losses)).detach().cpu()
if avg_loss < best_loss and self.checkpoint_every is not None:
ckpt_name = "best_model.pth"
self.save_checkpoint(name=ckpt_name, epoch=epoch, loss=avg_loss)
best_loss = avg_loss
if self.test_every and epoch_n % self.test_every == 0:
self.test_epoch(epoch)
# Periodic checkpointing
if (
self.strategy.is_main_worker
and self.checkpoint_every
and epoch_n % self.checkpoint_every == 0
):
ckpt_name = f"epoch_{epoch}.pth"
self.save_checkpoint(name=ckpt_name, epoch=epoch)
[docs]
def train_epoch(self, epoch: int) -> torch.Tensor:
"""Perform a complete sweep over the training dataset, completing an
epoch of training.
Args:
epoch (int): current epoch number, from 0 to ``self.epochs - 1``.
Returns:
Loss: average training loss for the current epoch.
"""
self.model.train()
train_loss_sum = 0.0
train_metrics_sum = defaultdict(float)
batch_counter = 0
progress_bar = tqdm(
enumerate(self.train_dataloader),
total=len(self.train_dataloader) // self.strategy.global_world_size(),
desc="Train batches",
disable=self.disable_tqdm or not self.strategy.is_main_worker,
leave=False, # Set this to true to see how many batches were used
)
for batch_idx, train_batch in progress_bar:
loss, metrics = self.train_step(batch=train_batch, batch_idx=batch_idx)
train_loss_sum += loss
batch_counter += 1
for name, val in metrics.items():
train_metrics_sum[name] += val
# Important: update counter
self.train_glob_step += 1
# Aggregate and log losses
avg_loss = train_loss_sum / batch_counter
self.log(
item=avg_loss.item(),
identifier="train_loss_epoch",
kind="metric",
step=self.train_glob_step,
)
# Aggregate and log metrics
for m_name, m_val in train_metrics_sum.items():
self.log(
item=m_val / batch_counter,
identifier="train_" + m_name + "_epoch",
kind="metric",
step=self.train_glob_step,
)
return avg_loss
[docs]
def train_step(self, batch: Batch, batch_idx: int) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Perform a single optimization step using a batch sampled from the
training dataset.
Args:
batch (Batch): batch sampled by a dataloader.
batch_idx (int): batch index in the dataloader.
Returns:
Tuple[Loss, Dict[str, Any]]: batch loss and dictionary of metric
values with the same structure of ``self.metrics``.
"""
x, y = batch
x, y = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad()
pred_y = self.model(x)
loss = self.loss(pred_y, y)
loss.backward()
self.optimizer.step()
# Log metrics
self.log(
item=loss.item(),
identifier="train_loss",
kind="metric",
step=self.train_glob_step,
batch_idx=batch_idx,
)
metrics: Dict[str, Any] = self.compute_metrics(
true=y,
pred=pred_y,
logger_step=self.train_glob_step,
batch_idx=batch_idx,
stage="train",
)
return loss, metrics
[docs]
def validation_epoch(self, epoch: int) -> torch.Tensor:
"""Perform a complete sweep over the validation dataset, completing an
epoch of validation.
Args:
epoch (int): current epoch number, from 0 to ``self.epochs - 1``.
Returns:
Optional[Loss]: average validation loss for the current epoch if
self.validation_dataloader is not None
"""
if self.validation_dataloader is None:
return
progress_bar = tqdm(
enumerate(self.validation_dataloader),
total=len(self.validation_dataloader) // self.strategy.global_world_size(),
desc="Validation batches",
disable=self.disable_tqdm or not self.strategy.is_main_worker,
leave=False, # Set this to true to see how many batches were used
)
self.model.eval()
validation_loss_sum = 0.0
validation_metrics_sum = defaultdict(float)
batch_counter = 0
for batch_idx, val_batch in progress_bar:
loss, metrics = self.validation_step(batch=val_batch, batch_idx=batch_idx)
validation_loss_sum += loss
batch_counter += 1
for name, val in metrics.items():
validation_metrics_sum[name] += val
# Important: update counter
self.validation_glob_step += 1
# Aggregate and log losses
avg_loss = validation_loss_sum / batch_counter
self.log(
item=avg_loss.item(),
identifier="validation_loss_epoch",
kind="metric",
step=self.validation_glob_step,
)
# Aggregate and log metrics
for m_name, m_val in validation_metrics_sum.items():
self.log(
item=m_val / batch_counter,
identifier="validation_" + m_name + "_epoch",
kind="metric",
step=self.validation_glob_step,
)
return avg_loss
[docs]
def validation_step(
self, batch: Batch, batch_idx: int
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Perform a single optimization step using a batch sampled from the
validation dataset.
Args:
batch (Batch): batch sampled by a dataloader.
batch_idx (int): batch index in the dataloader.
Returns:
Tuple[Loss, Dict[str, Any]]: batch loss and dictionary of metric
values with the same structure of ``self.metrics``.
"""
x, y = batch
x, y = x.to(self.device), y.to(self.device)
with torch.no_grad():
pred_y = self.model(x)
loss: torch.Tensor = self.loss(pred_y, y)
self.log(
item=loss.item(),
identifier="validation_loss",
kind="metric",
step=self.validation_glob_step,
batch_idx=batch_idx,
)
metrics: Dict[str, Any] = self.compute_metrics(
true=y,
pred=pred_y,
logger_step=self.validation_glob_step,
batch_idx=batch_idx,
stage="validation",
)
return loss, metrics
[docs]
def test_epoch(self, epoch: int) -> torch.Tensor:
"""Perform a complete sweep over the test dataset, completing an
epoch of test.
Args:
epoch (int): current epoch number, from 0 to ``self.epochs - 1``.
Returns:
Loss: average test loss for the current epoch.
"""
raise NotImplementedError()
[docs]
def test_step(self, batch: Batch, batch_idx: int) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Perform a single predictions step using a batch sampled from the
test dataset.
Args:
batch (Batch): batch sampled by a dataloader.
batch_idx (int): batch index in the dataloader.
Returns:
Tuple[Loss, Dict[str, Any]]: batch loss and dictionary of metric
values with the same structure of ``self.metrics``.
"""
raise NotImplementedError()
[docs]
class GANTrainer(TorchTrainer):
"""Trainer class for GAN models using pytorch.
Args:
config (Union[Dict, TrainingConfiguration]): training configuration
containing hyperparameters.
epochs (int): number of training epochs.
discriminator (nn.Module): pytorch discriminator model to train GAN.
generator (nn.Module): pytorch generator model to train GAN.
strategy (Literal['ddp', 'deepspeed', 'horovod'], optional):
distributed strategy. Defaults to 'ddp'.
validation_every (Optional[int], optional): run a validation epoch
every ``validation_every`` epochs. Disabled if None. Defaults to 1.
test_every (Optional[int], optional): run a test epoch
every ``test_every`` epochs. Disabled if None. Defaults to None.
random_seed (Optional[int], optional): set random seed for
reproducibility. If None, the seed is not set. Defaults to None.
logger (Optional[Logger], optional): logger for ML tracking.
Defaults to None.
metrics (Optional[Dict[str, Metric]], optional): map of torch metrics
metrics. Defaults to None.
checkpoints_location (str): path to checkpoints directory.
Defaults to "checkpoints".
checkpoint_every (Optional[int]): save a checkpoint every
``checkpoint_every`` epochs. Disabled if None. Defaults to None.
name (Optional[str], optional): trainer custom name. Defaults to None.
"""
def __init__(
self,
config: Union[Dict, TrainingConfiguration],
epochs: int,
discriminator: nn.Module,
generator: nn.Module,
strategy: Literal["ddp", "deepspeed"] = "ddp",
validation_every: Optional[int] = 1,
test_every: Optional[int] = None,
random_seed: Optional[int] = None,
logger: Optional[Logger] = None,
metrics: Optional[Dict[str, Metric]] = None,
checkpoints_location: str = "checkpoints",
checkpoint_every: Optional[int] = None,
name: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(
config=config,
epochs=epochs,
model=None,
strategy=strategy,
validation_every=validation_every,
test_every=test_every,
random_seed=random_seed,
logger=logger,
metrics=metrics,
checkpoints_location=checkpoints_location,
checkpoint_every=checkpoint_every,
name=name,
**kwargs,
)
self.save_parameters(**self.locals2params(locals()))
self.discriminator = discriminator
self.generator = generator
[docs]
def create_model_loss_optimizer(self) -> None:
self.optimizerD = optim.Adam(
self.discriminator.parameters(), lr=self.config.lr, betas=(0.5, 0.999)
)
self.optimizerG = optim.Adam(
self.generator.parameters(), lr=self.config.lr, betas=(0.5, 0.999)
)
self.criterion = nn.BCELoss()
# https://stackoverflow.com/a/67437077
self.discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.discriminator)
self.generator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.generator)
# First, define strategy-wise optional configurations
if isinstance(self.strategy, DeepSpeedStrategy):
# Batch size definition is not optional for DeepSpeedStrategy!
distribute_kwargs = dict(
config_params=dict(train_micro_batch_size_per_gpu=self.config.batch_size)
)
else:
distribute_kwargs = {}
# Distribute discriminator and its optimizer
self.discriminator, self.optimizerD, _ = self.strategy.distributed(
self.discriminator, self.optimizerD, **distribute_kwargs
)
self.generator, self.optimizerG, _ = self.strategy.distributed(
self.generator, self.optimizerG, **distribute_kwargs
)
[docs]
def train_epoch(self, epoch: int):
self.discriminator.train()
self.generator.train()
gen_train_losses = []
disc_train_losses = []
disc_train_accuracy = []
for batch_idx, (real_images, _) in enumerate(self.train_dataloader):
lossG, lossD, accuracy_disc = self.train_step(real_images, batch_idx)
gen_train_losses.append(lossG)
disc_train_losses.append(lossD)
disc_train_accuracy.append(accuracy_disc)
self.train_glob_step += 1
# Aggregate and log losses and accuracy
avg_disc_accuracy = torch.mean(torch.stack(disc_train_accuracy))
self.log(
item=avg_disc_accuracy.item(),
identifier="disc_train_accuracy_per_epoch",
kind="metric",
step=epoch,
)
avg_gen_loss = torch.mean(torch.stack(gen_train_losses))
self.log(
item=avg_gen_loss.item(),
identifier="gen_train_loss_per_epoch",
kind="metric",
step=epoch,
)
avg_disc_loss = torch.mean(torch.stack(disc_train_losses))
self.log(
item=avg_disc_loss.item(),
identifier="disc_train_loss_per_epoch",
kind="metric",
step=epoch,
)
self.save_fake_generator_images(epoch)
[docs]
def validation_epoch(self, epoch: int):
gen_validation_losses = []
gen_validation_accuracy = []
disc_validation_losses = []
disc_validation_accuracy = []
self.discriminator.eval()
self.generator.eval()
for batch_idx, (real_images, _) in enumerate(self.validation_dataloader):
loss_gen, accuracy_gen, loss_disc, accuracy_disc = self.validation_step(
real_images, batch_idx
)
gen_validation_losses.append(loss_gen)
gen_validation_accuracy.append(accuracy_gen)
disc_validation_losses.append(loss_disc)
disc_validation_accuracy.append(accuracy_disc)
self.validation_glob_step += 1
# Aggregate and log metrics
disc_validation_loss = torch.mean(torch.stack(disc_validation_losses))
self.log(
item=disc_validation_loss.item(),
identifier="disc_valid_loss_per_epoch",
kind="metric",
step=epoch,
)
disc_validation_accuracy = torch.mean(torch.stack(disc_validation_accuracy))
self.log(
item=disc_validation_accuracy.item(),
identifier="disc_valid_accuracy_epoch",
kind="metric",
step=epoch,
)
gen_validation_loss = torch.mean(torch.stack(gen_validation_losses))
self.log(
item=gen_validation_loss.item(),
identifier="gen_valid_loss_per_epoch",
kind="metric",
step=epoch,
)
gen_validation_accuracy = torch.mean(torch.stack(gen_validation_accuracy))
self.log(
item=gen_validation_accuracy.item(),
identifier="gen_valid_accuracy_epoch",
kind="metric",
step=epoch,
)
return gen_validation_loss
[docs]
def train_step(self, real_images, batch_idx):
real_images = real_images.to(self.device)
batch_size = real_images.size(0)
real_labels = torch.ones((batch_size,), dtype=torch.float, device=self.device)
fake_labels = torch.zeros((batch_size,), dtype=torch.float, device=self.device)
# Train Discriminator with real images
output_real = self.discriminator(real_images)
lossD_real = self.criterion(output_real, real_labels)
# Generate fake images and train Discriminator
noise = torch.randn(batch_size, self.config.z_dim, 1, 1, device=self.device)
fake_images = self.generator(noise)
output_fake = self.discriminator(fake_images.detach())
lossD_fake = self.criterion(output_fake, fake_labels)
lossD = (lossD_real + lossD_fake) / 2
self.optimizerD.zero_grad()
lossD.backward()
self.optimizerD.step()
accuracy = ((output_real > 0.5).float() == real_labels).float().mean() + (
(output_fake < 0.5).float() == fake_labels
).float().mean()
accuracy_disc = accuracy.mean()
# Train Generator
output_fake = self.discriminator(fake_images)
lossG = self.criterion(output_fake, real_labels)
self.optimizerG.zero_grad()
lossG.backward()
self.optimizerG.step()
self.log(
item=accuracy_disc,
identifier="disc_train_accuracy_per_batch",
kind="metric",
step=self.train_glob_step,
batch_idx=batch_idx,
)
self.log(
item=lossG,
identifier="gen_train_loss_per_batch",
kind="metric",
step=self.train_glob_step,
batch_idx=batch_idx,
)
self.log(
item=lossD,
identifier="disc_train_loss_per_batch",
kind="metric",
step=self.train_glob_step,
batch_idx=batch_idx,
)
return lossG, lossD, accuracy_disc
[docs]
def validation_step(self, real_images, batch_idx):
real_images = real_images.to(self.device)
batch_size = real_images.size(0)
real_labels = torch.ones((batch_size,), dtype=torch.float, device=self.device)
fake_labels = torch.zeros((batch_size,), dtype=torch.float, device=self.device)
# Validate with real images
output_real = self.discriminator(real_images)
loss_real = self.criterion(output_real, real_labels)
# Generate and validate fake images
noise = torch.randn(batch_size, self.config.z_dim, 1, 1, device=self.device)
with torch.no_grad():
fake_images = self.generator(noise)
output_fake = self.discriminator(fake_images.detach())
loss_fake = self.criterion(output_fake, fake_labels)
# Generator's attempt to fool the discriminator
loss_gen = self.criterion(output_fake, real_labels)
accuracy_gen = ((output_fake > 0.5).float() == real_labels).float().mean()
# Calculate total discriminator loss and accuracy
d_total_loss = (loss_real + loss_fake) / 2
accuracy = ((output_real > 0.5).float() == real_labels).float().mean() + (
(output_fake < 0.5).float() == fake_labels
).float().mean()
d_accuracy = accuracy.item() / 2
self.log(
item=loss_gen.item(),
identifier="gen_valid_loss_per_batch",
kind="metric",
step=self.validation_glob_step,
batch_idx=batch_idx,
)
self.log(
item=accuracy_gen.item(),
identifier="gen_valid_accuracy_per_batch",
kind="metric",
step=self.validation_glob_step,
batch_idx=batch_idx,
)
self.log(
item=d_total_loss.item(),
identifier="disc_valid_loss_per_batch",
kind="metric",
step=self.validation_glob_step,
batch_idx=batch_idx,
)
self.log(
item=d_accuracy,
identifier="disc_valid_accuracy_per_batch",
kind="metric",
step=self.validation_glob_step,
batch_idx=batch_idx,
)
return loss_gen, accuracy_gen
[docs]
def save_checkpoint(self, name, epoch, loss=None):
"""Save training checkpoint with both optimizers."""
if not os.path.exists(self.checkpoints_location):
os.makedirs(self.checkpoints_location)
checkpoint_path = os.path.join(self.checkpoints_location, f"{name}")
checkpoint = {
"epoch": epoch,
"loss": loss.item() if loss is not None else None,
"discriminator_state_dict": self.discriminator.state_dict(),
"generator_state_dict": self.generator.state_dict(),
"optimizerD_state_dict": self.optimizerD.state_dict(),
"optimizerG_state_dict": self.optimizerG.state_dict(),
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler else None,
}
torch.save(checkpoint, checkpoint_path)
print(f"Checkpoint saved to {checkpoint_path}")
[docs]
def load_checkpoint(self, checkpoint_path):
"""Load models and optimizers from checkpoint."""
checkpoint = torch.load(checkpoint_path)
self.discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
self.generator.load_state_dict(checkpoint["generator_state_dict"])
self.optimizerD.load_state_dict(checkpoint["optimizerD_state_dict"])
self.optimizerG.load_state_dict(checkpoint["optimizerG_state_dict"])
if "lr_scheduler" in checkpoint:
if checkpoint["lr_scheduler"] is not None:
self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
print(f"Checkpoint loaded from {checkpoint_path}")
[docs]
def save_fake_generator_images(self, epoch):
"""
plot and save fake images from generator
Args:
epoch (int): epoch number, from 0 to ``epochs-1``.
"""
import matplotlib.pyplot as plt
import numpy as np
self.generator.eval()
noise = torch.randn(64, self.config.z_dim, 1, 1, device=self.device)
fake_images = self.generator(noise)
fake_images_grid = torchvision.utils.make_grid(fake_images, normalize=True)
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_axis_off()
ax.set_title(f"Fake images for epoch {epoch}")
ax.imshow(np.transpose(fake_images_grid.cpu().numpy(), (1, 2, 0)))
self.log(
item=fig,
identifier=f"fake_images_epoch_{epoch}.png",
kind="figure",
step=epoch,
)
[docs]
class TorchLightningTrainer(Trainer):
"""Generic trainer for torch Lightning workflows.
Args:
config (Union[Dict, str]): `Lightning configuration`_
which can be the path to a file or a Python dictionary.
mlflow_saved_model (str, optional): name of the model created in
MLFlow. Defaults to 'my_model'.
.. _Lightning configuration:
https://pytorch-lightning.readthedocs.io/en/1.6.5/common/lightning_cli.html
"""
def __init__(self, config: Union[Dict, str], mlflow_saved_model: str = "my_model"):
self.save_parameters(**self.locals2params(locals()))
super().__init__()
if isinstance(config, str) and os.path.isfile(config):
# Load from YAML
config = load_yaml(config)
self.conf = config
self.mlflow_saved_model = mlflow_saved_model
[docs]
@monitor_exec
def execute(self) -> Any:
import lightning as L
from lightning.pytorch.cli import LightningCLI
from .mlflow import init_lightning_mlflow, teardown_lightning_mlflow
init_lightning_mlflow(
self.conf, tmp_dir="/tmp", registered_model_name=self.mlflow_saved_model
)
old_argv = sys.argv
sys.argv = ["some_script_placeholder.py"]
cli = LightningCLI(
args=self.conf,
model_class=L.LightningModule,
datamodule_class=L.LightningDataModule,
run=False,
save_config_kwargs={
"overwrite": True,
"config_filename": "pl-training.yml",
},
subclass_mode_model=True,
subclass_mode_data=True,
)
sys.argv = old_argv
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
teardown_lightning_mlflow()
def _distributed_dataloader(dataloader: DataLoader, gwsize, grank):
"""Makes a Dataloader distributed."""
sampler = DistributedSampler(
dataloader.dataset, num_replicas=gwsize, rank=grank, shuffle=True
)
# Recreate dataloader, with updated sampler
return DataLoader(
dataloader.dataset,
batch_size=dataloader.batch_size,
sampler=sampler,
num_workers=dataloader.num_workers,
collate_fn=dataloader.collate_fn,
pin_memory=dataloader.pin_memory,
drop_last=dataloader.drop_last,
timeout=dataloader.timeout,
worker_init_fn=seed_worker, # dataloader.worker_init_fn,
multiprocessing_context=dataloader.multiprocessing_context,
generator=dataloader.generator,
prefetch_factor=dataloader.prefetch_factor,
persistent_workers=dataloader.persistent_workers,
pin_memory_device=dataloader.pin_memory_device,
)
[docs]
def distributed(func):
"""The decorated function must have a standard signature.
Its first arguments must be:
model, train_dataloader, validation_dataloader, device (in this order).
Additional args or kwargs are allowed consistently with the signature
of the decorated function.
"""
def dist_train(
model, train_dataloader, validation_dataloader=None, device="cpu", *args, **kwargs
):
if torch.cuda.is_available():
dist.init_process_group(backend="nccl")
if torch.cuda.is_available():
lwsize = torch.cuda.device_count() # local world size - per node
gwsize = dist.get_world_size() # global world size - per run
grank = dist.get_rank() # global rank - assign per run
lrank = dist.get_rank() % lwsize # local rank - assign per node
else:
gwsize = 1
grank = 0
lrank = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu", lrank)
if torch.cuda.is_available():
torch.cuda.set_device(lrank)
model = model.to(device)
model = DDP(model, device_ids=[device], output_device=device)
train_dataloader = _distributed_dataloader(train_dataloader, gwsize, grank)
if validation_dataloader is not None:
validation_dataloader = _distributed_dataloader(
validation_dataloader, gwsize, grank
)
try:
func(model, train_dataloader, validation_dataloader, device, *args, **kwargs)
finally:
if torch.cuda.is_available():
dist.barrier()
dist.destroy_process_group()
return dist_train
[docs]
class RayTorchTrainer(Trainer):
"""A trainer class for distributed training and hyperparameter optimization
using Ray Train/ Tune and PyTorch.
Args:
config (Dict): A dictionary of configuration settings for the trainer.
strategy (Optional[Literal["ddp", "deepspeed"]]):
The distributed training strategy to use. Defaults to "ddp".
name (Optional[str]): Optional name for the trainer instance. Defaults to None.
logger (Optional[Logger]): Optional logger instance. Defaults to None.
"""
def __init__(
self,
config: Dict,
strategy: Literal["ddp", "deepspeed"] = "ddp",
name: str | None = None,
logger: Logger | None = None,
random_seed: int = 1234,
) -> None:
super().__init__(name=name)
self.logger = logger
self._set_strategy_and_init_ray(strategy)
self._set_configs(config=config)
self.torch_rng = set_seed(random_seed)
import ray.train
import ray.tune
self.ray_train = ray.train
self.ray_tune = ray.tune
def _set_strategy_and_init_ray(self, strategy: str) -> None:
"""Set the distributed training strategy. This will initialize the ray backend.
Args:
strategy (str): The strategy to use for distributed training.
Must be one of ["ddp", "deepspeed"].
Raises:
ValueError: If an unsupported strategy is provided.
"""
if strategy == "ddp":
self.strategy = RayDDPStrategy()
elif strategy == "deepspeed":
self.strategy = RayDeepSpeedStrategy(backend="nccl")
else:
raise ValueError(f"Unsupported strategy: {strategy}")
def _set_configs(self, config: Dict) -> None:
self.config = config
self._set_scaling_config()
self._set_tune_config()
self._set_run_config()
self._set_train_loop_config()
@property
def device(self) -> str:
"""Get the current device from distributed strategy.
Returns:
str: Device string (e.g., "cuda:0").
"""
return self.strategy.device()
[docs]
def create_dataloaders(
self,
train_dataset: Dataset,
validation_dataset: Dataset | None = None,
test_dataset: Dataset | None = None,
batch_size: int = 1,
num_workers_dataloader: int = 4,
pin_memory: bool = False,
shuffle_train: bool | None = False,
shuffle_test: bool | None = False,
shuffle_validation: bool | None = False,
sampler: Union[Sampler, Iterable, None] = None,
collate_fn: Callable[[List], Any] | None = None,
) -> None:
"""Create data loaders for training, validation, and testing.
Args:
train_dataset (Dataset): The training dataset.
validation_dataset (Dataset, optional): The validation dataset. Defaults to None.
test_dataset (Dataset, optional): The test dataset. Defaults to None.
batch_size (int, optional): Batch size for data loaders. Defaults to 1.
shuffle_train (bool, optional): Whether to shuffle the training dataset.
Defaults to False.
shuffle_test (bool, optional): Whether to shuffle the test dataset.
Defaults to False.
shuffle_validation (bool, optional): Whether to shuffle the validation dataset.
Defaults to False.
sampler (Union[Sampler, Iterable, None], optional): Sampler for the datasets.
Defaults to None.
collate_fn (Callable[[List], Any], optional):
Function to collate data samples into batches. Defaults to None.
"""
self.train_dataloader = self.strategy.create_dataloader(
dataset=train_dataset,
batch_size=batch_size,
num_workers=num_workers_dataloader,
pin_memory=pin_memory,
generator=self.torch_rng,
shuffle=shuffle_train,
sampler=sampler,
collate_fn=collate_fn,
)
if validation_dataset is not None:
self.validation_dataloader = self.strategy.create_dataloader(
dataset=validation_dataset,
batch_size=batch_size,
num_workers=num_workers_dataloader,
pin_memory=pin_memory,
generator=self.torch_rng,
shuffle=shuffle_validation,
sampler=sampler,
collate_fn=collate_fn,
)
else:
self.validation_dataloader = None
if test_dataset is not None:
self.test_dataloader = self.strategy.create_dataloader(
dataset=test_dataset,
batch_size=batch_size,
num_workers=num_workers_dataloader,
pin_memory=pin_memory,
generator=self.torch_rng,
shuffle=shuffle_test,
sampler=sampler,
collate_fn=collate_fn,
)
else:
self.test_dataloader = None
[docs]
@monitor_exec
def execute(
self,
train_dataset: Dataset,
validation_dataset: Dataset | None = None,
test_dataset: Dataset | None = None,
) -> Tuple[Dataset, Dataset, Dataset, Any]:
"""Execute the training pipeline with the given datasets.
Args:
train_dataset (Dataset): Training dataset.
validation_dataset (Optional[Dataset], optional): Validation dataset.
Defaults to None.
test_dataset (Optional[Dataset], optional): Test dataset. Defaults to None.
Returns:
Tuple[Dataset, Dataset, Dataset, Any]:
A tuple containing the datasets and the training result grid.
"""
import ray.train.torch
train_with_data = self.ray_tune.with_parameters(
self.train, data=[train_dataset, validation_dataset, test_dataset]
)
trainer = ray.train.torch.TorchTrainer(
train_with_data,
train_loop_config=self.train_loop_config,
scaling_config=self.scaling_config,
run_config=self.run_config,
)
param_space = {"train_loop_config": self.train_loop_config}
tuner = self.ray_tune.Tuner(
trainer, param_space=param_space, tune_config=self.tune_config
)
result_grid = tuner.fit()
return train_dataset, validation_dataset, test_dataset, result_grid
[docs]
def set_epoch(self, epoch: int) -> None:
self.train_dataloader.sampler.set_epoch(epoch)
if self.validation_dataloader is not None:
self.validation_dataloader.sampler.set_epoch(epoch)
if self.test_dataloader is not None:
self.test_dataloader.sampler.set_epoch(epoch)
def _set_tune_config(self) -> None:
from .tuning import get_raytune_scheduler, get_raytune_search_alg
tune_config = self.config.get("tune_config", {})
if not tune_config:
print(
"WARNING: Empty Tune Config configured. Using the default configuration with "
"a single trial."
)
search_alg = get_raytune_search_alg(tune_config)
scheduler = get_raytune_scheduler(tune_config)
metric = tune_config.get("metric", "loss")
mode = tune_config.get("mode", "min")
try:
self.tune_config = self.ray_tune.TuneConfig(
**tune_config,
search_alg=search_alg,
scheduler=scheduler,
metric=metric,
mode=mode,
)
except AttributeError as e:
print(
"Could not set Tune Config. Please ensure that you have passed the "
"correct arguments for it. You can find more information for which "
"arguments to set at "
"https://docs.ray.io/en/latest/tune/api/doc/ray.tune.TuneConfig.html."
)
print(e)
def _set_scaling_config(self) -> None:
scaling_config = self.config.get("scaling_config", {})
if not scaling_config:
print("WARNING: No Scaling Config configured. Running trials non-distributed.")
try:
self.scaling_config = self.ray_train.ScalingConfig(**scaling_config)
except AttributeError as e:
print(
"Could not set Scaling Config. Please ensure that you have passed the "
"correct arguments for it. You can find more information for which "
"arguments to set at "
"https://docs.ray.io/en/latest/train/api/doc/ray.train.ScalingConfig.html"
)
print(e)
def _set_run_config(self) -> None:
run_config = self.config.get("run_config", {})
if not run_config:
print("WARNING: No RunConfig provided. Assuming local or single-node execution.")
try:
storage_path = Path(run_config.pop("storage_path")).resolve()
if not storage_path:
print(
"INFO: Empty storage path provided. Using default path 'ray_checkpoints'"
)
storage_path = Path("ray_checkpoints").resolve()
self.run_config = self.ray_train.RunConfig(**run_config, storage_path=storage_path)
except AttributeError as e:
print(
"Could not set Run Config. Please ensure that you have passed the "
"correct arguments for it. You can find more information for which "
"arguments to set at "
"https://docs.ray.io/en/latest/train/api/doc/ray.train.RunConfig.html"
)
print(e)
def _set_train_loop_config(self) -> None:
self.train_loop_config = self.config.get("train_loop_config", {})
if not self.train_loop_config:
print(
"WARNING: No training_loop_config detected. "
"If you want to tune any hyperparameters, make sure to define them here."
)
return
try:
for name, param in self.train_loop_config.items():
if not isinstance(param, dict):
continue
# Convert specific keys to float if necessary
for key in ["lower", "upper", "mean", "std"]:
if key in param:
param[key] = float(param[key])
param_type = param.pop("type")
param = getattr(self.ray_tune, param_type)(**param)
self.train_loop_config[name] = param
except AttributeError as e:
print(
f"{param} could not be set. Check that this parameter type is "
"supported by Ray Tune at "
"https://docs.ray.io/en/latest/tune/api/search_space.html"
)
print(e)
# TODO: Can I also log the checkpoint?
[docs]
def checkpoint_and_report(self, epoch, tuning_metrics, checkpointing_data=None):
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
checkpoint = None
should_checkpoint = epoch % self.config.get("checkpoint_freq", 1)
if checkpointing_data and should_checkpoint:
torch.save(checkpointing_data, os.path.join(temp_checkpoint_dir, str(epoch)))
checkpoint = self.ray_train.Checkpoint.from_directory(temp_checkpoint_dir)
self.ray_train.report(tuning_metrics, checkpoint=checkpoint)
[docs]
def initialize_logger(self, hyperparams: Optional[Dict], rank):
if not self.logger:
return
self.logger.create_logger_context(rank=rank)
print(f"Logger initialized with rank {rank}")
if hyperparams:
self.logger.save_hyperparameters(hyperparams)
[docs]
def close_logger(self):
if self.logger:
self.logger.destroy_logger_context()
[docs]
def log(
self,
item: Union[Any, List[Any]],
identifier: Union[str, List[str]],
kind: str = "metric",
step: Optional[int] = None,
batch_idx: Optional[int] = None,
**kwargs,
) -> None:
"""Log ``item`` with ``identifier`` name of ``kind`` type at ``step``
time step.
Args:
item (Union[Any, List[Any]]): element to be logged (e.g., metric).
identifier (Union[str, List[str]]): unique identifier for the
element to log(e.g., name of a metric).
kind (str, optional): type of the item to be logged. Must be one
among the list of self.supported_types. Defaults to 'metric'.
step (Optional[int], optional): logging step. Defaults to None.
batch_idx (Optional[int], optional): DataLoader batch counter
(i.e., batch idx), if available. Defaults to None.
"""
if self.logger:
self.logger.log(
item=item,
identifier=identifier,
kind=kind,
step=step,
batch_idx=batch_idx,
**kwargs,
)
else:
print(
"INFO: The log method was called, but no logger was configured for this "
"Trainer."
)