3. Using the itwinai TorchTrainer Class

The code used in this tutorial is adapted from this example.

The itwinai TorchTrainer class works as a wrapper that manages most aspects of training. It facilitates distributed machine learning and allows for extensive customization by subclassing and overriding the desired methods.

You can find all the associated code in the GitHub repository.

3.1. Setting Up the Training Script

The following is an outline on how you can setup the training script:

# Create dataset as usual
train_dataset = ...

# Create model as usual
model = ...

trainer = TorchTrainer(config={}, model=model, strategy="ddp")

_, _, _, trained_model = trainer.execute(train_dataset, ...)

3.2. Launching Distributed Training

To launch the training across multiple workers, i.e. with multiple GPUs, potentially across multiple nodes, you can use torchrun to allow the processes to communicate between them. If you are on a system that uses SLURM, you can combine srun and torchrun to start the processes on different nodes as well. Here is an example on how you could do this, assuming your code is in train.py:

srun --cpu-bind=none --ntasks-per-node=1 \
    bash -c "torchrun \
    --nnodes=2 \
    --nproc_per_node=4 \
    --rdzv_id=151152 \
    --rdzv_conf=is_host=\$(((SLURM_NODEID)) && echo 0 || echo 1) \
    --rdzv_backend=c10d \
    --rdzv_endpoint='$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)'i:29500 \
    python train.py"

3.3. Complete TorchTrainer Example

Below we have a complete example of how to use the TorchTrainer to train a model on the MNIST dataset, which can be seen on Github here. This can be run locally using

python train.py

or in a distributed manner as explained in the section above. If you wish to analyze the resulting MLFlow logs, you can use the following command:

itwinai mlflow-ui --path mllogs/mlflow

Note

You might have to change the port or the host, depending on which system you are on. If you are running this on a server and wish to port-forward the result to your local machine, then you have to change out the host using --host to 0.0.0.0. For more information on this, look for information on how to forward ports with SSH online.

Here you can see the contents of train.py:

  1# --------------------------------------------------------------------------------------
  2# Part of the interTwin Project: https://www.intertwin.eu/
  3#
  4# Created by: Matteo Bunino
  5#
  6# Credit:
  7# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
  8# - Jarl Sondre Sæther <jarl.sondre.saether@cern.ch> - CERN
  9# --------------------------------------------------------------------------------------
 10
 11"""Adapted from: https://github.com/pytorch/examples/blob/main/mnist/main.py"""
 12
 13import argparse
 14
 15import torch
 16import torch.nn as nn
 17import torch.nn.functional as F
 18import torchmetrics
 19from torchvision import datasets, transforms
 20
 21from itwinai.loggers import MLFlowLogger
 22from itwinai.torch.config import TrainingConfiguration
 23from itwinai.torch.trainer import TorchTrainer
 24
 25
 26# Step 1: setup your neural network architecture
 27class Net(nn.Module):
 28    def __init__(self):
 29        super(Net, self).__init__()
 30        self.conv1 = nn.Conv2d(1, 32, 3, 1)
 31        self.conv2 = nn.Conv2d(32, 64, 3, 1)
 32        self.dropout1 = nn.Dropout(0.25)
 33        self.dropout2 = nn.Dropout(0.5)
 34        self.fc1 = nn.Linear(9216, 128)
 35        self.fc2 = nn.Linear(128, 10)
 36
 37    def forward(self, x):
 38        x = self.conv1(x)
 39        x = F.relu(x)
 40        x = self.conv2(x)
 41        x = F.relu(x)
 42        x = F.max_pool2d(x, 2)
 43        x = self.dropout1(x)
 44        x = torch.flatten(x, 1)
 45        x = self.fc1(x)
 46        x = F.relu(x)
 47        x = self.dropout2(x)
 48        x = self.fc2(x)
 49        return x
 50
 51
 52def main():
 53    # Step 2 (optional): Parse your arguments from the command line
 54    parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
 55    parser.add_argument(
 56        "--batch-size",
 57        type=int,
 58        default=64,
 59        help="input batch size for training (default: 64)",
 60    )
 61    parser.add_argument(
 62        "--epochs", type=int, default=14, help="number of epochs to train (default: 14)"
 63    )
 64    parser.add_argument(
 65        "--strategy", type=str, default="ddp", help="distributed strategy (default=ddp)"
 66    )
 67    parser.add_argument(
 68        "--lr", type=float, default=1.0, help="learning rate (default: 1.0)"
 69    )
 70    parser.add_argument("--seed", type=int, default=1, help="random seed (default: 1)")
 71    parser.add_argument(
 72        "--ckpt-interval",
 73        type=int,
 74        default=10,
 75        help="how many batches to wait before logging training status",
 76    )
 77    args = parser.parse_args()
 78
 79    # Step 3: Create your datasets
 80    transform = transforms.Compose(
 81        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
 82    )
 83    train_dataset = datasets.MNIST(
 84        "../data", train=True, download=True, transform=transform
 85    )
 86    validation_dataset = datasets.MNIST("../data", train=False, transform=transform)
 87
 88    # Step 4: Configure your model and your training configuration
 89    model = Net()
 90
 91    training_config = TrainingConfiguration(
 92        batch_size=args.batch_size,
 93        optim_lr=args.lr,
 94        optimizer="adadelta",
 95        loss="cross_entropy",
 96    )
 97
 98    # Step 5 (optional): Configure a logger and some metrics
 99    logger = MLFlowLogger(experiment_name="mnist-tutorial", log_freq=10)
100
101    metrics = {
102        "accuracy": torchmetrics.Accuracy(task="multiclass", num_classes=10),
103        "precision": torchmetrics.Precision(task="multiclass", num_classes=10),
104    }
105
106    # Step 6: Create your Trainer
107    trainer = TorchTrainer(
108        config=training_config,
109        model=model,
110        metrics=metrics,
111        logger=logger,
112        strategy=args.strategy,
113        epochs=args.epochs,
114        random_seed=args.seed,
115        checkpoint_every=args.ckpt_interval,
116    )
117
118    # Step 7: Launch your training
119    _, _, _, trained_model = trainer.execute(train_dataset, validation_dataset, None)
120
121
122if __name__ == "__main__":
123    main()