4. GAN tutorial with PyTorch

4.1. Tutorial on itwinai TorchTrainer adapted for the distributed GAN using MNIST dataset

Author(s): Henry Mutegeki (CERN), Matteo Bunino (CERN), Jarl Sondre Sæther (CERN)

The code is adapted from this example and a simple non-distributed GAN model can be found in a file named simpleGAN.py that serves as a baseline GAN example but focus is mainly on the train.py file for the distributed GAN use case.

4.1.1. Setup

First, from the root of this repository, build the environment containing pytorch and deepspeed. Refer to the itwinai installation steps.

Then navigate to the project working directory

cd tutorials/distributed-ml/torch-tutorial-GAN

4.1.2. Distributed training on a single node (interactive)

If you want to use SLURM in interactive mode, do the following:

# Allocate resources
$ salloc --partition=batch --nodes=1 --account=intertwin  --gres=gpu:4 --time=1:59:00
job ID is XXXX
# Get a shell in the compute node (if using SLURM)
$ srun --jobid XXXX --overlap --pty /bin/bash
# Now you are inside the compute node

# On JSC, you may need to load some modules
ml --force purge
ml Stages/2024 GCC OpenMPI CUDA/12 MPI-settings/CUDA Python HDF5 PnetCDF libaio mpi4py

# ...before activating the Python environment (adapt this to your env name/path)
source ../../../envAI_hdfml/bin/activate

For more info you can refer to this documentation page.

To launch the training with torch DDP use:

torchrun --standalone --nnodes=1 --nproc-per-node=gpu train.py --strategy ddp

# Optional -- from a SLURM login node:
srun --jobid XXXX --ntasks-per-node=1 torchrun --standalone --nnodes=1 --nproc-per-node=gpu train.py --strategy ddp

To launch the training with Microsoft DeepSpeed use:

deepspeed train.py -s deepspeed --deepspeed

# Optional -- from a SLURM login node:
srun --jobid XXXX --ntasks-per-node=1 deepspeed train.py --strategy deepspeed

4.1.3. Distributed training with SLURM (batch mode)

You can run your training with SLURM by using the itwinai SLURM Builder. Use the slurm_config.yaml file to specify your SLURM parameters and then preview your script with the following command:

itwinai generate-slurm -c slurm_config.yaml --no-save-script --no-submit-job

If you are happy with the script, you can then run it by omitting --no-submit-job:

itwinai generate-slurm -c slurm_config.yaml --no-save-script

If you want to store a copy of the script in a folder, then you can similarly omit --no-save-script:

itwinai generate-slurm -c slurm_config.yaml

4.1.4. Analyze the logs

Analyze the logs with MLFlow:

itwinai mlflow-ui --path mllogs/mlflow

4.1.5. Distributed GAN Documentation

This Guide provides a detailed explanation of how a simple Generative Adversarial Network (GAN) has been adapted to operate within a distributed environment using the GANTrainer. This adaptation enables more efficient training on larger datasets by leveraging distributed computing resources.

4.1.6. Overview

A Generative Adversarial Network consists of two key components:

  • Generator (G): Generates new data instances.

  • Discriminator (D): Evaluates them for authenticity, aiming to distinguish real instances from the fake ones generated by the Generator.

The training process involves iterative adjustments where the Generator tries to produce data indistinguishable from actual data, and the Discriminator improves its ability to detect fakes.

4.1.7. Steps to make a distributed GAN model

The code for all steps can be seen in the attached python file.

4.1.7.1. Step 1: Define Model Architecture

Both the Generator and Discriminator are defined using PyTorch’s nn.Module. The specific architecture for both includes convolutional layers that are well-suited for processing image data.

4.1.7.2. Step 2: Implement Distributed Training

The GANTrainer class extends the custom itwinai TorchTrainer class and handles the initialization of models, optimizers, and the distributed training strategy for the GAN. The snippet below shows how the GANTrainer is extending the TorchTrainer class and initializing the parameters. This is essentially done to handle the scenario for the GAN which comprises of two Neural Network models which is not handled by the TorchTrainer that expects and handles one model. We also create custom optimizers for the Optimizer and Discriminator GAN models:

4.1.7.3. Step 3: Training and Validation Logic

The training alternates between updating the Discriminator using real and generated images and training the Generator to fool the Discriminator. Validation evaluates the performance of the Generator in deceiving the Discriminator.

4.1.7.4. Step 4: Visualization and Monitoring

Training progress is monitored through visualizations of loss metrics and image samples generated periodically.

4.1.8. Takeaways

This readme describes steps taken to adapt a GAN for distributed training, aimed at enhancing efficiency and scalability for training on large-scale datasets. From this use case we learn the following:

  • itwinai TorchTrainer can easily be adapted to different unique use cases like the GAN that has two models.

  • Training models in a distributed environment may require some level of customization in the training architecturing but it comes with lots of performance improvements.

  • Distributed training for GANs requires a large dataset to reduce the chances of overfitting to the smaller data splits created during the training phase.

  • Always ensure that both the models, data and results for a specific process are accessible on the same device during training and validation in a distributed environment.

