# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# - Henry Mutegeki <henry.mutegeki@cern.ch> - CERN
# - Linus Eickhoff <linus.maximilian.eickhoff@cern.ch> - CERN
# --------------------------------------------------------------------------------------
"""Provides training logic for PyTorch models via Trainer classes."""
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Literal, Optional, Tuple
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torchvision
import yaml
from ray.train import DataConfig, RunConfig, ScalingConfig
from ray.train.torch import TorchConfig
from ray.tune import TuneConfig
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric
from torchmetrics.image.fid import FrechetInceptionDistance
from ..loggers import Logger
from .config import TrainingConfiguration
from .trainer import TorchTrainer
if TYPE_CHECKING:
from ray.train.horovod import HorovodConfig
py_logger = logging.getLogger(__name__)
[docs]
class GANTrainingConfiguration(TrainingConfiguration):
"""Configuration object for training a GAN. Extends the base TrainingConfiguration."""
#: Name of the optimizer to use for the generator. Defaults to 'adam'.
optimizer_generator: Literal["adadelta", "adam", "adamw", "rmsprop", "sgd"] = "adam"
#: Learning rate used by the optimizer for the generator. Defaults to 1e-3.
optim_generator_lr: float = 1e-3
#: Momentum used by some optimizers (e.g., SGD) for the generator. Defaults to 0.9.
optim_generator_momentum: float = 0.9
#: Betas of Adam optimized (if used) for the generator. Defaults to (0.5, 0.999).
optim_generator_betas: Tuple[float, float] = (0.5, 0.999)
#: Weight decay parameter for the optimizer for the generator. Defaults to 0.
optim_generator_weight_decay: float = 0.0
#: Learning rate scheduler algorithm for the generator optimizer.
#: Defaults to None (not used).
lr_scheduler_generator: (
Literal["step", "multistep", "constant", "linear", "exponential", "polynomial"] | None
) = None
#: Learning rate scheduler step size, if needed by the scheduler. Defaults to 10 (epochs).
lr_scheduler_generator_step_size: int | Iterable[int] = 10
#: Learning rate scheduler step size, if needed by the scheduler.
#: Usually this is used by the ExponentialLR.
# : Defaults to 0.5.
lr_scheduler_generator_gamma: float = 0.95
#: Name of the optimizer to use for the discriminator. Defaults to 'adam'.
optimizer_discriminator: Literal["adadelta", "adam", "adamw", "rmsprop", "sgd"] = "adam"
#: Learning rate used by the optimizer for the discriminator. Defaults to 1e-3.
optim_discriminator_lr: float = 1e-3
#: Momentum used by some optimizers (e.g., SGD) for the discriminator. Defaults to 0.9.
optim_discriminator_momentum: float = 0.9
#: Betas of Adam optimized (if used) for the discriminator. Defaults to (0.5, 0.999).
optim_discriminator_betas: Tuple[float, float] = (0.5, 0.999)
#: Weight decay parameter for the optimizer for the discriminator. Defaults to 0.
optim_discriminator_weight_decay: float = 0.0
#: Learning rate scheduler algorithm for the discriminator optimizer.
#: Defaults to None (not used).
lr_scheduler_discriminator: (
Literal["step", "multistep", "constant", "linear", "exponential", "polynomial"] | None
) = None
#: Learning rate scheduler step size, if needed by the scheduler. Defaults to 10 (epochs).
lr_scheduler_discriminator_step_size: int | Iterable[int] = 10
#: Learning rate scheduler step size, if needed by the scheduler.
#: Usually this is used by the ExponentialLR.
# : Defaults to 0.5.
lr_scheduler_discriminator_gamma: float = 0.95
#: Classification criterion to be used for generator and discriminator losses. Defaults to
#: "bceloss".
loss: str = "bceloss"
#: Generator input size (random noise size). Defaults to 100.
z_dim: int = 100
[docs]
class GANTrainer(TorchTrainer):
"""Trainer class for GAN models using pytorch.
Args:
config (Dict | TrainingConfiguration): training configuration
containing hyperparameters.
epochs (int): number of training epochs.
discriminator (nn.Module): pytorch discriminator model to train GAN.
generator (nn.Module): pytorch generator model to train GAN.
strategy (Literal['ddp', 'deepspeed', 'horovod'], optional):
distributed strategy. Defaults to 'ddp'.
test_every (int | None, optional): run a test epoch
every ``test_every`` epochs. Disabled if None. Defaults to None.
random_seed (int | None, optional): set random seed for
reproducibility. If None, the seed is not set. Defaults to None.
logger (Logger | None, optional): logger for ML tracking.
Defaults to None.
metrics (Dict[str, Callable] | None, optional): map of torch metrics
metrics. Defaults to None.
checkpoints_location (str): path to checkpoints directory.
Defaults to "checkpoints".
checkpoint_every (int | None): save a checkpoint every
``checkpoint_every`` epochs. Disabled if None. Defaults to None.
name (str | None, optional): trainer custom name. Defaults to None.
profiling_wait_epochs (int): how many epochs to wait before starting
the profiler.
profiling_warmup_epochs (int): length of the profiler warmup phase in terms of
number of epochs.
ray_scaling_config (ScalingConfig, optional): scaling config for Ray Trainer.
Defaults to None,
ray_tune_config (TuneConfig, optional): tune config for Ray Tuner.
Defaults to None.
ray_run_config (RunConfig, optional): run config for Ray Trainer.
Defaults to None.
ray_search_space (Dict[str, Any], optional): search space for Ray Tuner.
Defaults to None.
ray_torch_config (TorchConfig, optional): torch configuration for Ray's TorchTrainer.
Defaults to None.
ray_data_config (DataConfig, optional): dataset configuration for Ray.
Defaults to None.
ray_horovod_config (HorovodConfig, optional): horovod configuration for Ray's
HorovodTrainer. Defaults to None.
from_checkpoint (str | Path, optional): path to checkpoint directory. Defaults to None.
"""
#: PyTorch generator to train.
generator: nn.Module | None = None
#: PyTorch discriminator to train.
discriminator: nn.Module | None = None
#: Classification loss criterion used in the generator and discriminator losses.
loss: Callable | None = None
#: Optimizer for the generator.
optimizer_generator: Optimizer | None = None
#: Optimizer for the discriminator.
optimizer_discriminator: Optimizer | None = None
#: Learning rate scheduler for the optimizer of the generator.
lr_scheduler_generator: LRScheduler | None = None
#: Learning rate scheduler for the optimizer of the discriminator.
lr_scheduler_discriminator: LRScheduler | None = None
def __init__(
self,
config: Dict | GANTrainingConfiguration,
epochs: int,
discriminator: nn.Module,
generator: nn.Module,
strategy: Literal["ddp", "deepspeed"] = "ddp",
test_every: int | None = None,
random_seed: int | None = None,
logger: Logger | None = None,
metrics: Dict[str, Metric] | None = None,
checkpoints_location: str = "checkpoints",
checkpoint_every: int | None = None,
name: str | None = None,
profiling_wait_epochs: int = 1,
profiling_warmup_epochs: int = 2,
ray_scaling_config: ScalingConfig | None = None,
ray_tune_config: TuneConfig | None = None,
ray_run_config: RunConfig | None = None,
ray_search_space: Dict[str, Any] | None = None,
ray_torch_config: TorchConfig | None = None,
ray_data_config: DataConfig | None = None,
ray_horovod_config: Optional["HorovodConfig"] = None,
from_checkpoint: str | Path | None = None,
**kwargs,
) -> None:
super().__init__(
config=config,
epochs=epochs,
model=None,
strategy=strategy,
test_every=test_every,
random_seed=random_seed,
logger=logger,
metrics=metrics,
checkpoints_location=checkpoints_location,
checkpoint_every=checkpoint_every,
name=name,
profiling_wait_epochs=profiling_wait_epochs,
profiling_warmup_epochs=profiling_warmup_epochs,
ray_scaling_config=ray_scaling_config,
ray_tune_config=ray_tune_config,
ray_run_config=ray_run_config,
ray_search_space=ray_search_space,
ray_horovod_config=ray_horovod_config,
from_checkpoint=from_checkpoint,
ray_torch_config=ray_torch_config,
ray_data_config=ray_data_config,
**kwargs,
)
self.save_parameters(**self.locals2params(locals()))
self.discriminator = discriminator
self.generator = generator
if isinstance(config, dict):
config = GANTrainingConfiguration(**config)
self.config = config
self.epoch = 0
# Initial training state -- can be resumed from a checkpoint
self.discriminator_state_dict = None
self.generator_state_dict = None
self.optimizer_discriminator_state_dict = None
self.optimizer_generator_state_dict = None
self.lr_scheduler_generator_state_dict = None
self.lr_scheduler_discriminator_state_dict = None
def _optimizer_from_config(self) -> None:
match self.config.optimizer_generator:
case "adadelta":
self.optimizer_generator = optim.Adadelta(
self.generator.parameters(),
lr=self.config.optim_generator_lr,
weight_decay=self.config.optim_generator_weight_decay,
)
case "adam":
self.optimizer_generator = optim.Adam(
self.generator.parameters(),
lr=self.config.optim_generator_lr,
betas=self.config.optim_generator_betas,
weight_decay=self.config.optim_generator_weight_decay,
)
case "adamw":
self.optimizer_generator = optim.AdamW(
self.generator.parameters(),
lr=self.config.optim_generator_lr,
betas=self.config.optim_generator_betas,
weight_decay=self.config.optim_generator_weight_decay,
)
case "rmsprop":
self.optimizer_generator = optim.RMSprop(
self.generator.parameters(),
lr=self.config.optim_generator_lr,
weight_decay=self.config.optim_generator_weight_decay,
momentum=self.config.optim_generator_momentum,
)
case "sgd":
self.optimizer_generator = optim.SGD(
self.generator.parameters(),
lr=self.config.optim_generator_lr,
weight_decay=self.config.optim_generator_weight_decay,
momentum=self.config.optim_generator_momentum,
)
case _:
raise ValueError(
"Unrecognized self.config.optimizer_generator! Check the docs for "
"supported values and consider overriding "
"create_model_loss_optimizer method for more flexibility."
)
match self.config.optimizer_discriminator:
case "adadelta":
self.optimizer_discriminator = optim.Adadelta(
self.discriminator.parameters(),
lr=self.config.optim_discriminator_lr,
weight_decay=self.config.optim_discriminator_weight_decay,
)
case "adam":
self.optimizer_discriminator = optim.Adam(
self.discriminator.parameters(),
lr=self.config.optim_discriminator_lr,
betas=self.config.optim_discriminator_betas,
weight_decay=self.config.optim_discriminator_weight_decay,
)
case "adamw":
self.optimizer_discriminator = optim.AdamW(
self.discriminator.parameters(),
lr=self.config.optim_discriminator_lr,
betas=self.config.optim_discriminator_betas,
weight_decay=self.config.optim_discriminator_weight_decay,
)
case "rmsprop":
self.optimizer_discriminator = optim.RMSprop(
self.discriminator.parameters(),
lr=self.config.optim_discriminator_lr,
weight_decay=self.config.optim_discriminator_weight_decay,
momentum=self.config.optim_discriminator_momentum,
)
case "sgd":
self.optimizer_discriminator = optim.SGD(
self.discriminator.parameters(),
lr=self.config.optim_discriminator_lr,
weight_decay=self.config.optim_discriminator_weight_decay,
momentum=self.config.optim_discriminator_momentum,
)
case _:
raise ValueError(
"Unrecognized self.config.optimizer_discriminator! Check the docs for "
"supported values and consider overriding "
"create_model_loss_optimizer method for more flexibility."
)
def _lr_scheduler_from_config(self) -> None:
"""Parse Lr scheduler from training config"""
if self.config.lr_scheduler_generator:
if not self.optimizer_generator:
raise ValueError(
"Trying to instantiate a LR scheduler but the optimizer_generator is None!"
)
match self.config.lr_scheduler_generator:
case "constant":
self.lr_scheduler_generator = lr_scheduler.ConstantLR(
self.optimizer_generator
)
case "polynomial":
self.lr_scheduler_generator = lr_scheduler.PolynomialLR(
self.optimizer_generator
)
case "exponential":
self.lr_scheduler_generator = lr_scheduler.ExponentialLR(
self.optimizer_generator,
gamma=self.config.lr_scheduler_generator_gamma,
)
case "linear":
self.lr_scheduler_generator = lr_scheduler.LinearLR(
self.optimizer_generator
)
case "multistep":
self.lr_scheduler_generator = lr_scheduler.MultiStepLR(
self.optimizer_generator,
milestones=self.config.lr_scheduler_generator_step_size,
)
case "step":
self.lr_scheduler_generator = lr_scheduler.StepLR(
self.optimizer_generator,
step_size=self.config.lr_scheduler_generator_step_size,
)
case _:
raise ValueError(
"Unrecognized self.config.lr_scheduler_generator! Check the docs for "
"supported values and consider overriding "
"create_model_loss_optimizer method for more flexibility."
)
if self.config.lr_scheduler_discriminator:
if not self.optimizer_discriminator:
raise ValueError(
"Trying to instantiate a LR scheduler but the optimizer_discriminator "
"is None!"
)
match self.config.lr_scheduler_discriminator:
case "constant":
self.lr_scheduler_discriminator = lr_scheduler.ConstantLR(
self.optimizer_discriminator
)
case "polynomial":
self.lr_scheduler_discriminator = lr_scheduler.PolynomialLR(
self.optimizer_discriminator
)
case "exponential":
self.lr_scheduler_discriminator = lr_scheduler.ExponentialLR(
self.optimizer_discriminator,
gamma=self.config.lr_scheduler_discriminator_gamma,
)
case "linear":
self.lr_scheduler_discriminator = lr_scheduler.LinearLR(
self.optimizer_discriminator
)
case "multistep":
self.lr_scheduler_discriminator = lr_scheduler.MultiStepLR(
self.optimizer_discriminator,
milestones=self.config.lr_scheduler_discriminator_step_size,
)
case "step":
self.lr_scheduler_discriminator = lr_scheduler.StepLR(
self.optimizer_discriminator,
step_size=self.config.lr_scheduler_discriminator_step_size,
)
case _:
raise ValueError(
"Unrecognized self.config.lr_scheduler_discriminator! Check the "
"docs for supported values and consider overriding "
"create_model_loss_optimizer method for more flexibility."
)
[docs]
def create_model_loss_optimizer(self) -> None:
"""Instantiate a torch model, loss, optimizer, and LR scheduler using the
configuration provided in the Trainer constructor.
Generally a user-defined method.
"""
###################################
# Dear user, this is a method you #
# may be interested to override! #
###################################
# Model, optimizer, and lr scheduler may have already been loaded from a checkpoint
if self.generator is None or self.discriminator is None:
raise ValueError(
"self.generator or self.discrimintaor is None! "
"Either pass it to the constructor, load a checkpoint, or "
"override create_model_loss_optimizer method."
)
if self.generator_state_dict:
# Load generator from checkpoint
self.generator.load_state_dict(self.generator_state_dict, strict=False)
if self.discriminator_state_dict:
# Load discriminator from checkpoint
self.discriminator.load_state_dict(self.discriminator_state_dict, strict=False)
# Parse optimizers from training configuration
# Optimizers can be changed with a custom one here!
self._optimizer_from_config()
# Parse LR schedulers from training configuration
# LR schedulers can be changed with a custom one here!
self._lr_scheduler_from_config()
if self.optimizer_generator_state_dict:
# Load optimizer state from checkpoint
# IMPORTANT: this must be after the learning rate scheduler was already initialized
# by passing to it the optimizer. Otherwise the optimizer state just loaded will
# be modified by the lr scheduler.
self.optimizer_generator.load_state_dict(self.optimizer_generator_state_dict)
if self.optimizer_discriminator_state_dict:
# Load optimizer state from checkpoint
# IMPORTANT: this must be after the learning rate scheduler was already initialized
# by passing to it the optimizer. Otherwise the optimizer state just loaded will
# be modified by the lr scheduler.
self.optimizer_discriminator.load_state_dict(
self.optimizer_discriminator_state_dict
)
if self.lr_scheduler_generator_state_dict and self.lr_scheduler_generator:
# Load LR scheduler state from checkpoint
self.lr_scheduler_generator.load_state_dict(self.lr_scheduler_generator_state_dict)
if self.lr_scheduler_discriminator_state_dict and self.lr_scheduler_discriminator:
# Load LR scheduler state from checkpoint
self.lr_scheduler_discriminator.load_state_dict(
self.lr_scheduler_discriminator_state_dict
)
# Parse loss from training configuration
# Loss can be change with a custom one here!
self._set_loss_from_config()
self.criterion = self.loss
# if not self.optimizer_discriminator:
# self.optimizer_discriminator = optim.Adam(
# self.discriminator.parameters(), lr=self.config.lr, betas=(0.5, 0.999)
# )
# if not self.optimizer_generator:
# self.optimizer_generator = optim.Adam(
# self.generator.parameters(), lr=self.config.lr, betas=(0.5, 0.999)
# )
# self.criterion = nn.BCELoss()
# https://stackoverflow.com/a/67437077
self.discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.discriminator)
self.generator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.generator)
# IMPORTANT: model, optimizer, and scheduler need to be distributed from here on
distribute_kwargs = self.get_default_distributed_kwargs()
# Distribute discriminator and its optimizer
self.discriminator, self.optimizer_discriminator, _ = self.strategy.distributed(
self.discriminator, self.optimizer_discriminator, **distribute_kwargs
)
self.generator, self.optimizer_generator, _ = self.strategy.distributed(
self.generator, self.optimizer_generator, **distribute_kwargs
)
[docs]
def save_checkpoint(
self,
name: str,
best_validation_metric: torch.Tensor | None = None,
checkpoints_root: str | Path | None = None,
force: bool = False,
) -> str | None:
"""Save training checkpoint.
Args:
name (str): name of the checkpoint directory.
best_validation_metric (torch.Tensor | None) : best validation loss throughout
training so far (if available).
checkpoints_root (str | None): path for root checkpoints dir. If None, uses
``self.checkpoints_location`` as base.
force (bool): force checkpointign now.
Returns:
path to the checkpoint file or ``None`` when the checkpoint is not created.
"""
if not (
force
or self.strategy.is_main_worker
and self.checkpoint_every
and (self.epoch + 1) % self.checkpoint_every == 0
):
# Do nothing and return
return
ckpt_dir = Path(checkpoints_root or self.checkpoints_location) / name
ckpt_dir.mkdir(parents=True, exist_ok=True)
# Save state (epoch, loss, optimizer, scheduler)
state = {
"epoch": self.epoch,
# This could store the best validation loss
"best_validation_metric": (
best_validation_metric.item() if best_validation_metric is not None else None
),
"optimizer_generator_state_dict": self.optimizer_generator.state_dict(),
"optimizer_discriminator_state_dict": self.optimizer_discriminator.state_dict(),
"lr_scheduler_generator_state_dict": (
self.lr_scheduler_generator.state_dict()
if self.lr_scheduler_generator is not None
else None
),
"lr_scheduler_discriminator_state_dict": (
self.lr_scheduler_discriminator.state_dict()
if self.lr_scheduler_discriminator is not None
else None
),
"torch_rng_state": self.torch_rng.get_state(),
"random_seed": self.random_seed,
}
state_path = ckpt_dir / "state.pt"
torch.save(state, state_path)
# Save PyTorch models separately
# TODO: check that the state dict is stripped from any distributed info
generator_path = ckpt_dir / "generator.pt"
torch.save(self.generator.state_dict(), generator_path)
discriminator_path = ckpt_dir / "discriminator.pt"
torch.save(self.discriminator.state_dict(), discriminator_path)
# Save Pydantic config as YAML
config_path = ckpt_dir / "config.yaml"
with config_path.open("w") as f:
yaml.safe_dump(self.config.model_dump(), f)
# Log each file with an appropriate identifier
self.log(str(state_path), f"{name}_state", kind="artifact")
self.log(str(generator_path), f"{name}_generator", kind="artifact")
self.log(str(discriminator_path), f"{name}_discriminator", kind="artifact")
self.log(str(config_path), f"{name}_config", kind="artifact")
return str(ckpt_dir)
def _load_checkpoint(self, checkpoint_dir: str | Path) -> None:
"""Load checkpoint from path."""
checkpoint_dir = Path(checkpoint_dir)
state = torch.load(checkpoint_dir / "state.pt")
# Override initial training state
self.generator_state_dict_state_dict = torch.load(checkpoint_dir / "generator.pt")
self.discriminator_state_dict = torch.load(checkpoint_dir / "discriminator.pt")
self.optimizer_generator_state_dict = state["optimizer_generator_state_dict"]
self.optimizer_discriminator_state_dict = state["optimizer_discriminator_state_dict"]
self.lr_scheduler_generator_state_dict = state["lr_scheduler_generator_state_dict"]
self.lr_scheduler_discriminator_state_dict = state[
"lr_scheduler_discriminator_state_dict"
]
self.torch_rng_state = state["torch_rng_state"]
# Direct overrides (don't require further attention)
self.random_seed = state["random_seed"]
self.epoch = state["epoch"] + 1 # Start from next epoch
if state["best_validation_metric"]:
self.best_validation_metric = state["best_validation_metric"]
[docs]
def train_epoch(self):
self.discriminator.train()
self.generator.train()
gen_train_losses = []
disc_train_losses = []
disc_train_accuracy = []
for batch_idx, (real_images, _) in enumerate(self.train_dataloader):
loss_gen, loss_disc, accuracy_disc = self.train_step(real_images, batch_idx)
gen_train_losses.append(loss_gen)
disc_train_losses.append(loss_disc)
disc_train_accuracy.append(accuracy_disc)
self.train_glob_step += 1
# Aggregate and log losses and accuracy
avg_disc_accuracy = torch.mean(torch.stack(disc_train_accuracy))
self.log(
item=avg_disc_accuracy.item(),
identifier="disc_train_accuracy_per_epoch",
kind="metric",
step=self.epoch,
)
avg_gen_loss = torch.mean(torch.stack(gen_train_losses))
self.log(
item=avg_gen_loss.item(),
identifier="gen_train_loss_per_epoch",
kind="metric",
step=self.epoch,
)
avg_disc_loss = torch.mean(torch.stack(disc_train_losses))
self.log(
item=avg_disc_loss.item(),
identifier="disc_train_loss_per_epoch",
kind="metric",
step=self.epoch,
)
self.save_fake_generator_images()
[docs]
def train_step(
self, real_images: torch.Tensor, batch_idx: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""train step for GAN.
Args:
real_images (torch.Tensor): real images.
batch_idx (int): batch index.
Returns:
torch.Tensor: loss of the discriminator
torch.Tensor: loss of the generator
torch.Tensor: accuracy of the discriminator
"""
real_images = real_images.to(self.device)
batch_size = real_images.size(0)
real_labels = torch.ones((batch_size,), dtype=torch.float, device=self.device)
fake_labels = torch.zeros((batch_size,), dtype=torch.float, device=self.device)
# Train Discriminator with real images
output_real = self.discriminator(real_images)
loss_disc_real = self.criterion(output_real, real_labels)
# Generate fake images and train Discriminator
noise = torch.randn(batch_size, self.config.z_dim, 1, 1, device=self.device)
fake_images = self.generator(noise)
output_fake = self.discriminator(fake_images.detach())
loss_disc_fake = self.criterion(output_fake, fake_labels)
loss_disc = (loss_disc_real + loss_disc_fake) / 2
self.optimizer_discriminator.zero_grad()
loss_disc.backward()
self.optimizer_discriminator.step()
accuracy = ((output_real > 0.5).float() == real_labels).float().mean() + (
(output_fake < 0.5).float() == fake_labels
).float().mean()
accuracy_disc = accuracy.mean()
# Train Generator
output_fake = self.discriminator(fake_images)
loss_gen = self.criterion(output_fake, real_labels)
self.optimizer_generator.zero_grad()
loss_gen.backward()
self.optimizer_generator.step()
self.log(
item=accuracy_disc,
identifier="disc_train_accuracy_per_batch",
kind="metric",
step=self.train_glob_step,
batch_idx=batch_idx,
)
self.log(
item=loss_gen,
identifier="gen_train_loss_per_batch",
kind="metric",
step=self.train_glob_step,
batch_idx=batch_idx,
)
self.log(
item=loss_disc,
identifier="disc_train_loss_per_batch",
kind="metric",
step=self.train_glob_step,
batch_idx=batch_idx,
)
return loss_gen, loss_disc, accuracy_disc
[docs]
def validation_epoch(self, fid_features: int = 2048) -> torch.Tensor:
"""Validation epoch for GAN.
Args:
fid_features (int, optional): number of features for InceptionV3 modela.
Defaults to 2048.
Returns:
torch.Tensor: FID score that is returned by the FID metric.
"""
gen_validation_accuracy = []
disc_validation_accuracy = []
self.discriminator.eval()
self.generator.eval()
fid = FrechetInceptionDistance(feature=fid_features, normalize=True)
# known to be unstable with float32
# (https://lightning.ai/docs/torchmetrics/stable/image/frechet_inception_distance.html)
fid.set_dtype(torch.float64)
# Move FID to the same device as GAN
fid = fid.to(self.device)
for batch_idx, (real_images, _) in enumerate(self.validation_dataloader):
accuracy_gen, accuracy_disc = self.validation_step(real_images, batch_idx, fid)
gen_validation_accuracy.append(accuracy_gen)
disc_validation_accuracy.append(accuracy_disc)
self.validation_glob_step += 1
# Aggregate and log metrics
disc_validation_accuracy = torch.mean(torch.stack(disc_validation_accuracy))
# Compute FID score using InceptionV3 model
fid_score = fid.compute()
self.log(
item=disc_validation_accuracy.item(),
identifier="disc_valid_accuracy_epoch",
kind="metric",
step=self.epoch,
)
gen_validation_accuracy = torch.mean(torch.stack(gen_validation_accuracy))
self.log(
item=gen_validation_accuracy.item(),
identifier="gen_valid_accuracy_epoch",
kind="metric",
step=self.epoch,
)
self.log(
item=fid_score.item(),
identifier="gen_valid_fid_score_epoch",
kind="metric",
step=self.epoch,
)
return fid_score
[docs]
def validation_step(
self, real_images: torch.Tensor, batch_idx: int, fid: FrechetInceptionDistance
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Validation step for GAN.
Args:
real_images (torch.Tensor): real images.
batch_idx (int): batch index.
fid (FrechetInceptionDistance): FID metric.
Returns:
torch.Tensor: accuracy of the generator
torch.Tensor: accuracy of the discriminator
"""
real_images = real_images.to(self.device)
batch_size = real_images.size(0)
real_labels = torch.ones((batch_size,), dtype=torch.float, device=self.device)
fake_labels = torch.zeros((batch_size,), dtype=torch.float, device=self.device)
# Validate with real images
output_real = self.discriminator(real_images)
# Generate and validate fake images
noise = torch.randn(batch_size, self.config.z_dim, 1, 1, device=self.device)
with torch.no_grad():
fake_images = self.generator(noise)
output_fake = self.discriminator(fake_images.detach())
# Generator's attempt to fool the discriminator
accuracy_gen = ((output_fake > 0.5).float() == real_labels).float().mean()
# Calculate total discriminator loss and accuracy
accuracy = ((output_real > 0.5).float() == real_labels).float().mean() + (
(output_fake < 0.5).float() == fake_labels
).float().mean()
accuracy_disc = accuracy / 2
# convert to 3 channel images for inceptionV3 model (which is used for FID) and
# use float64 for FID
real_images = real_images.repeat(1, 3, 1, 1).to(torch.float64)
fake_images = fake_images.repeat(1, 3, 1, 1).to(torch.float64)
fid.update(real_images, real=True)
fid.update(fake_images, real=False)
# Does not log FID score per batch, because it is computed on the whole validation set
# Per batch logging of FID would be too noisy
self.log(
item=accuracy_gen.item(),
identifier="gen_valid_accuracy_per_batch",
kind="metric",
step=self.validation_glob_step,
batch_idx=batch_idx,
)
self.log(
item=accuracy_disc.item(),
identifier="disc_valid_accuracy_per_batch",
kind="metric",
step=self.validation_glob_step,
batch_idx=batch_idx,
)
return accuracy_gen, accuracy_disc
[docs]
def save_fake_generator_images(self):
"""Plot and save fake images from generator"""
import matplotlib.pyplot as plt
import numpy as np
self.generator.eval()
noise = torch.randn(64, self.config.z_dim, 1, 1, device=self.device)
fake_images = self.generator(noise)
fake_images_grid = torchvision.utils.make_grid(fake_images, normalize=True)
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_axis_off()
ax.set_title(f"Fake images for epoch {self.epoch}")
ax.imshow(np.transpose(fake_images_grid.cpu().numpy(), (1, 2, 0)))
self.log(
item=fig,
identifier=f"fake_images_epoch_{self.epoch}.png",
kind="figure",
step=self.epoch,
)