4. GAN tutorial with PyTorch

4.1. Tutorial on itwinai TorchTrainer adapted for the distributed GAN using MNIST dataset

Author(s): Henry Mutegeki (CERN), Matteo Bunino (CERN), Jarl Sondre Sæther (CERN), Linus Eickhoff (CERN)

The code is adapted from this example. Focus is mainly on the train.py file for the distributed GAN use case.

4.1.1. Setup

First, from the root of this repository, build the environment containing pytorch and deepspeed. Refer to the itwinai installation steps.

Then navigate to the project working directory

cd tutorials/distributed-ml/torch-tutorial-GAN

4.1.2. Distributed training on a single node (interactive)

If you want to use SLURM in interactive mode, do the following:

# Allocate resources
$ salloc --partition=batch --nodes=1 --account=intertwin  --gres=gpu:4 --time=1:59:00
job ID is XXXX
# Get a shell in the compute node (if using SLURM)
$ srun --jobid XXXX --overlap --pty /bin/bash
# Now you are inside the compute node

# On JSC, you may need to load some modules
ml --force purge
ml Stages/2024 GCC OpenMPI CUDA/12 MPI-settings/CUDA Python HDF5 PnetCDF libaio mpi4py

# ...before activating the Python environment (adapt this to your env name/path)
source ../../../envAI_hdfml/bin/activate

For more info you can refer to this documentation page.

To launch the training with torch DDP use:

torchrun --standalone --nnodes=1 --nproc-per-node=gpu train.py --strategy ddp

# Optional -- from a SLURM login node:
srun --jobid XXXX --ntasks-per-node=1 torchrun --standalone --nnodes=1 --nproc-per-node=gpu train.py --strategy ddp

To launch the training with Microsoft DeepSpeed use:

deepspeed train.py -s deepspeed --deepspeed

# Optional -- from a SLURM login node:
srun --jobid XXXX --ntasks-per-node=1 deepspeed train.py --strategy deepspeed

4.1.3. Distributed training with SLURM (batch mode)

You can run your training with SLURM by using the itwinai SLURM Builder. Use the slurm_config.yaml file to specify your SLURM parameters and then preview your script with the following command:

itwinai generate-slurm -c slurm_config.yaml --no-save-script --no-submit-job

If you are happy with the script, you can then run it by omitting --no-submit-job:

itwinai generate-slurm -c slurm_config.yaml --no-save-script

If you want to store a copy of the script in a folder, then you can similarly omit --no-save-script:

itwinai generate-slurm -c slurm_config.yaml

4.1.4. Analyze the logs

Analyze the logs with MLFlow:

itwinai mlflow-ui --path mllogs/mlflow

4.1.5. Distributed GAN Documentation

This Guide provides a detailed explanation of how a simple Generative Adversarial Network (GAN) has been adapted to operate within a distributed environment using the GANTrainer. This adaptation enables more efficient training on larger datasets by leveraging distributed computing resources.

4.1.6. Overview

A Generative Adversarial Network consists of two key components:

  • Generator (G): Generates new data instances.

  • Discriminator (D): Evaluates them for authenticity, aiming to distinguish real instances from the fake ones generated by the Generator.

The training process involves iterative adjustments where the Generator tries to produce data indistinguishable from actual data, and the Discriminator improves its ability to detect fakes.

4.1.7. Steps to make a distributed GAN model

The code for all steps can be seen in the attached python file.

4.1.7.1. Step 1: Define Model Architecture

Both the Generator and Discriminator are defined using PyTorch’s nn.Module. The specific architecture for both includes convolutional layers that are well-suited for processing image data.

4.1.7.2. Step 2: Implement Distributed Training

The GANTrainer class extends the custom itwinai TorchTrainer class and handles the initialization of models, optimizers, and the distributed training strategy for the GAN. The snippet below shows how the GANTrainer is extending the TorchTrainer class and initializing the parameters. This is essentially done to handle the scenario for the GAN which comprises of two Neural Network models which is not handled by the TorchTrainer that expects and handles one model. We also create custom optimizers for the Optimizer and Discriminator GAN models:

4.1.7.3. Step 3: Training and Validation Logic

The training alternates between updating the Discriminator using real and generated images and training the Generator to fool the Discriminator. Validation evaluates the performance of the Generator in deceiving the Discriminator.

4.1.7.4. Step 4: Visualization and Monitoring

Training progress is monitored through visualizations of loss metrics and image samples generated periodically.

4.1.8. Takeaways