4.2. Python files

4.2.1. train.py

# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Henry Mutegeki
#
# Credit:
# - Henry Mutegeki <henry.mutegeki@cern.ch> - CERN
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# - Jarl Sondre Sæther <jarl.sondre.saether@cern.ch> - CERN
# --------------------------------------------------------------------------------------

import argparse

import torch
import torch.nn as nn
from torchvision import datasets, transforms

from itwinai.loggers import MLFlowLogger
from itwinai.torch.gan import GANTrainer, GANTrainingConfiguration


class Generator(nn.Module):
    def __init__(self, z_dim, g_hidden, image_channel):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input layer
            nn.ConvTranspose2d(z_dim, g_hidden * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(g_hidden * 8),
            nn.ReLU(True),
            # 1st hidden layer
            nn.ConvTranspose2d(g_hidden * 8, g_hidden * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(g_hidden * 4),
            nn.ReLU(True),
            # 2nd hidden layer
            nn.ConvTranspose2d(g_hidden * 4, g_hidden * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(g_hidden * 2),
            nn.ReLU(True),
            # 3rd hidden layer
            nn.ConvTranspose2d(g_hidden * 2, g_hidden, 4, 2, 1, bias=False),
            nn.BatchNorm2d(g_hidden),
            nn.ReLU(True),
            # output layer
            nn.ConvTranspose2d(g_hidden, image_channel, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, input):
        return self.main(input)


class Discriminator(nn.Module):
    def __init__(self, d_hidden, image_channel):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # 1st layer
            nn.Conv2d(image_channel, d_hidden, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 2nd layer
            nn.Conv2d(d_hidden, d_hidden * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(d_hidden * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 3rd layer
            nn.Conv2d(d_hidden * 2, d_hidden * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(d_hidden * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 4th layer
            nn.Conv2d(d_hidden * 4, d_hidden * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(d_hidden * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # output layer
            nn.Conv2d(d_hidden * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)


def main():
    parser = argparse.ArgumentParser(description="PyTorch MNIST GAN Example")
    parser.add_argument(
        "--batch-size",
        type=int,
        default=128,
        help="input batch size for training (default: 128)",
    )
    parser.add_argument(
        "--epochs", type=int, default=15, help="number of epochs to train (default: 15)"
    )
    parser.add_argument(
        "--strategy", type=str, default="ddp", help="distributed strategy (default=ddp)"
    )
    parser.add_argument(
        "--lr", type=float, default=0.001, help="learning rate (default: 0.001)"
    )
    parser.add_argument("--seed", type=int, default=1, help="random seed (default: 1)")
    parser.add_argument(
        "--ckpt-interval",
        type=int,
        default=2,
        help="how many batches to wait before logging training status",
    )
    args = parser.parse_args()

    torch.manual_seed(args.seed)

    # Dataset
    transform = transforms.Compose(
        [
            transforms.Resize(64),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ]
    )

    train_dataset = datasets.MNIST("../data", train=True, download=True, transform=transform)
    validation_dataset = datasets.MNIST("../data", train=False, transform=transform)

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find("Conv") != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find("BatchNorm") != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    # Models
    netG = Generator(z_dim=100, g_hidden=64, image_channel=1)
    netG.apply(weights_init)
    netD = Discriminator(d_hidden=64, image_channel=1)
    netD.apply(weights_init)

    # Training configuration
    training_config = GANTrainingConfiguration(
        batch_size=args.batch_size,
        optim_generator_lr=args.lr,
        optim_discriminator_lr=args.lr,
        z_dim=100,
    )

    # Logger
    logger = MLFlowLogger(experiment_name="Distributed GAN MNIST", log_freq=10)

    # Trainer
    trainer = GANTrainer(
        config=training_config,
        epochs=args.epochs,
        discriminator=netD,
        generator=netG,
        strategy=args.strategy,
        random_seed=args.seed,
        logger=logger,
        checkpoint_every=args.ckpt_interval,
    )

    # Launch training
    _, _, _, trained_model = trainer.execute(train_dataset, validation_dataset, None)


if __name__ == "__main__":
    main()

4.2.2. simpleGAN.py

# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Henry Mutegeki
#
# Credit:
# - Henry Mutegeki <henry.mutegeki@cern.ch> - CERN
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# --------------------------------------------------------------------------------------


"""This script shows a simple example on how to train a GAN without using itwinai."""

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils

DATA_PATH = "../data"
BATCH_SIZE = 128
IMAGE_CHANNEL = 1
Z_DIM = 100
G_HIDDEN = 64
X_DIM = 64
D_HIDDEN = 64
EPOCH_NUM = 10
REAL_LABEL = 1
FAKE_LABEL = 0
lr = 2e-4
seed = 1

USE_CUDA = False
CUDA = torch.cuda.is_available() and USE_CUDA
device = torch.device("cuda:0" if CUDA else "cpu")

print("PyTorch version: {}".format(torch.__version__))
if CUDA:
    print("CUDA version: {}\n".format(torch.version.cuda))
    torch.cuda.manual_seed(seed)

cudnn.benchmark = True

# Data preprocessing
dataset = dset.MNIST(
    root=DATA_PATH,
    download=False,
    transform=transforms.Compose(
        [transforms.Resize(X_DIM), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
    ),
)

# Dataloader
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2
)


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input layer
            nn.ConvTranspose2d(Z_DIM, G_HIDDEN * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 8),
            nn.ReLU(True),
            # 1st hidden layer
            nn.ConvTranspose2d(G_HIDDEN * 8, G_HIDDEN * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 4),
            nn.ReLU(True),
            # 2nd hidden layer
            nn.ConvTranspose2d(G_HIDDEN * 4, G_HIDDEN * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 2),
            nn.ReLU(True),
            # 3rd hidden layer
            nn.ConvTranspose2d(G_HIDDEN * 2, G_HIDDEN, 4, 2, 1, bias=False),
            nn.BatchNorm2d(G_HIDDEN),
            nn.ReLU(True),
            # output layer
            nn.ConvTranspose2d(G_HIDDEN, IMAGE_CHANNEL, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, input):
        return self.main(input)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # 1st layer
            nn.Conv2d(IMAGE_CHANNEL, D_HIDDEN, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 2nd layer
            nn.Conv2d(D_HIDDEN, D_HIDDEN * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 3rd layer
            nn.Conv2d(D_HIDDEN * 2, D_HIDDEN * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 4th layer
            nn.Conv2d(D_HIDDEN * 4, D_HIDDEN * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # output layer
            nn.Conv2d(D_HIDDEN * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)


# Create the generator
netG = Generator().to(device)
netG.apply(weights_init)
print(netG)


# Create the discriminator
netD = Discriminator().to(device)
netD.apply(weights_init)
print(netD)


# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors to visualize the progression of the generator
viz_noise = torch.randn(BATCH_SIZE, Z_DIM, 1, 1, device=device)

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))


# Training Loop
def train_GAN_model(EPOCH_NUM, netD, netG, optimizerG, optimizerD, dataloader, criterion):
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0
    for epoch in range(EPOCH_NUM):
        for i, data in enumerate(dataloader, 0):
            # (1) Update the discriminator with real data
            netD.zero_grad()
            # Format batch
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size,), REAL_LABEL, dtype=torch.float, device=device)
            # Forward pass real batch through D
            output = netD(real_cpu).view(-1)
            # Calculate loss on all-real batch
            errD_real = criterion(output, label)
            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x = output.mean().item()

            # (2) Update the discriminator with fake data
            # Generate batch of latent vectors
            noise = torch.randn(b_size, Z_DIM, 1, 1, device=device)
            # Generate fake image batch with G
            fake = netG(noise)
            label.fill_(FAKE_LABEL)
            # Classify all fake batch with D
            output = netD(fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output, label)
            # Calculate the gradients for this batch, accumulated (summed)
            # with previous gradients
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            # Compute error of D as sum over the fake and the real batches
            errD = errD_real + errD_fake
            # Update D
            optimizerD.step()

            # (3) Update the generator with fake data
            netG.zero_grad()
            label.fill_(REAL_LABEL)  # fake labels are real for generator cost
            # Since we just updated D, perform another forward pass of
            # all-fake batch through D
            output = netD(fake).view(-1)
            # Calculate G's loss based on this output
            errG = criterion(output, label)
            # Calculate gradients for G
            errG.backward()
            D_G_z2 = output.mean().item()
            # Update G
            optimizerG.step()

            # Output training stats
            if i % 50 == 0:
                print(
                    (
                        "[ % d/%d][%d/%d]\tLoss_D: % .4f\tLoss_G:"
                        " % .4f\tD(x): % .4f\tD(G(z)): % .4f / %.4f"
                    )
                    % (
                        epoch,
                        EPOCH_NUM,
                        i,
                        len(dataloader),
                        errD.item(),
                        errG.item(),
                        D_x,
                        D_G_z1,
                        D_G_z2,
                    )
                )

            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            # Check how the generator is doing
            if (iters % 500 == 0) or ((epoch == EPOCH_NUM - 1) and (i == len(dataloader) - 1)):
                with torch.no_grad():
                    fake = netG(viz_noise).detach().cpu()
                img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

            iters += 1

    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses, label="G")
    plt.plot(D_losses, label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()
    # plt.savefig('simpleGANlearning_curve.png')
    # Grab a batch of real images from the dataloader
    real_batch = next(iter(dataloader))
    # Plot the real images
    plt.figure(figsize=(15, 15))
    plt.subplot(1, 2, 1)
    plt.axis("off")
    plt.title("Real Images")
    plt.imshow(
        np.transpose(
            vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),
            (1, 2, 0),
        )
    )
    # plt.savefig('simpleGANreal_image.png')
    # Plot the fake images from the last epoch
    plt.subplot(1, 2, 2)
    plt.axis("off")
    plt.title("Fake Images")
    plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
    plt.show()
    # plt.savefig('simpleGANfake_image.png')


train_GAN_model(EPOCH_NUM, netD, netG, optimizerG, optimizerD, dataloader, criterion)