4. GAN tutorial with PyTorchο
4.1. Tutorial on itwinai TorchTrainer adapted for the distributed GAN using MNIST datasetο
Author(s): Henry Mutegeki (CERN), Matteo Bunino (CERN), Jarl Sondre Sæther (CERN)
The code is adapted from
this example
and a simple non-distributed GAN model can be found in a file named simpleGAN.py
that serves as a baseline GAN example but focus is mainly on the train.py
file for the distributed GAN use case.
4.1.1. Setupο
First, from the root of this repository, build the environment containing pytorch and deepspeed. Refer to the itwinai installation steps.
Then navigate to the project working directory
cd tutorials/distributed-ml/torch-tutorial-GAN
4.1.2. Distributed training on a single node (interactive)ο
If you want to use SLURM in interactive mode, do the following:
# Allocate resources
$ salloc --partition=batch --nodes=1 --account=intertwin --gres=gpu:4 --time=1:59:00
job ID is XXXX
# Get a shell in the compute node (if using SLURM)
$ srun --jobid XXXX --overlap --pty /bin/bash
# Now you are inside the compute node
# On JSC, you may need to load some modules
ml --force purge
ml Stages/2024 GCC OpenMPI CUDA/12 MPI-settings/CUDA Python HDF5 PnetCDF libaio mpi4py
# ...before activating the Python environment (adapt this to your env name/path)
source ../../../envAI_hdfml/bin/activate
For more info you can refer to this documentation page.
To launch the training with torch DDP use:
torchrun --standalone --nnodes=1 --nproc-per-node=gpu train.py --strategy ddp
# Optional -- from a SLURM login node:
srun --jobid XXXX --ntasks-per-node=1 torchrun --standalone --nnodes=1 --nproc-per-node=gpu train.py --strategy ddp
To launch the training with Microsoft DeepSpeed use:
deepspeed train.py -s deepspeed --deepspeed
# Optional -- from a SLURM login node:
srun --jobid XXXX --ntasks-per-node=1 deepspeed train.py --strategy deepspeed
4.1.3. Distributed training with SLURM (batch mode)ο
You can run your training with SLURM by using the itwinai SLURM Builder. Use the
slurm_config.yaml file to specify your SLURM parameters and then preview your script
with the following command:
itwinai generate-slurm -c slurm_config.yaml --no-save-script --no-submit-job
If you are happy with the script, you can then run it by omitting --no-submit-job:
itwinai generate-slurm -c slurm_config.yaml --no-save-script
If you want to store a copy of the script in a folder, then you can similarly omit
--no-save-script:
itwinai generate-slurm -c slurm_config.yaml
4.1.4. Analyze the logsο
Analyze the logs with MLFlow:
itwinai mlflow-ui --path mllogs/mlflow
4.1.5. Distributed GAN Documentationο
This Guide provides a detailed explanation of how a simple Generative Adversarial Network (GAN)
has been adapted to operate within a distributed environment using the GANTrainer. This
adaptation enables more efficient training on larger datasets by leveraging distributed
computing resources.
4.1.6. Overviewο
A Generative Adversarial Network consists of two key components:
Generator (G): Generates new data instances.
Discriminator (D): Evaluates them for authenticity, aiming to distinguish real instances from the fake ones generated by the Generator.
The training process involves iterative adjustments where the Generator tries to produce data indistinguishable from actual data, and the Discriminator improves its ability to detect fakes.
4.1.7. Steps to make a distributed GAN modelο
The code for all steps can be seen in the attached python file.
4.1.7.1. Step 1: Define Model Architectureο
Both the Generator and Discriminator are defined using PyTorchβs nn.Module. The specific
architecture for both includes convolutional layers that are well-suited for processing image
data.
4.1.7.2. Step 2: Implement Distributed Trainingο
The GANTrainer class extends the custom itwinai TorchTrainer class and handles the
initialization of models, optimizers, and the distributed training strategy for the GAN. The
snippet below shows how the GANTrainer is extending the TorchTrainer class and initializing the
parameters. This is essentially done to handle the scenario for the GAN which comprises of two
Neural Network models which is not handled by the TorchTrainer that expects and handles one
model. We also create custom optimizers for the Optimizer and Discriminator GAN models:
4.1.7.3. Step 3: Training and Validation Logicο
The training alternates between updating the Discriminator using real and generated images and training the Generator to fool the Discriminator. Validation evaluates the performance of the Generator in deceiving the Discriminator.
4.1.7.4. Step 4: Visualization and Monitoringο
Training progress is monitored through visualizations of loss metrics and image samples generated periodically.
4.1.8. Takeawaysο
This readme describes steps taken to adapt a GAN for distributed training, aimed at enhancing efficiency and scalability for training on large-scale datasets. From this use case we learn the following:
itwinai
TorchTrainercan easily be adapted to different unique use cases like the GAN that has two models.Training models in a distributed environment may require some level of customization in the training architecturing but it comes with lots of performance improvements.
Distributed training for GANs requires a large dataset to reduce the chances of overfitting to the smaller data splits created during the training phase.
Always ensure that both the models, data and results for a specific process are accessible on the same device during training and validation in a distributed environment.
4.2. Python filesο
4.2.1. train.pyο
# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Henry Mutegeki
#
# Credit:
# - Henry Mutegeki <henry.mutegeki@cern.ch> - CERN
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# - Jarl Sondre Sæther <jarl.sondre.saether@cern.ch> - CERN
# --------------------------------------------------------------------------------------
import argparse
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from itwinai.loggers import MLFlowLogger
from itwinai.torch.gan import GANTrainer, GANTrainingConfiguration
class Generator(nn.Module):
def __init__(self, z_dim, g_hidden, image_channel):
super(Generator, self).__init__()
self.main = nn.Sequential(
# input layer
nn.ConvTranspose2d(z_dim, g_hidden * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(g_hidden * 8),
nn.ReLU(True),
# 1st hidden layer
nn.ConvTranspose2d(g_hidden * 8, g_hidden * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(g_hidden * 4),
nn.ReLU(True),
# 2nd hidden layer
nn.ConvTranspose2d(g_hidden * 4, g_hidden * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(g_hidden * 2),
nn.ReLU(True),
# 3rd hidden layer
nn.ConvTranspose2d(g_hidden * 2, g_hidden, 4, 2, 1, bias=False),
nn.BatchNorm2d(g_hidden),
nn.ReLU(True),
# output layer
nn.ConvTranspose2d(g_hidden, image_channel, 4, 2, 1, bias=False),
nn.Tanh(),
)
def forward(self, input):
return self.main(input)
class Discriminator(nn.Module):
def __init__(self, d_hidden, image_channel):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
# 1st layer
nn.Conv2d(image_channel, d_hidden, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 2nd layer
nn.Conv2d(d_hidden, d_hidden * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(d_hidden * 2),
nn.LeakyReLU(0.2, inplace=True),
# 3rd layer
nn.Conv2d(d_hidden * 2, d_hidden * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(d_hidden * 4),
nn.LeakyReLU(0.2, inplace=True),
# 4th layer
nn.Conv2d(d_hidden * 4, d_hidden * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(d_hidden * 8),
nn.LeakyReLU(0.2, inplace=True),
# output layer
nn.Conv2d(d_hidden * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid(),
)
def forward(self, input):
return self.main(input).view(-1, 1).squeeze(1)
def main():
parser = argparse.ArgumentParser(description="PyTorch MNIST GAN Example")
parser.add_argument(
"--batch-size",
type=int,
default=128,
help="input batch size for training (default: 128)",
)
parser.add_argument(
"--epochs", type=int, default=15, help="number of epochs to train (default: 15)"
)
parser.add_argument(
"--strategy", type=str, default="ddp", help="distributed strategy (default=ddp)"
)
parser.add_argument(
"--lr", type=float, default=0.001, help="learning rate (default: 0.001)"
)
parser.add_argument("--seed", type=int, default=1, help="random seed (default: 1)")
parser.add_argument(
"--ckpt-interval",
type=int,
default=2,
help="how many batches to wait before logging training status",
)
args = parser.parse_args()
torch.manual_seed(args.seed)
# Dataset
transform = transforms.Compose(
[
transforms.Resize(64),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
]
)
train_dataset = datasets.MNIST("../data", train=True, download=True, transform=transform)
validation_dataset = datasets.MNIST("../data", train=False, transform=transform)
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find("BatchNorm") != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
# Models
netG = Generator(z_dim=100, g_hidden=64, image_channel=1)
netG.apply(weights_init)
netD = Discriminator(d_hidden=64, image_channel=1)
netD.apply(weights_init)
# Training configuration
training_config = GANTrainingConfiguration(
batch_size=args.batch_size,
optim_generator_lr=args.lr,
optim_discriminator_lr=args.lr,
z_dim=100,
)
# Logger
logger = MLFlowLogger(experiment_name="Distributed GAN MNIST", log_freq=10)
# Trainer
trainer = GANTrainer(
config=training_config,
epochs=args.epochs,
discriminator=netD,
generator=netG,
strategy=args.strategy,
random_seed=args.seed,
logger=logger,
checkpoint_every=args.ckpt_interval,
)
# Launch training
_, _, _, trained_model = trainer.execute(train_dataset, validation_dataset, None)
if __name__ == "__main__":
main()
4.2.2. simpleGAN.pyο
# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Henry Mutegeki
#
# Credit:
# - Henry Mutegeki <henry.mutegeki@cern.ch> - CERN
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# --------------------------------------------------------------------------------------
"""This script shows a simple example on how to train a GAN without using itwinai."""
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
DATA_PATH = "../data"
BATCH_SIZE = 128
IMAGE_CHANNEL = 1
Z_DIM = 100
G_HIDDEN = 64
X_DIM = 64
D_HIDDEN = 64
EPOCH_NUM = 10
REAL_LABEL = 1
FAKE_LABEL = 0
lr = 2e-4
seed = 1
USE_CUDA = False
CUDA = torch.cuda.is_available() and USE_CUDA
device = torch.device("cuda:0" if CUDA else "cpu")
print("PyTorch version: {}".format(torch.__version__))
if CUDA:
print("CUDA version: {}\n".format(torch.version.cuda))
torch.cuda.manual_seed(seed)
cudnn.benchmark = True
# Data preprocessing
dataset = dset.MNIST(
root=DATA_PATH,
download=False,
transform=transforms.Compose(
[transforms.Resize(X_DIM), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
),
)
# Dataloader
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2
)
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find("BatchNorm") != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
# input layer
nn.ConvTranspose2d(Z_DIM, G_HIDDEN * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(G_HIDDEN * 8),
nn.ReLU(True),
# 1st hidden layer
nn.ConvTranspose2d(G_HIDDEN * 8, G_HIDDEN * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(G_HIDDEN * 4),
nn.ReLU(True),
# 2nd hidden layer
nn.ConvTranspose2d(G_HIDDEN * 4, G_HIDDEN * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(G_HIDDEN * 2),
nn.ReLU(True),
# 3rd hidden layer
nn.ConvTranspose2d(G_HIDDEN * 2, G_HIDDEN, 4, 2, 1, bias=False),
nn.BatchNorm2d(G_HIDDEN),
nn.ReLU(True),
# output layer
nn.ConvTranspose2d(G_HIDDEN, IMAGE_CHANNEL, 4, 2, 1, bias=False),
nn.Tanh(),
)
def forward(self, input):
return self.main(input)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
# 1st layer
nn.Conv2d(IMAGE_CHANNEL, D_HIDDEN, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 2nd layer
nn.Conv2d(D_HIDDEN, D_HIDDEN * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(D_HIDDEN * 2),
nn.LeakyReLU(0.2, inplace=True),
# 3rd layer
nn.Conv2d(D_HIDDEN * 2, D_HIDDEN * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(D_HIDDEN * 4),
nn.LeakyReLU(0.2, inplace=True),
# 4th layer
nn.Conv2d(D_HIDDEN * 4, D_HIDDEN * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(D_HIDDEN * 8),
nn.LeakyReLU(0.2, inplace=True),
# output layer
nn.Conv2d(D_HIDDEN * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid(),
)
def forward(self, input):
return self.main(input).view(-1, 1).squeeze(1)
# Create the generator
netG = Generator().to(device)
netG.apply(weights_init)
print(netG)
# Create the discriminator
netD = Discriminator().to(device)
netD.apply(weights_init)
print(netD)
# Initialize BCELoss function
criterion = nn.BCELoss()
# Create batch of latent vectors to visualize the progression of the generator
viz_noise = torch.randn(BATCH_SIZE, Z_DIM, 1, 1, device=device)
# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
# Training Loop
def train_GAN_model(EPOCH_NUM, netD, netG, optimizerG, optimizerD, dataloader, criterion):
img_list = []
G_losses = []
D_losses = []
iters = 0
for epoch in range(EPOCH_NUM):
for i, data in enumerate(dataloader, 0):
# (1) Update the discriminator with real data
netD.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), REAL_LABEL, dtype=torch.float, device=device)
# Forward pass real batch through D
output = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
# (2) Update the discriminator with fake data
# Generate batch of latent vectors
noise = torch.randn(b_size, Z_DIM, 1, 1, device=device)
# Generate fake image batch with G
fake = netG(noise)
label.fill_(FAKE_LABEL)
# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch, accumulated (summed)
# with previous gradients
errD_fake.backward()
D_G_z1 = output.mean().item()
# Compute error of D as sum over the fake and the real batches
errD = errD_real + errD_fake
# Update D
optimizerD.step()
# (3) Update the generator with fake data
netG.zero_grad()
label.fill_(REAL_LABEL) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of
# all-fake batch through D
output = netD(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()
# Output training stats
if i % 50 == 0:
print(
(
"[ % d/%d][%d/%d]\tLoss_D: % .4f\tLoss_G:"
" % .4f\tD(x): % .4f\tD(G(z)): % .4f / %.4f"
)
% (
epoch,
EPOCH_NUM,
i,
len(dataloader),
errD.item(),
errG.item(),
D_x,
D_G_z1,
D_G_z2,
)
)
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing
if (iters % 500 == 0) or ((epoch == EPOCH_NUM - 1) and (i == len(dataloader) - 1)):
with torch.no_grad():
fake = netG(viz_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
# plt.savefig('simpleGANlearning_curve.png')
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))
# Plot the real images
plt.figure(figsize=(15, 15))
plt.subplot(1, 2, 1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(
np.transpose(
vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),
(1, 2, 0),
)
)
# plt.savefig('simpleGANreal_image.png')
# Plot the fake images from the last epoch
plt.subplot(1, 2, 2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
plt.show()
# plt.savefig('simpleGANfake_image.png')
train_GAN_model(EPOCH_NUM, netD, netG, optimizerG, optimizerD, dataloader, criterion)