This readme describes steps taken to adapt a GAN for distributed training, aimed at enhancing efficiency and scalability for training on large-scale datasets. From this use case we learn the following:

  • itwinai TorchTrainer can easily be adapted to different unique use cases like the GAN that has two models.

  • Training models in a distributed environment may require some level of customization in the training architecturing but it comes with lots of performance improvements.

  • Distributed training for GANs requires a large dataset to reduce the chances of overfitting to the smaller data splits created during the training phase.

  • Always ensure that both the models, data and results for a specific process are accessible on the same device during training and validation in a distributed environment.

4.2. Python files

4.2.1. train.py

# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Henry Mutegeki
#
# Credit:
# - Henry Mutegeki <henry.mutegeki@cern.ch> - CERN
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# - Jarl Sondre Sæther <jarl.sondre.saether@cern.ch> - CERN
# --------------------------------------------------------------------------------------

import argparse

import torch
import torch.nn as nn
from torchvision import datasets, transforms

from itwinai.loggers import MLFlowLogger
from itwinai.torch.gan import GANTrainer, GANTrainingConfiguration


class Generator(nn.Module):
    def __init__(self, z_dim, g_hidden, image_channel):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input layer
            nn.ConvTranspose2d(z_dim, g_hidden * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(g_hidden * 8),
            nn.ReLU(True),
            # 1st hidden layer
            nn.ConvTranspose2d(g_hidden * 8, g_hidden * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(g_hidden * 4),
            nn.ReLU(True),
            # 2nd hidden layer
            nn.ConvTranspose2d(g_hidden * 4, g_hidden * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(g_hidden * 2),
            nn.ReLU(True),
            # 3rd hidden layer
            nn.ConvTranspose2d(g_hidden * 2, g_hidden, 4, 2, 1, bias=False),
            nn.BatchNorm2d(g_hidden),
            nn.ReLU(True),
            # output layer
            nn.ConvTranspose2d(g_hidden, image_channel, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, input):
        return self.main(input)


class Discriminator(nn.Module):
    def __init__(self, d_hidden, image_channel):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # 1st layer
            nn.Conv2d(image_channel, d_hidden, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 2nd layer
            nn.Conv2d(d_hidden, d_hidden * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(d_hidden * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 3rd layer
            nn.Conv2d(d_hidden * 2, d_hidden * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(d_hidden * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 4th layer
            nn.Conv2d(d_hidden * 4, d_hidden * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(d_hidden * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # output layer
            nn.Conv2d(d_hidden * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)


def main():
    parser = argparse.ArgumentParser(description="PyTorch MNIST GAN Example")
    parser.add_argument(
        "--batch-size",
        type=int,
        default=128,
        help="input batch size for training (default: 128)",
    )
    parser.add_argument(
        "--epochs", type=int, default=15, help="number of epochs to train (default: 15)"
    )
    parser.add_argument(
        "--strategy", type=str, default="ddp", help="distributed strategy (default=ddp)"
    )
    parser.add_argument(
        "--lr", type=float, default=0.001, help="learning rate (default: 0.001)"
    )
    parser.add_argument("--seed", type=int, default=1, help="random seed (default: 1)")
    parser.add_argument(
        "--ckpt-interval",
        type=int,
        default=2,
        help="how many batches to wait before logging training status",
    )
    args = parser.parse_args()

    torch.manual_seed(args.seed)

    # Dataset
    transform = transforms.Compose(
        [
            transforms.Resize(64),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ]
    )

    train_dataset = datasets.MNIST("../data", train=True, download=True, transform=transform)
    validation_dataset = datasets.MNIST("../data", train=False, transform=transform)

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find("Conv") != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find("BatchNorm") != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    # Models
    netG = Generator(z_dim=100, g_hidden=64, image_channel=1)
    netG.apply(weights_init)
    netD = Discriminator(d_hidden=64, image_channel=1)
    netD.apply(weights_init)

    # Training configuration
    training_config = GANTrainingConfiguration(
        batch_size=args.batch_size,
        optim_generator_lr=args.lr,
        optim_discriminator_lr=args.lr,
        z_dim=100,
    )

    # Logger
    logger = MLFlowLogger(experiment_name="Distributed GAN MNIST", log_freq=10)

    # Trainer
    trainer = GANTrainer(
        config=training_config,
        epochs=args.epochs,
        discriminator=netD,
        generator=netG,
        strategy=args.strategy,
        random_seed=args.seed,
        logger=logger,
        checkpoint_every=args.ckpt_interval,
    )

    # Launch training
    _, _, _, trained_model = trainer.execute(train_dataset, validation_dataset, None)


if __name__ == "__main__":
    main()