Source code for itwinai.torch.gan

# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# - Henry Mutegeki <henry.mutegeki@cern.ch> - CERN
# - Linus Eickhoff <linus.maximilian.eickhoff@cern.ch> - CERN
# --------------------------------------------------------------------------------------

"""Provides training logic for PyTorch models via Trainer classes."""

import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Literal, Optional, Tuple

import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torchvision
import yaml
from ray.train import DataConfig, RunConfig, ScalingConfig
from ray.train.torch import TorchConfig
from ray.tune import TuneConfig
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric
from torchmetrics.image.fid import FrechetInceptionDistance

from ..loggers import Logger
from .config import TrainingConfiguration
from .trainer import TorchTrainer

if TYPE_CHECKING:
    from ray.train.horovod import HorovodConfig

py_logger = logging.getLogger(__name__)


[docs] class GANTrainingConfiguration(TrainingConfiguration): """Configuration object for training a GAN. Extends the base TrainingConfiguration.""" #: Name of the optimizer to use for the generator. Defaults to 'adam'. optimizer_generator: Literal["adadelta", "adam", "adamw", "rmsprop", "sgd"] = "adam" #: Learning rate used by the optimizer for the generator. Defaults to 1e-3. optim_generator_lr: float = 1e-3 #: Momentum used by some optimizers (e.g., SGD) for the generator. Defaults to 0.9. optim_generator_momentum: float = 0.9 #: Betas of Adam optimized (if used) for the generator. Defaults to (0.5, 0.999). optim_generator_betas: Tuple[float, float] = (0.5, 0.999) #: Weight decay parameter for the optimizer for the generator. Defaults to 0. optim_generator_weight_decay: float = 0.0 #: Learning rate scheduler algorithm for the generator optimizer. #: Defaults to None (not used). lr_scheduler_generator: ( Literal["step", "multistep", "constant", "linear", "exponential", "polynomial"] | None ) = None #: Learning rate scheduler step size, if needed by the scheduler. Defaults to 10 (epochs). lr_scheduler_generator_step_size: int | Iterable[int] = 10 #: Learning rate scheduler step size, if needed by the scheduler. #: Usually this is used by the ExponentialLR. # : Defaults to 0.5. lr_scheduler_generator_gamma: float = 0.95 #: Name of the optimizer to use for the discriminator. Defaults to 'adam'. optimizer_discriminator: Literal["adadelta", "adam", "adamw", "rmsprop", "sgd"] = "adam" #: Learning rate used by the optimizer for the discriminator. Defaults to 1e-3. optim_discriminator_lr: float = 1e-3 #: Momentum used by some optimizers (e.g., SGD) for the discriminator. Defaults to 0.9. optim_discriminator_momentum: float = 0.9 #: Betas of Adam optimized (if used) for the discriminator. Defaults to (0.5, 0.999). optim_discriminator_betas: Tuple[float, float] = (0.5, 0.999) #: Weight decay parameter for the optimizer for the discriminator. Defaults to 0. optim_discriminator_weight_decay: float = 0.0 #: Learning rate scheduler algorithm for the discriminator optimizer. #: Defaults to None (not used). lr_scheduler_discriminator: ( Literal["step", "multistep", "constant", "linear", "exponential", "polynomial"] | None ) = None #: Learning rate scheduler step size, if needed by the scheduler. Defaults to 10 (epochs). lr_scheduler_discriminator_step_size: int | Iterable[int] = 10 #: Learning rate scheduler step size, if needed by the scheduler. #: Usually this is used by the ExponentialLR. # : Defaults to 0.5. lr_scheduler_discriminator_gamma: float = 0.95 #: Classification criterion to be used for generator and discriminator losses. Defaults to #: "bceloss". loss: str = "bceloss" #: Generator input size (random noise size). Defaults to 100. z_dim: int = 100
[docs] class GANTrainer(TorchTrainer): """Trainer class for GAN models using pytorch. Args: config (Dict | TrainingConfiguration): training configuration containing hyperparameters. epochs (int): number of training epochs. discriminator (nn.Module): pytorch discriminator model to train GAN. generator (nn.Module): pytorch generator model to train GAN. strategy (Literal['ddp', 'deepspeed', 'horovod'], optional): distributed strategy. Defaults to 'ddp'. test_every (int | None, optional): run a test epoch every ``test_every`` epochs. Disabled if None. Defaults to None. random_seed (int | None, optional): set random seed for reproducibility. If None, the seed is not set. Defaults to None. logger (Logger | None, optional): logger for ML tracking. Defaults to None. metrics (Dict[str, Callable] | None, optional): map of torch metrics metrics. Defaults to None. checkpoints_location (str): path to checkpoints directory. Defaults to "checkpoints". checkpoint_every (int | None): save a checkpoint every ``checkpoint_every`` epochs. Disabled if None. Defaults to None. name (str | None, optional): trainer custom name. Defaults to None. profiling_wait_epochs (int): how many epochs to wait before starting the profiler. profiling_warmup_epochs (int): length of the profiler warmup phase in terms of number of epochs. ray_scaling_config (ScalingConfig, optional): scaling config for Ray Trainer. Defaults to None, ray_tune_config (TuneConfig, optional): tune config for Ray Tuner. Defaults to None. ray_run_config (RunConfig, optional): run config for Ray Trainer. Defaults to None. ray_search_space (Dict[str, Any], optional): search space for Ray Tuner. Defaults to None. ray_torch_config (TorchConfig, optional): torch configuration for Ray's TorchTrainer. Defaults to None. ray_data_config (DataConfig, optional): dataset configuration for Ray. Defaults to None. ray_horovod_config (HorovodConfig, optional): horovod configuration for Ray's HorovodTrainer. Defaults to None. from_checkpoint (str | Path, optional): path to checkpoint directory. Defaults to None. """ #: PyTorch generator to train. generator: nn.Module | None = None #: PyTorch discriminator to train. discriminator: nn.Module | None = None #: Classification loss criterion used in the generator and discriminator losses. loss: Callable | None = None #: Optimizer for the generator. optimizer_generator: Optimizer | None = None #: Optimizer for the discriminator. optimizer_discriminator: Optimizer | None = None #: Learning rate scheduler for the optimizer of the generator. lr_scheduler_generator: LRScheduler | None = None #: Learning rate scheduler for the optimizer of the discriminator. lr_scheduler_discriminator: LRScheduler | None = None def __init__( self, config: Dict | GANTrainingConfiguration, epochs: int, discriminator: nn.Module, generator: nn.Module, strategy: Literal["ddp", "deepspeed"] = "ddp", test_every: int | None = None, random_seed: int | None = None, logger: Logger | None = None, metrics: Dict[str, Metric] | None = None, checkpoints_location: str = "checkpoints", checkpoint_every: int | None = None, name: str | None = None, profiling_wait_epochs: int = 1, profiling_warmup_epochs: int = 2, ray_scaling_config: ScalingConfig | None = None, ray_tune_config: TuneConfig | None = None, ray_run_config: RunConfig | None = None, ray_search_space: Dict[str, Any] | None = None, ray_torch_config: TorchConfig | None = None, ray_data_config: DataConfig | None = None, ray_horovod_config: Optional["HorovodConfig"] = None, from_checkpoint: str | Path | None = None, **kwargs, ) -> None: super().__init__( config=config, epochs=epochs, model=None, strategy=strategy, test_every=test_every, random_seed=random_seed, logger=logger, metrics=metrics, checkpoints_location=checkpoints_location, checkpoint_every=checkpoint_every, name=name, profiling_wait_epochs=profiling_wait_epochs, profiling_warmup_epochs=profiling_warmup_epochs, ray_scaling_config=ray_scaling_config, ray_tune_config=ray_tune_config, ray_run_config=ray_run_config, ray_search_space=ray_search_space, ray_horovod_config=ray_horovod_config, from_checkpoint=from_checkpoint, ray_torch_config=ray_torch_config, ray_data_config=ray_data_config, **kwargs, ) self.save_parameters(**self.locals2params(locals())) self.discriminator = discriminator self.generator = generator if isinstance(config, dict): config = GANTrainingConfiguration(**config) self.config = config self.epoch = 0 # Initial training state -- can be resumed from a checkpoint self.discriminator_state_dict = None self.generator_state_dict = None self.optimizer_discriminator_state_dict = None self.optimizer_generator_state_dict = None self.lr_scheduler_generator_state_dict = None self.lr_scheduler_discriminator_state_dict = None def _optimizer_from_config(self) -> None: match self.config.optimizer_generator: case "adadelta": self.optimizer_generator = optim.Adadelta( self.generator.parameters(), lr=self.config.optim_generator_lr, weight_decay=self.config.optim_generator_weight_decay, ) case "adam": self.optimizer_generator = optim.Adam( self.generator.parameters(), lr=self.config.optim_generator_lr, betas=self.config.optim_generator_betas, weight_decay=self.config.optim_generator_weight_decay, ) case "adamw": self.optimizer_generator = optim.AdamW( self.generator.parameters(), lr=self.config.optim_generator_lr, betas=self.config.optim_generator_betas, weight_decay=self.config.optim_generator_weight_decay, ) case "rmsprop": self.optimizer_generator = optim.RMSprop( self.generator.parameters(), lr=self.config.optim_generator_lr, weight_decay=self.config.optim_generator_weight_decay, momentum=self.config.optim_generator_momentum, ) case "sgd": self.optimizer_generator = optim.SGD( self.generator.parameters(), lr=self.config.optim_generator_lr, weight_decay=self.config.optim_generator_weight_decay, momentum=self.config.optim_generator_momentum, ) case _: raise ValueError( "Unrecognized self.config.optimizer_generator! Check the docs for " "supported values and consider overriding " "create_model_loss_optimizer method for more flexibility." ) match self.config.optimizer_discriminator: case "adadelta": self.optimizer_discriminator = optim.Adadelta( self.discriminator.parameters(), lr=self.config.optim_discriminator_lr, weight_decay=self.config.optim_discriminator_weight_decay, ) case "adam": self.optimizer_discriminator = optim.Adam( self.discriminator.parameters(), lr=self.config.optim_discriminator_lr, betas=self.config.optim_discriminator_betas, weight_decay=self.config.optim_discriminator_weight_decay, ) case "adamw": self.optimizer_discriminator = optim.AdamW( self.discriminator.parameters(), lr=self.config.optim_discriminator_lr, betas=self.config.optim_discriminator_betas, weight_decay=self.config.optim_discriminator_weight_decay, ) case "rmsprop": self.optimizer_discriminator = optim.RMSprop( self.discriminator.parameters(), lr=self.config.optim_discriminator_lr, weight_decay=self.config.optim_discriminator_weight_decay, momentum=self.config.optim_discriminator_momentum, ) case "sgd": self.optimizer_discriminator = optim.SGD( self.discriminator.parameters(), lr=self.config.optim_discriminator_lr, weight_decay=self.config.optim_discriminator_weight_decay, momentum=self.config.optim_discriminator_momentum, ) case _: raise ValueError( "Unrecognized self.config.optimizer_discriminator! Check the docs for " "supported values and consider overriding " "create_model_loss_optimizer method for more flexibility." ) def _lr_scheduler_from_config(self) -> None: """Parse Lr scheduler from training config""" if self.config.lr_scheduler_generator: if not self.optimizer_generator: raise ValueError( "Trying to instantiate a LR scheduler but the optimizer_generator is None!" ) match self.config.lr_scheduler_generator: case "constant": self.lr_scheduler_generator = lr_scheduler.ConstantLR( self.optimizer_generator ) case "polynomial": self.lr_scheduler_generator = lr_scheduler.PolynomialLR( self.optimizer_generator ) case "exponential": self.lr_scheduler_generator = lr_scheduler.ExponentialLR( self.optimizer_generator, gamma=self.config.lr_scheduler_generator_gamma, ) case "linear": self.lr_scheduler_generator = lr_scheduler.LinearLR( self.optimizer_generator ) case "multistep": self.lr_scheduler_generator = lr_scheduler.MultiStepLR( self.optimizer_generator, milestones=self.config.lr_scheduler_generator_step_size, ) case "step": self.lr_scheduler_generator = lr_scheduler.StepLR( self.optimizer_generator, step_size=self.config.lr_scheduler_generator_step_size, ) case _: raise ValueError( "Unrecognized self.config.lr_scheduler_generator! Check the docs for " "supported values and consider overriding " "create_model_loss_optimizer method for more flexibility." ) if self.config.lr_scheduler_discriminator: if not self.optimizer_discriminator: raise ValueError( "Trying to instantiate a LR scheduler but the optimizer_discriminator " "is None!" ) match self.config.lr_scheduler_discriminator: case "constant": self.lr_scheduler_discriminator = lr_scheduler.ConstantLR( self.optimizer_discriminator ) case "polynomial": self.lr_scheduler_discriminator = lr_scheduler.PolynomialLR( self.optimizer_discriminator ) case "exponential": self.lr_scheduler_discriminator = lr_scheduler.ExponentialLR( self.optimizer_discriminator, gamma=self.config.lr_scheduler_discriminator_gamma, ) case "linear": self.lr_scheduler_discriminator = lr_scheduler.LinearLR( self.optimizer_discriminator ) case "multistep": self.lr_scheduler_discriminator = lr_scheduler.MultiStepLR( self.optimizer_discriminator, milestones=self.config.lr_scheduler_discriminator_step_size, ) case "step": self.lr_scheduler_discriminator = lr_scheduler.StepLR( self.optimizer_discriminator, step_size=self.config.lr_scheduler_discriminator_step_size, ) case _: raise ValueError( "Unrecognized self.config.lr_scheduler_discriminator! Check the " "docs for supported values and consider overriding " "create_model_loss_optimizer method for more flexibility." )
[docs] def create_model_loss_optimizer(self) -> None: """Instantiate a torch model, loss, optimizer, and LR scheduler using the configuration provided in the Trainer constructor. Generally a user-defined method. """ ################################### # Dear user, this is a method you # # may be interested to override! # ################################### # Model, optimizer, and lr scheduler may have already been loaded from a checkpoint if self.generator is None or self.discriminator is None: raise ValueError( "self.generator or self.discrimintaor is None! " "Either pass it to the constructor, load a checkpoint, or " "override create_model_loss_optimizer method." ) if self.generator_state_dict: # Load generator from checkpoint self.generator.load_state_dict(self.generator_state_dict, strict=False) if self.discriminator_state_dict: # Load discriminator from checkpoint self.discriminator.load_state_dict(self.discriminator_state_dict, strict=False) # Parse optimizers from training configuration # Optimizers can be changed with a custom one here! self._optimizer_from_config() # Parse LR schedulers from training configuration # LR schedulers can be changed with a custom one here! self._lr_scheduler_from_config() if self.optimizer_generator_state_dict: # Load optimizer state from checkpoint # IMPORTANT: this must be after the learning rate scheduler was already initialized # by passing to it the optimizer. Otherwise the optimizer state just loaded will # be modified by the lr scheduler. self.optimizer_generator.load_state_dict(self.optimizer_generator_state_dict) if self.optimizer_discriminator_state_dict: # Load optimizer state from checkpoint # IMPORTANT: this must be after the learning rate scheduler was already initialized # by passing to it the optimizer. Otherwise the optimizer state just loaded will # be modified by the lr scheduler. self.optimizer_discriminator.load_state_dict( self.optimizer_discriminator_state_dict ) if self.lr_scheduler_generator_state_dict and self.lr_scheduler_generator: # Load LR scheduler state from checkpoint self.lr_scheduler_generator.load_state_dict(self.lr_scheduler_generator_state_dict) if self.lr_scheduler_discriminator_state_dict and self.lr_scheduler_discriminator: # Load LR scheduler state from checkpoint self.lr_scheduler_discriminator.load_state_dict( self.lr_scheduler_discriminator_state_dict ) # Parse loss from training configuration # Loss can be change with a custom one here! self._set_loss_from_config() self.criterion = self.loss # if not self.optimizer_discriminator: # self.optimizer_discriminator = optim.Adam( # self.discriminator.parameters(), lr=self.config.lr, betas=(0.5, 0.999) # ) # if not self.optimizer_generator: # self.optimizer_generator = optim.Adam( # self.generator.parameters(), lr=self.config.lr, betas=(0.5, 0.999) # ) # self.criterion = nn.BCELoss() # https://stackoverflow.com/a/67437077 self.discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.discriminator) self.generator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.generator) # IMPORTANT: model, optimizer, and scheduler need to be distributed from here on distribute_kwargs = self.get_default_distributed_kwargs() # Distribute discriminator and its optimizer self.discriminator, self.optimizer_discriminator, _ = self.strategy.distributed( self.discriminator, self.optimizer_discriminator, **distribute_kwargs ) self.generator, self.optimizer_generator, _ = self.strategy.distributed( self.generator, self.optimizer_generator, **distribute_kwargs )
[docs] def save_checkpoint( self, name: str, best_validation_metric: torch.Tensor | None = None, checkpoints_root: str | Path | None = None, force: bool = False, ) -> str | None: """Save training checkpoint. Args: name (str): name of the checkpoint directory. best_validation_metric (torch.Tensor | None) : best validation loss throughout training so far (if available). checkpoints_root (str | None): path for root checkpoints dir. If None, uses ``self.checkpoints_location`` as base. force (bool): force checkpointign now. Returns: path to the checkpoint file or ``None`` when the checkpoint is not created. """ if not ( force or self.strategy.is_main_worker and self.checkpoint_every and (self.epoch + 1) % self.checkpoint_every == 0 ): # Do nothing and return return ckpt_dir = Path(checkpoints_root or self.checkpoints_location) / name ckpt_dir.mkdir(parents=True, exist_ok=True) # Save state (epoch, loss, optimizer, scheduler) state = { "epoch": self.epoch, # This could store the best validation loss "best_validation_metric": ( best_validation_metric.item() if best_validation_metric is not None else None ), "optimizer_generator_state_dict": self.optimizer_generator.state_dict(), "optimizer_discriminator_state_dict": self.optimizer_discriminator.state_dict(), "lr_scheduler_generator_state_dict": ( self.lr_scheduler_generator.state_dict() if self.lr_scheduler_generator is not None else None ), "lr_scheduler_discriminator_state_dict": ( self.lr_scheduler_discriminator.state_dict() if self.lr_scheduler_discriminator is not None else None ), "torch_rng_state": self.torch_rng.get_state(), "random_seed": self.random_seed, } state_path = ckpt_dir / "state.pt" torch.save(state, state_path) # Save PyTorch models separately # TODO: check that the state dict is stripped from any distributed info generator_path = ckpt_dir / "generator.pt" torch.save(self.generator.state_dict(), generator_path) discriminator_path = ckpt_dir / "discriminator.pt" torch.save(self.discriminator.state_dict(), discriminator_path) # Save Pydantic config as YAML config_path = ckpt_dir / "config.yaml" with config_path.open("w") as f: yaml.safe_dump(self.config.model_dump(), f) # Log each file with an appropriate identifier self.log(str(state_path), f"{name}_state", kind="artifact") self.log(str(generator_path), f"{name}_generator", kind="artifact") self.log(str(discriminator_path), f"{name}_discriminator", kind="artifact") self.log(str(config_path), f"{name}_config", kind="artifact") return str(ckpt_dir)
def _load_checkpoint(self, checkpoint_dir: str | Path) -> None: """Load checkpoint from path.""" checkpoint_dir = Path(checkpoint_dir) state = torch.load(checkpoint_dir / "state.pt") # Override initial training state self.generator_state_dict_state_dict = torch.load(checkpoint_dir / "generator.pt") self.discriminator_state_dict = torch.load(checkpoint_dir / "discriminator.pt") self.optimizer_generator_state_dict = state["optimizer_generator_state_dict"] self.optimizer_discriminator_state_dict = state["optimizer_discriminator_state_dict"] self.lr_scheduler_generator_state_dict = state["lr_scheduler_generator_state_dict"] self.lr_scheduler_discriminator_state_dict = state[ "lr_scheduler_discriminator_state_dict" ] self.torch_rng_state = state["torch_rng_state"] # Direct overrides (don't require further attention) self.random_seed = state["random_seed"] self.epoch = state["epoch"] + 1 # Start from next epoch if state["best_validation_metric"]: self.best_validation_metric = state["best_validation_metric"]
[docs] def train_epoch(self): self.discriminator.train() self.generator.train() gen_train_losses = [] disc_train_losses = [] disc_train_accuracy = [] for batch_idx, (real_images, _) in enumerate(self.train_dataloader): loss_gen, loss_disc, accuracy_disc = self.train_step(real_images, batch_idx) gen_train_losses.append(loss_gen) disc_train_losses.append(loss_disc) disc_train_accuracy.append(accuracy_disc) self.train_glob_step += 1 # Aggregate and log losses and accuracy avg_disc_accuracy = torch.mean(torch.stack(disc_train_accuracy)) self.log( item=avg_disc_accuracy.item(), identifier="disc_train_accuracy_per_epoch", kind="metric", step=self.epoch, ) avg_gen_loss = torch.mean(torch.stack(gen_train_losses)) self.log( item=avg_gen_loss.item(), identifier="gen_train_loss_per_epoch", kind="metric", step=self.epoch, ) avg_disc_loss = torch.mean(torch.stack(disc_train_losses)) self.log( item=avg_disc_loss.item(), identifier="disc_train_loss_per_epoch", kind="metric", step=self.epoch, ) self.save_fake_generator_images()
[docs] def train_step( self, real_images: torch.Tensor, batch_idx: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """train step for GAN. Args: real_images (torch.Tensor): real images. batch_idx (int): batch index. Returns: torch.Tensor: loss of the discriminator torch.Tensor: loss of the generator torch.Tensor: accuracy of the discriminator """ real_images = real_images.to(self.device) batch_size = real_images.size(0) real_labels = torch.ones((batch_size,), dtype=torch.float, device=self.device) fake_labels = torch.zeros((batch_size,), dtype=torch.float, device=self.device) # Train Discriminator with real images output_real = self.discriminator(real_images) loss_disc_real = self.criterion(output_real, real_labels) # Generate fake images and train Discriminator noise = torch.randn(batch_size, self.config.z_dim, 1, 1, device=self.device) fake_images = self.generator(noise) output_fake = self.discriminator(fake_images.detach()) loss_disc_fake = self.criterion(output_fake, fake_labels) loss_disc = (loss_disc_real + loss_disc_fake) / 2 self.optimizer_discriminator.zero_grad() loss_disc.backward() self.optimizer_discriminator.step() accuracy = ((output_real > 0.5).float() == real_labels).float().mean() + ( (output_fake < 0.5).float() == fake_labels ).float().mean() accuracy_disc = accuracy.mean() # Train Generator output_fake = self.discriminator(fake_images) loss_gen = self.criterion(output_fake, real_labels) self.optimizer_generator.zero_grad() loss_gen.backward() self.optimizer_generator.step() self.log( item=accuracy_disc, identifier="disc_train_accuracy_per_batch", kind="metric", step=self.train_glob_step, batch_idx=batch_idx, ) self.log( item=loss_gen, identifier="gen_train_loss_per_batch", kind="metric", step=self.train_glob_step, batch_idx=batch_idx, ) self.log( item=loss_disc, identifier="disc_train_loss_per_batch", kind="metric", step=self.train_glob_step, batch_idx=batch_idx, ) return loss_gen, loss_disc, accuracy_disc
[docs] def validation_epoch(self, fid_features: int = 2048) -> torch.Tensor: """Validation epoch for GAN. Args: fid_features (int, optional): number of features for InceptionV3 modela. Defaults to 2048. Returns: torch.Tensor: FID score that is returned by the FID metric. """ gen_validation_accuracy = [] disc_validation_accuracy = [] self.discriminator.eval() self.generator.eval() fid = FrechetInceptionDistance(feature=fid_features, normalize=True) # known to be unstable with float32 # (https://lightning.ai/docs/torchmetrics/stable/image/frechet_inception_distance.html) fid.set_dtype(torch.float64) # Move FID to the same device as GAN fid = fid.to(self.device) for batch_idx, (real_images, _) in enumerate(self.validation_dataloader): accuracy_gen, accuracy_disc = self.validation_step(real_images, batch_idx, fid) gen_validation_accuracy.append(accuracy_gen) disc_validation_accuracy.append(accuracy_disc) self.validation_glob_step += 1 # Aggregate and log metrics disc_validation_accuracy = torch.mean(torch.stack(disc_validation_accuracy)) # Compute FID score using InceptionV3 model fid_score = fid.compute() self.log( item=disc_validation_accuracy.item(), identifier="disc_valid_accuracy_epoch", kind="metric", step=self.epoch, ) gen_validation_accuracy = torch.mean(torch.stack(gen_validation_accuracy)) self.log( item=gen_validation_accuracy.item(), identifier="gen_valid_accuracy_epoch", kind="metric", step=self.epoch, ) self.log( item=fid_score.item(), identifier="gen_valid_fid_score_epoch", kind="metric", step=self.epoch, ) return fid_score
[docs] def validation_step( self, real_images: torch.Tensor, batch_idx: int, fid: FrechetInceptionDistance ) -> Tuple[torch.Tensor, torch.Tensor]: """Validation step for GAN. Args: real_images (torch.Tensor): real images. batch_idx (int): batch index. fid (FrechetInceptionDistance): FID metric. Returns: torch.Tensor: accuracy of the generator torch.Tensor: accuracy of the discriminator """ real_images = real_images.to(self.device) batch_size = real_images.size(0) real_labels = torch.ones((batch_size,), dtype=torch.float, device=self.device) fake_labels = torch.zeros((batch_size,), dtype=torch.float, device=self.device) # Validate with real images output_real = self.discriminator(real_images) # Generate and validate fake images noise = torch.randn(batch_size, self.config.z_dim, 1, 1, device=self.device) with torch.no_grad(): fake_images = self.generator(noise) output_fake = self.discriminator(fake_images.detach()) # Generator's attempt to fool the discriminator accuracy_gen = ((output_fake > 0.5).float() == real_labels).float().mean() # Calculate total discriminator loss and accuracy accuracy = ((output_real > 0.5).float() == real_labels).float().mean() + ( (output_fake < 0.5).float() == fake_labels ).float().mean() accuracy_disc = accuracy / 2 # convert to 3 channel images for inceptionV3 model (which is used for FID) and # use float64 for FID real_images = real_images.repeat(1, 3, 1, 1).to(torch.float64) fake_images = fake_images.repeat(1, 3, 1, 1).to(torch.float64) fid.update(real_images, real=True) fid.update(fake_images, real=False) # Does not log FID score per batch, because it is computed on the whole validation set # Per batch logging of FID would be too noisy self.log( item=accuracy_gen.item(), identifier="gen_valid_accuracy_per_batch", kind="metric", step=self.validation_glob_step, batch_idx=batch_idx, ) self.log( item=accuracy_disc.item(), identifier="disc_valid_accuracy_per_batch", kind="metric", step=self.validation_glob_step, batch_idx=batch_idx, ) return accuracy_gen, accuracy_disc
[docs] def save_fake_generator_images(self): """Plot and save fake images from generator""" import matplotlib.pyplot as plt import numpy as np self.generator.eval() noise = torch.randn(64, self.config.z_dim, 1, 1, device=self.device) fake_images = self.generator(noise) fake_images_grid = torchvision.utils.make_grid(fake_images, normalize=True) fig, ax = plt.subplots(figsize=(8, 8)) ax.set_axis_off() ax.set_title(f"Fake images for epoch {self.epoch}") ax.imshow(np.transpose(fake_images_grid.cpu().numpy(), (1, 2, 0))) self.log( item=fig, identifier=f"fake_images_epoch_{self.epoch}.png", kind="figure", step=self.epoch, )