3DGAN

This section covers the CERN use case that utilizes the torch-lightning framework for training and evaluation. Following you can find instructions to execute CERN use case and its integral scripts:

itwinai x 3DGAN

First of all, from the repository root, create a torch environment following the installation instructions.

Now, install custom requirements for this use case in requirements.txt file. Example:

source .venv-pytorch/bin/activate
cd use-cases/3dgan
pip install -r requirements.txt

[!NOTE] Python commands below assumed to be executed from within the virtual environment.

Training

Make sure to be in the use-cases/3dgan folder. Before you can start training, you have to download the data using the dataloading script:

itwinai exec-pipeline --config config.yaml --pipe-key training_pipeline --steps dataloading_step

Now you can launch training using itwinai and the provided training configuration config.yaml:

itwinai exec-pipeline --config config.yaml --pipe-key training_pipeline

The command above shows how to run the training using a single worker, but if you want to run distributed ML training you have two options: interactive (launch from terminal) or batch (launch form SLURM job script).

[!WARNING] Before running distributed ML, make sure that the distributed strategy used by pytorch lightning is set to ddp_find_unused_parameters_true . You can set this manually by setting distributed_strategy: ddp_find_unused_parameters_true in config.yaml.

To know more on SLURM, see our SLURM cheatsheet.

Distributed training on a single node (interactive)

If you want to use SLURM in interactive mode, do the following:

# Allocate resources (on JSC)
$ salloc --partition=batch --nodes=1 --account=intertwin  --gres=gpu:4 --time=1:59:00
job ID is XXXX
# Get a shell in the compute node (if using SLURM)
$ srun --jobid XXXX --overlap --pty /bin/bash
# Now you are inside the compute node

# On JSC, you may need to load some modules
ml --force purge
ml Stages/2024 GCC OpenMPI CUDA/12 MPI-settings/CUDA Python HDF5 PnetCDF libaio mpi4py

# ...before activating the Python environment (adapt this to your env name/path)
source ../../envAI_hdfml/bin/activate

To launch the training with torch DDP use:

torchrun --standalone --nnodes=1 --nproc-per-node=gpu \
    $(which itwinai) exec-pipeline --config config.yaml --pipe-key training_pipeline

# Alternatively, from a SLURM login node:
srun --jobid XXXX --ntasks-per-node=1 torchrun --standalone --nnodes=1 --nproc-per-node=gpu \
    $(which itwinai) exec-pipeline --config config.yaml --pipe-key training_pipeline

Distributed training with SLURM (batch mode)

Differently from the interactive approach, this way allows you to use more than one compute node, thus allowing to scale the distributed ML to larger resources.

Remember that on JSC there is no internet connection on compute nodes, thus if your script tries to contact the internet it will fail. If needed, make sure to download the datasets from the SLURM login node before launching the job.

# Launch a SLURM batch job (on JSC)
sbatch slurm.jsc.sh

# Launch a SLURM batch job (on Vega)
sbatch slurm.vega.sh

# Check the job in the SLURM queue
squeue -u YOUR_USERNAME

# Check the job status
sacct -j JOBID

Job’s stdout is usually saved to job.out and its stderr is saved to job.err.

Visualize the results of training

Depending on the logging service that you are using, there are different ways to inspect the logs generated during ML training.

To visualize the logs generated with MLFLow, if you set a local path as tracking URI, run the following in the terminal:

mlflow ui --backend-store-uri LOCAL_TRACKING_URI

And select the β€œ3DGAN” experiment.

Inference

  1. As inference dataset we can reuse training/validation dataset, for instance the one downloaded from Google Drive folder: if the dataset root folder is not present, the dataset will be downloaded. The inference dataset is a set of H5 files stored inside exp_data sub-folders:

    β”œβ”€β”€ exp_data
    β”‚   β”œβ”€β”€ data
    |   β”‚   β”œβ”€β”€ file_0.h5
    |   β”‚   β”œβ”€β”€ file_1.h5
    ...
    |   β”‚   β”œβ”€β”€ file_N.h5
    
  2. As model, if a pre-trained checkpoint is not available, we can create a dummy version of it with:

    python create_inference_sample.py
    
  3. Run inference command. This will generate a 3dgan-generated-data folder containing generated particle traces in form of torch tensors (.pth files) and 3D scatter plots (.jpg images).

    itwinai exec-pipeline --config config.yaml --pipe-key inference_pipeline
    

The inference execution will produce a folder called 3dgan-generated-data containing generated 3D particle trajectories (overwritten if already there). Each generated 3D image is stored both as a torch tensor (.pth) and 3D scatter plot (.jpg):

β”œβ”€β”€ 3dgan-generated-data
|   β”œβ”€β”€ energy=1.296749234199524&angle=1.272539496421814.pth
|   β”œβ”€β”€ energy=1.296749234199524&angle=1.272539496421814.jpg
...
|   β”œβ”€β”€ energy=1.664689540863037&angle=1.4906378984451294.pth
|   β”œβ”€β”€ energy=1.664689540863037&angle=1.4906378984451294.jpg

However, if aggregate_predictions in the ParticleImagesSaver step is set to True, only one pickled file will be generated inside 3dgan-generated-data folder. Notice that multiple inference calls will create new files under 3dgan-generated-data folder.

With fields overriding:

# Override variables
export CERN_DATA_ROOT="../.."  # data root
export TMP_DATA_ROOT=$CERN_DATA_ROOT
export CERN_CODE_ROOT="." # where code and configuration are stored
export MAX_DATA_SAMPLES=20000 # max dataset size
export BATCH_SIZE=1024 # increase to fill up GPU memory
export NUM_WORKERS_DL=4 # num worker processes used by the dataloader to pre-fetch data
export AGGREGATE_PREDS="true" # write predictions in a single file
export ACCELERATOR="gpu" # choose "cpu" or "gpu"
export STRATEGY="auto" # distributed strategy
export DEVICES="0," # GPU devices list


itwinai exec-pipeline --print-config --config $CERN_CODE_ROOT/config.yaml \
    --pipe-key inference_pipeline \
    -o dataset_location=$CERN_DATA_ROOT/exp_data \
    -o logs_dir=$TMP_DATA_ROOT/ml_logs/mlflow_logs \
    -o distributed_strategy=$STRATEGY \
    -o devices=$DEVICES \
    -o hw_accelerators=$ACCELERATOR \
    -o checkpoints_path=$TMP_DATA_ROOT/checkpoints \
    -o inference_model_uri=$CERN_CODE_ROOT/3dgan-inference.pth \
    -o max_dataset_size=$MAX_DATA_SAMPLES \
    -o batch_size=$BATCH_SIZE \
    -o num_workers_dataloader=$NUM_WORKERS_DL \
    -o inference_results_location=$TMP_DATA_ROOT/3dgan-generated-data \
    -o aggregate_predictions=$AGGREGATE_PREDS

Docker image

Build from project root with

# Local
docker buildx build -t itwinai:0.0.1-3dgan-0.1 -f use-cases/3dgan/Dockerfile .

# Ghcr.io
docker buildx build -t ghcr.io/intertwin-eu/itwinai:0.0.1-3dgan-0.1 -f use-cases/3dgan/Dockerfile .
docker push ghcr.io/intertwin-eu/itwinai:0.0.1-3dgan-0.1

You can run inference from wherever a sample of H5 files is available (folder called exp_data/’):

β”œβ”€β”€ $PWD
|   β”œβ”€β”€ exp_data
|   β”‚   β”œβ”€β”€ data
|   |   β”‚   β”œβ”€β”€ file_0.h5
|   |   β”‚   β”œβ”€β”€ file_1.h5
...
|   |   β”‚   β”œβ”€β”€ file_N.h5
docker run -it --rm --name running-inference -v "$PWD":/tmp/data ghcr.io/intertwin-eu/itwinai:0.0.1-3dgan-0.1

This command will store the results in a folder called 3dgan-generated-data:

β”œβ”€β”€ $PWD
|   β”œβ”€β”€ 3dgan-generated-data
|   β”‚   β”œβ”€β”€ energy=1.296749234199524&angle=1.272539496421814.pth
|   β”‚   β”œβ”€β”€ energy=1.296749234199524&angle=1.272539496421814.jpg
...
|   β”‚   β”œβ”€β”€ energy=1.664689540863037&angle=1.4906378984451294.pth
|   β”‚   β”œβ”€β”€ energy=1.664689540863037&angle=1.4906378984451294.jpg

To override fields in the configuration file at runtime, you can use the -o flag. Example: -o path.to.config.element=NEW_VALUE.

Please find a complete exampled below, showing how to override default configurations by setting some env variables:

# Override variables
export CERN_DATA_ROOT="/usr/data"
export CERN_CODE_ROOT="/usr/src/app"
export MAX_DATA_SAMPLES=10 # max dataset size
export BATCH_SIZE=64 # increase to fill up GPU memory
export NUM_WORKERS_DL=4 # num worker processes used by the dataloader to pre-fetch data
export AGGREGATE_PREDS="true" # write predictions in a single file
export ACCELERATOR="gpu" # choose "cpu" or "gpu"

docker run -it --rm --name running-inference \
-v "$PWD":/usr/data ghcr.io/intertwin-eu/itwinai:0.0.1-3dgan-0.1 \
/bin/bash -c "itwinai exec-pipeline \
    --print-config --config $CERN_CODE_ROOT/config.yaml \
    --pipe-key inference_pipeline \
    -o dataset_location=$CERN_DATA_ROOT/exp_data \
    -o logs_dir=$TMP_DATA_ROOT/ml_logs/mlflow_logs \
    -o distributed_strategy=$STRATEGY \
    -o devices=$DEVICES \
    -o hw_accelerators=$ACCELERATOR \
    -o checkpoints_path=$TMP_DATA_ROOT/checkpoints \
    -o inference_model_uri=$CERN_CODE_ROOT/3dgan-inference.pth \
    -o max_dataset_size=$MAX_DATA_SAMPLES \
    -o batch_size=$BATCH_SIZE \
    -o num_workers_dataloader=$NUM_WORKERS_DL \
    -o inference_results_location=$TMP_DATA_ROOT/3dgan-generated-data \
    -o aggregate_predictions=$AGGREGATE_PREDS "

How to fully exploit GPU resources

Keeping the example above as reference, increase the value of BATCH_SIZE as much as possible (just below β€œout of memory” errors). Also, make sure that ACCELERATOR="gpu". Also, make sure to use a dataset large enough by changing the value of MAX_DATA_SAMPLES to collect meaningful performance data. Consider that each H5 file contains roughly 5k items, thus setting MAX_DATA_SAMPLES=10000 should be enough to use all items in each input H5 file.

You can try:

export MAX_DATA_SAMPLES=10000 # max dataset size
export BATCH_SIZE=1024 # increase to fill up GPU memory
export ACCELERATOR="gpu

Singularity

Run Docker container with Singularity:

singularity run --nv -B "$PWD":/usr/data docker://ghcr.io/intertwin-eu/itwinai:0.0.1-3dgan-0.1 /bin/bash -c \
"cd /usr/src/app && itwinai exec-pipeline --config config.yaml --pipe-key inference_pipeline"

Example with overrides (as above for Docker):

# Override variables
export CERN_DATA_ROOT="/usr/data"
export CERN_CODE_ROOT="/usr/src/app"
export MAX_DATA_SAMPLES=10 # max dataset size
export BATCH_SIZE=64 # increase to fill up GPU memory
export NUM_WORKERS_DL=4 # num worker processes used by the dataloader to pre-fetch data
export AGGREGATE_PREDS="true" # write predictions in a single file
export ACCELERATOR="gpu" # choose "cpu" or "gpu"

singularity run --nv -B "$PWD":/usr/data docker://ghcr.io/intertwin-eu/itwinai:0.0.1-3dgan-0.1 /bin/bash -c \
"cd /usr/src/app && itwinai exec-pipeline \
    --print-config --config $CERN_CODE_ROOT/config.yaml \
    --pipe-key inference_pipeline \
    -o dataset_location=$CERN_DATA_ROOT/exp_data \
    -o logs_dir=$TMP_DATA_ROOT/ml_logs/mlflow_logs \
    -o distributed_strategy=$STRATEGY \
    -o devices=$DEVICES \
    -o hw_accelerators=$ACCELERATOR \
    -o checkpoints_path=$TMP_DATA_ROOT/checkpoints \
    -o inference_model_uri=$CERN_CODE_ROOT/3dgan-inference.pth \
    -o max_dataset_size=$MAX_DATA_SAMPLES \
    -o batch_size=$BATCH_SIZE \
    -o num_workers_dataloader=$NUM_WORKERS_DL \
    -o inference_results_location=$TMP_DATA_ROOT/3dgan-generated-data \
    -o aggregate_predictions=$AGGREGATE_PREDS "

model.py

import sys
import os
# import pickle
from collections import defaultdict
import math
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning as pl
import numpy as np

from itwinai.loggers import Logger as BaseItwinaiLogger


class Generator(nn.Module):
    def __init__(self, latent_dim):  # img_shape
        super().__init__()
        # self.img_shape = img_shape
        self.latent_dim = latent_dim

        self.l1 = nn.Linear(self.latent_dim, 5184)
        self.up1 = nn.Upsample(
            scale_factor=(6, 6, 6),
            mode='trilinear',
            align_corners=False
        )
        self.conv1 = nn.Conv3d(
            in_channels=8, out_channels=8,
            kernel_size=(6, 6, 8), padding=0
        )
        nn.init.kaiming_uniform_(self.conv1.weight)
        # num_features is the number of channels (see doc)
        self.bn1 = nn.BatchNorm3d(num_features=8, eps=1e-6)
        self.pad1 = nn.ConstantPad3d((1, 1, 2, 2, 2, 2), 0)

        self.conv2 = nn.Conv3d(
            in_channels=8, out_channels=6,
            kernel_size=(4, 4, 6), padding=0
        )
        nn.init.kaiming_uniform_(self.conv2.weight)
        self.bn2 = nn.BatchNorm3d(num_features=6, eps=1e-6)
        self.pad2 = nn.ConstantPad3d((1, 1, 2, 2, 2, 2), 0)

        self.conv3 = nn.Conv3d(
            in_channels=6, out_channels=6,
            kernel_size=(4, 4, 6), padding=0
        )
        nn.init.kaiming_uniform_(self.conv3.weight)
        self.bn3 = nn.BatchNorm3d(num_features=6, eps=1e-6)
        self.pad3 = nn.ConstantPad3d((1, 1, 2, 2, 2, 2), 0)

        self.conv4 = nn.Conv3d(
            in_channels=6, out_channels=6,
            kernel_size=(4, 4, 6), padding=0
        )
        nn.init.kaiming_uniform_(self.conv4.weight)
        self.bn4 = nn.BatchNorm3d(num_features=6, eps=1e-6)
        self.pad4 = nn.ConstantPad3d((0, 0, 1, 1, 1, 1), 0)

        self.conv5 = nn.Conv3d(
            in_channels=6, out_channels=6,
            kernel_size=(3, 3, 5), padding=0
        )
        nn.init.kaiming_uniform_(self.conv5.weight)
        self.bn5 = nn.BatchNorm3d(num_features=6, eps=1e-6)
        self.pad5 = nn.ConstantPad3d((0, 0, 1, 1, 1, 1), 0)

        self.conv6 = nn.Conv3d(
            in_channels=6, out_channels=6,
            kernel_size=(3, 3, 3), padding=0
        )
        nn.init.kaiming_uniform_(self.conv6.weight)

        self.conv7 = nn.Conv3d(
            in_channels=6, out_channels=1,
            kernel_size=(2, 2, 2), padding=0
        )
        nn.init.xavier_normal_(self.conv7.weight)

    def forward(self, z):
        img = self.l1(z)
        img = img.view(-1, 8, 9, 9, 8)
        img = self.up1(img)
        img = self.conv1(img)
        img = F.relu(img)
        img = self.bn1(img)
        img = self.pad1(img)

        img = self.conv2(img)
        img = F.relu(img)
        img = self.bn2(img)
        img = self.pad2(img)

        img = self.conv3(img)
        img = F.relu(img)
        img = self.bn3(img)
        img = self.pad3(img)

        img = self.conv4(img)
        img = F.relu(img)
        img = self.bn4(img)
        img = self.pad4(img)

        img = self.conv5(img)
        img = F.relu(img)
        img = self.bn5(img)
        img = self.pad5(img)

        img = self.conv6(img)
        img = F.relu(img)

        img = self.conv7(img)
        img = F.relu(img)

        return img


class Discriminator(nn.Module):
    def __init__(self, power):
        super().__init__()

        self.power = power

        self.conv1 = nn.Conv3d(
            in_channels=1, out_channels=16,
            kernel_size=(5, 6, 6), padding=(2, 3, 3)
        )
        self.drop1 = nn.Dropout(0.2)
        self.pad1 = nn.ConstantPad3d((1, 1, 0, 0, 0, 0), 0)

        self.conv2 = nn.Conv3d(
            in_channels=16, out_channels=8,
            kernel_size=(5, 6, 6), padding=0
        )
        self.bn1 = nn.BatchNorm3d(num_features=8, eps=1e-6)
        self.drop2 = nn.Dropout(0.2)
        self.pad2 = nn.ConstantPad3d((1, 1, 0, 0, 0, 0), 0)

        self.conv3 = nn.Conv3d(
            in_channels=8, out_channels=8,
            kernel_size=(5, 6, 6), padding=0
        )
        self.bn2 = nn.BatchNorm3d(num_features=8, eps=1e-6)
        self.drop3 = nn.Dropout(0.2)

        self.conv4 = nn.Conv3d(
            in_channels=8, out_channels=8,
            kernel_size=(5, 6, 6), padding=0
        )
        self.bn3 = nn.BatchNorm3d(num_features=8, eps=1e-6)
        self.drop4 = nn.Dropout(0.2)

        self.avgpool = nn.AvgPool3d((2, 2, 2))
        self.flatten = nn.Flatten()

        # The input features for the Linear layer need to be calculated based
        # on the output shape from the previous layers.
        self.fakeout = nn.Linear(19152, 1)
        self.auxout = nn.Linear(19152, 1)  # The same as above for this layer.

    # calculate sum of intensities
    def ecal_sum(self, image, daxis):
        sum = torch.sum(image, dim=daxis)
        return sum

    # angle calculation
    def ecal_angle(self, image, daxis1):
        image = torch.squeeze(image, dim=daxis1)  # squeeze along channel axis

        # get shapes
        x_shape = image.shape[1]
        y_shape = image.shape[2]
        z_shape = image.shape[3]
        sumtot = torch.sum(image, dim=(1, 2, 3))  # sum of events

        # get 1. where event sum is 0 and 0 elsewhere
        amask = torch.where(sumtot == 0.0, torch.ones_like(
            sumtot), torch.zeros_like(sumtot))
        # masked_events = torch.sum(amask)  # counting zero sum events

        # ref denotes barycenter as that is our reference point
        x_ref = torch.sum(torch.sum(image, dim=(2, 3))
                          * (torch.arange(x_shape, device=image.device,
                             dtype=torch.float32).unsqueeze(0) + 0.5),
                          dim=1,)  # sum for x position * x index
        y_ref = torch.sum(
            torch.sum(image, dim=(1, 3))
            * (torch.arange(y_shape, device=image.device,
               dtype=torch.float32).unsqueeze(0) + 0.5),
            dim=1,)
        z_ref = torch.sum(
            torch.sum(image, dim=(1, 2))
            * (torch.arange(z_shape, device=image.device,
               dtype=torch.float32).unsqueeze(0) + 0.5),
            dim=1,)

        # return max position if sumtot=0 and divide by sumtot otherwise
        x_ref = torch.where(
            sumtot == 0.0, torch.ones_like(x_ref), x_ref / sumtot)
        y_ref = torch.where(
            sumtot == 0.0, torch.ones_like(y_ref), y_ref / sumtot)
        z_ref = torch.where(
            sumtot == 0.0, torch.ones_like(z_ref), z_ref / sumtot)

        # reshape
        x_ref = x_ref.unsqueeze(1)
        y_ref = y_ref.unsqueeze(1)
        z_ref = z_ref.unsqueeze(1)

        sumz = torch.sum(image, dim=(1, 2))  # sum for x,y planes going along z

        # Get 0 where sum along z is 0 and 1 elsewhere
        zmask = torch.where(sumz == 0.0, torch.zeros_like(
            sumz), torch.ones_like(sumz))

        x = torch.arange(x_shape, device=image.device).unsqueeze(
            0)  # x indexes
        x = (x.unsqueeze(2).float()) + 0.5
        y = torch.arange(y_shape, device=image.device).unsqueeze(
            0)  # y indexes
        y = (y.unsqueeze(2).float()) + 0.5

        # barycenter for each z position
        x_mid = torch.sum(torch.sum(image, dim=2) * x, dim=1)
        y_mid = torch.sum(torch.sum(image, dim=1) * y, dim=1)

        x_mid = torch.where(sumz == 0.0, torch.zeros_like(
            sumz), x_mid / sumz)  # if sum != 0 then divide by sum
        y_mid = torch.where(sumz == 0.0, torch.zeros_like(
            sumz), y_mid / sumz)  # if sum != 0 then divide by sum

        # Angle Calculations
        z = (torch.arange(
            z_shape,
            device=image.device,
            dtype=torch.float32
            # Make an array of z indexes for all events
        ) + 0.5) * torch.ones_like(z_ref)

        # projection from z axis with stability check
        zproj = torch.sqrt(
            torch.max(
                (x_mid - x_ref) ** 2.0 + (z - z_ref) ** 2.0,
                torch.tensor(
                    [torch.finfo(torch.float32).eps]
                ).to(x_mid.device)
            )
        )
        # torch.finfo(torch.float32).eps))
        # to avoid divide by zero for zproj =0
        m = torch.where(zproj == 0.0, torch.zeros_like(
            zproj), (y_mid - y_ref) / zproj)
        m = torch.where(z < z_ref, -1 * m, m)  # sign inversion
        ang = (math.pi / 2.0) - torch.atan(m)  # angle correction
        zmask = torch.where(zproj == 0.0, torch.zeros_like(zproj), zmask)
        ang = ang * zmask  # place zero where zsum is zero
        ang = ang * z  # weighted by position
        sumz_tot = z * zmask  # removing indexes with 0 energies or angles

        # zunmasked = K.sum(zmask, axis=1) # used for simple mean
        # Mean does not include positions where zsum=0
        # ang = K.sum(ang, axis=1)/zunmasked

        # sum ( measured * weights)/sum(weights)
        ang = torch.sum(ang, dim=1) / torch.sum(sumz_tot, dim=1)
        # Place 100 for measured angle where no energy is deposited in events
        ang = torch.where(amask == 0.0, ang, 100.0 * torch.ones_like(ang))
        ang = ang.unsqueeze(1)
        return ang

    def forward(self, x):
        z = self.conv1(x)
        z = F.leaky_relu(z)
        z = self.drop1(z)
        z = self.pad1(z)

        z = self.conv2(z)
        z = F.leaky_relu(z)
        z = self.bn1(z)
        z = self.drop2(z)
        z = self.pad2(z)

        z = self.conv3(z)
        z = F.leaky_relu(z)
        z = self.bn2(z)
        z = self.drop3(z)

        z = self.conv4(z)
        z = F.leaky_relu(z)
        z = self.bn3(z)
        z = self.drop4(z)
        z = self.avgpool(z)
        z = self.flatten(z)

        # generation output that says fake/real
        fake = torch.sigmoid(self.fakeout(z))
        aux = self.auxout(z)  # auxiliary output
        inv_image = x.pow(1.0 / self.power)
        ang = self.ecal_angle(inv_image, 1)  # angle calculation
        ecal = self.ecal_sum(inv_image, (2, 3, 4))  # sum of energies

        return fake, aux, ang, ecal


class ThreeDGAN(pl.LightningModule):
    def __init__(
        self,
        latent_size=256,
        loss_weights=[3, 0.1, 25, 0.1],
        power=0.85,
        lr=0.001,
        checkpoints_dir: str = '.',
        provenance_verbose: bool = False
    ):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False

        self.latent_size = latent_size
        self.loss_weights = loss_weights
        self.lr = lr
        self.power = power
        self.checkpoints_dir = checkpoints_dir
        self.provenance_verbose = provenance_verbose
        os.makedirs(self.checkpoints_dir, exist_ok=True)

        self.generator = Generator(self.latent_size)
        self.discriminator = Discriminator(self.power)

        self.epoch_gen_loss = []
        self.epoch_disc_loss = []
        self.disc_epoch_test_loss = []
        self.gen_epoch_test_loss = []
        self.index = 0
        self.train_history = defaultdict(list)
        self.test_history = defaultdict(list)
        # self.pklfile = checkpoint_path
        # checkpoint_dir = os.path.dirname(checkpoint_path)
        # os.makedirs(checkpoint_dir, exist_ok=True)

    @property
    def itwinai_logger(self) -> BaseItwinaiLogger:
        try:
            itwinai_logger = self.trainer.itwinai_logger
        except AttributeError:
            print("WARNING: itwinai_logger attribute not set "
                  f"in {self.__class__.__name__}")
            itwinai_logger = None
        return itwinai_logger

    def on_fit_start(self) -> None:
        if self.itwinai_logger:
            # Log hyper-parameters
            self.itwinai_logger.save_hyperparameters(self.hparams)

    def BitFlip(self, x, prob=0.05):
        """
        Flips a single bit according to a certain probability.

        Args:
            x (list): list of bits to be flipped
            prob (float): probability of flipping one bit

        Returns:
            list: List of flipped bits

        """
        x = np.array(x)
        selection = np.random.uniform(0, 1, x.shape) < prob
        x[selection] = 1 * np.logical_not(x[selection])
        return x

    def mean_absolute_percentage_error(self, y_true, y_pred):
        return torch.mean(torch.abs((y_true - y_pred) / (y_true + 1e-7))) * 100

    def compute_global_loss(
        self,
        labels,
        predictions,
        loss_weights=(3, 0.1, 25, 0.1)
    ):
        # Can be initialized outside
        binary_crossentropy_object = nn.BCEWithLogitsLoss(reduction='none')
        # there is no equivalent in pytorch for
        # tf.keras.losses.MeanAbsolutePercentageError --> using the
        # custom "mean_absolute_percentage_error" above!
        mean_absolute_percentage_error_object1 = \
            self.mean_absolute_percentage_error(predictions[1], labels[1])
        mean_absolute_percentage_error_object2 = \
            self.mean_absolute_percentage_error(predictions[3], labels[3])
        mae_object = nn.L1Loss(reduction='none')

        binary_example_loss = binary_crossentropy_object(
            predictions[0], labels[0]) * loss_weights[0]

        # mean_example_loss_1 = mean_absolute_percentage_error_object(
        # predictions[1], labels[1]) * loss_weights[1]
        mean_example_loss_1 = \
            mean_absolute_percentage_error_object1 * loss_weights[1]

        mae_example_loss = mae_object(
            predictions[2], labels[2]) * loss_weights[2]

        # mean_example_loss_2 = mean_absolute_percentage_error_object(
        # predictions[3], labels[3]) * loss_weights[3]
        mean_example_loss_2 = \
            mean_absolute_percentage_error_object2 * loss_weights[3]

        binary_loss = binary_example_loss.mean()
        mean_loss_1 = mean_example_loss_1.mean()
        mae_loss = mae_example_loss.mean()
        mean_loss_2 = mean_example_loss_2.mean()

        return [binary_loss, mean_loss_1, mae_loss, mean_loss_2]

    def forward(self, z):
        return self.generator(z)

    def training_step(self, batch, batch_idx):
        image_batch, energy_batch, ang_batch, ecal_batch = \
            batch['X'], batch['Y'], batch['ang'], batch['ecal']

        image_batch = image_batch.permute(0, 4, 1, 2, 3)

        image_batch = image_batch.to(self.device)
        energy_batch = energy_batch.to(self.device)
        ang_batch = ang_batch.to(self.device)
        ecal_batch = ecal_batch.to(self.device)

        optimizer_discriminator, optimizer_generator = self.optimizers()
        batch_size = energy_batch.shape[0]

        noise = torch.randn(
            (batch_size, self.latent_size - 2),
            dtype=torch.float32,
            device=self.device
        )
        # print(f'Energy elements: {energy_batch.numel} {energy_batch.shape}')
        # print(f'Angle elements: {ang_batch.numel} {ang_batch.shape}')
        generator_ip = torch.cat(
            (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise),
            dim=1
        )
        generated_images = self.generator(generator_ip)

        # Train discriminator first on real batch
        fake_batch = self.BitFlip(np.ones(batch_size).astype(np.float32))
        fake_batch = torch.tensor([[el] for el in fake_batch]).to(self.device)
        labels = [fake_batch, energy_batch, ang_batch, ecal_batch]

        predictions = self.discriminator(image_batch)
        # print("calculating real_batch_loss...")
        real_batch_loss = self.compute_global_loss(
            labels, predictions, self.loss_weights)
        if self.itwinai_logger:
            self.itwinai_logger.log(
                item=sum(real_batch_loss),
                identifier="real_batch_loss",
                kind='metric',
                step=self.global_step,
                batch_idx=batch_idx,
                context='training'
            )

        # self.log("real_batch_loss", sum(real_batch_loss),
        #          prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
        # print("real batch disc train")
        # the following 3 lines correspond in tf version to:
        # gradients = tape.gradient(real_batch_loss,
        # discriminator.trainable_variables)
        # optimizer_discriminator.apply_gradients(zip(gradients,
        #  discriminator.trainable_variables)) in Tensorflow
        optimizer_discriminator.zero_grad()
        self.manual_backward(sum(real_batch_loss))
        # sum(real_batch_loss).backward()
        # real_batch_loss.backward()
        optimizer_discriminator.step()

        # Train discriminator on the fake batch
        fake_batch = self.BitFlip(np.zeros(batch_size).astype(np.float32))
        fake_batch = torch.tensor([[el] for el in fake_batch]).to(self.device)
        labels = [fake_batch, energy_batch, ang_batch, ecal_batch]

        predictions = self.discriminator(generated_images)

        fake_batch_loss = self.compute_global_loss(
            labels, predictions, self.loss_weights)
        # self.log("fake_batch_loss", sum(fake_batch_loss),
        #          prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)

        if self.itwinai_logger:
            self.itwinai_logger.log(
                item=sum(fake_batch_loss),
                identifier="fake_batch_loss",
                kind='metric',
                step=self.global_step,
                batch_idx=batch_idx,
                context='training'
            )

        # print("fake batch disc train")
        # the following 3 lines correspond to
        # gradients = tape.gradient(fake_batch_loss,
        # discriminator.trainable_variables)
        # optimizer_discriminator.apply_gradients(zip(gradients,
        # discriminator.trainable_variables)) in Tensorflow
        optimizer_discriminator.zero_grad()
        self.manual_backward(sum(fake_batch_loss))
        # sum(fake_batch_loss).backward()
        optimizer_discriminator.step()

        # avg_disc_loss = (sum(real_batch_loss) + sum(fake_batch_loss)) / 2

        trick = np.ones(batch_size).astype(np.float32)
        fake_batch = torch.tensor([[el] for el in trick]).to(self.device)
        labels = [fake_batch, energy_batch.view(-1, 1), ang_batch, ecal_batch]

        gen_losses_train = []
        # Train generator twice using combined model
        for _ in range(2):
            noise = torch.randn(
                (batch_size, self.latent_size - 2)).to(self.device)
            generator_ip = torch.cat(
                (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise),
                dim=1
            )

            generated_images = self.generator(generator_ip)
            predictions = self.discriminator(generated_images)

            loss = self.compute_global_loss(
                labels, predictions, self.loss_weights)
            # self.log("gen_loss", sum(loss), prog_bar=True,
            #          on_step=True, on_epoch=True, sync_dist=True)

            if self.itwinai_logger:
                self.itwinai_logger.log(
                    item=sum(loss),
                    identifier="gen_loss",
                    kind='metric',
                    step=self.global_step,
                    batch_idx=batch_idx,
                    context='training'
                )

            # print("gen train")
            optimizer_generator.zero_grad()
            self.manual_backward(sum(loss))
            # sum(loss).backward()
            optimizer_generator.step()

            for el in loss:
                gen_losses_train.append(el)

        avg_generator_loss = sum(gen_losses_train) / len(gen_losses_train)
        # self.log("generator_loss", avg_generator_loss.item(),
        #          prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)

        if self.itwinai_logger:
            self.itwinai_logger.log(
                item=avg_generator_loss.item(),
                identifier="generator_loss",
                kind='metric',
                step=self.global_step,
                batch_idx=batch_idx,
                context='training'
            )
            # Log provenance information
            if self.provenance_verbose:
                # Log provenance at every training step
                self._log_provenance(context='training')

        # avg_generator_loss = [(a + b) / 2 for a, b in zip(*gen_losses_train)]
        # self.log("generator_loss", sum(avg_generator_loss), prog_bar=True,
        # on_step=True, on_epoch=True, sync_dist=True)

        gen_losses = []
        # I'm not returning anything as in pl you do not return anything when
        # you back-propagate manually
        # return_loss = real_batch_loss
        real_batch_loss = [real_batch_loss[0], real_batch_loss[1],
                           real_batch_loss[2], real_batch_loss[3]]
        fake_batch_loss = [fake_batch_loss[0], fake_batch_loss[1],
                           fake_batch_loss[2], fake_batch_loss[3]]
        gen_batch_loss = [gen_losses_train[0], gen_losses_train[1],
                          gen_losses_train[2], gen_losses_train[3]]
        gen_losses.append(gen_batch_loss)
        gen_batch_loss = [gen_losses_train[4], gen_losses_train[5],
                          gen_losses_train[6], gen_losses_train[7]]
        gen_losses.append(gen_batch_loss)

        real_batch_loss = [el.cpu().detach().numpy() for el in real_batch_loss]
        real_batch_loss_total_loss = np.sum(real_batch_loss)
        new_real_batch_loss = [real_batch_loss_total_loss]
        for i_weights in range(len(real_batch_loss)):
            new_real_batch_loss.append(
                real_batch_loss[i_weights] / self.loss_weights[i_weights])
        real_batch_loss = new_real_batch_loss

        fake_batch_loss = [el.cpu().detach().numpy() for el in fake_batch_loss]
        fake_batch_loss_total_loss = np.sum(fake_batch_loss)
        new_fake_batch_loss = [fake_batch_loss_total_loss]
        for i_weights in range(len(fake_batch_loss)):
            new_fake_batch_loss.append(
                fake_batch_loss[i_weights] / self.loss_weights[i_weights])
        fake_batch_loss = new_fake_batch_loss

        # if ecal sum has 100% loss(generating empty events) then end
        # the training
        if fake_batch_loss[3] == 100.0 and self.index > 10:
            # print("Empty image with Ecal loss equal to 100.0 "
            #       f"for {self.index} batch")
            torch.save(self.generator.state_dict(), os.path.join(
                self.checkpoints_dir, "generator_weights.pth"))
            torch.save(self.discriminator.state_dict(), os.path.join(
                       self.checkpoints_dir, "discriminator_weights.pth"))
            if self.itwinai_logger:
                self.itwinai_logger.log(
                    item=os.path.join(self.checkpoints_dir,
                                      "generator_weights.pth"),
                    identifier='final_generator_weights',
                    kind='artifact',
                    context='training'
                )
                self.itwinai_logger.log(
                    item=os.path.join(self.checkpoints_dir,
                                      "discriminator_weights.pth"),
                    identifier='final_discriminator_weights',
                    kind='artifact',
                    context='training'
                )

            # print("real_batch_loss", real_batch_loss)
            # print("fake_batch_loss", fake_batch_loss)
            sys.exit()

        # append mean of discriminator loss for real and fake events
        self.epoch_disc_loss.append(
            [(a + b) / 2 for a, b in zip(real_batch_loss, fake_batch_loss)])

        gen_losses[0] = [el.cpu().detach().numpy() for el in gen_losses[0]]
        gen_losses_total_loss = np.sum(gen_losses[0])
        new_gen_losses = [gen_losses_total_loss]
        for i_weights in range(len(gen_losses[0])):
            new_gen_losses.append(
                gen_losses[0][i_weights] / self.loss_weights[i_weights])
        gen_losses[0] = new_gen_losses

        gen_losses[1] = [el.cpu().detach().numpy() for el in gen_losses[1]]
        gen_losses_total_loss = np.sum(gen_losses[1])
        new_gen_losses = [gen_losses_total_loss]
        for i_weights in range(len(gen_losses[1])):
            new_gen_losses.append(
                gen_losses[1][i_weights] / self.loss_weights[i_weights])
        gen_losses[1] = new_gen_losses

        generator_loss = [(a + b) / 2 for a, b in zip(*gen_losses)]

        self.epoch_gen_loss.append(generator_loss)

        # # MB: verify weight synchronization among workers
        # # Ref: https://github.com/Lightning-AI/lightning/issues/9237
        # disc_w = self.discriminator.conv1.weight.reshape(-1)[0:5]
        # gen_w = self.generator.conv1.weight.reshape(-1)[0:5]
        # print(f"DISC w: {disc_w}")
        # print(f"GEN w: {gen_w}")

        # self.index += 1 #this might be moved after test cycle

        # logging of gen and disc loss done by Trainer
        # self.log('epoch_gen_loss', self.epoch_gen_loss, on_step=True,
        #  on_epoch=True, sync_dist=True)
        # self.log('epoch_disc_loss', self.epoch_disc_loss, on_step=True, o
        # n_epoch=True, sync_dist=True)

        # return avg_disc_loss + avg_generator_loss

    def on_train_epoch_end(self):

        if not self.provenance_verbose:
            # Log provenance only at the end of an epoch
            self._log_provenance(context='training')

        discriminator_train_loss = np.mean(
            np.array(self.epoch_disc_loss), axis=0)
        generator_train_loss = np.mean(np.array(self.epoch_gen_loss), axis=0)

        self.train_history["generator"].append(generator_train_loss)
        self.train_history["discriminator"].append(discriminator_train_loss)

        print("-" * 65)
        ROW_FMT = (
            "{0:<20s} | {1:<4.2f} | {2:<10.2f} | "
            "{3:<10.2f}| {4:<10.2f} | {5:<10.2f}")
        print(ROW_FMT.format("generator (train)",
              *self.train_history["generator"][-1]))
        print(ROW_FMT.format("discriminator (train)",
              *self.train_history["discriminator"][-1]))

        torch.save(self.generator.state_dict(), os.path.join(
            self.checkpoints_dir, "generator_weights.pth"))
        torch.save(self.discriminator.state_dict(), os.path.join(
            self.checkpoints_dir, "discriminator_weights.pth"))

        if self.itwinai_logger:
            self.itwinai_logger.log(
                item=os.path.join(self.checkpoints_dir,
                                  "generator_weights.pth"),
                identifier='ckpts/generator_weights_epoch_' +
                str(self.current_epoch),
                kind='artifact',
                context='training'
            )
            self.itwinai_logger.log(
                item=self.generator,
                identifier='generator_epoch_' + str(self.current_epoch),
                kind='model',
                context='training'
            )
            self.itwinai_logger.log(
                item=os.path.join(self.checkpoints_dir,
                                  "discriminator_weights.pth"),
                identifier='ckpts/discriminator_weights_epoch_' +
                str(self.current_epoch),
                kind='artifact',
                context='training'
            )

        # with open(self.pklfile, "wb") as f:
        #     pickle.dump({"train": self.train_history,
        #                 "test": self.test_history}, f)

        # pickle.dump({"train": self.train_history}, open(self.pklfile, "wb"))
        print("train-loss:" + str(self.train_history["generator"][-1][0]))

    def validation_step(self, batch, batch_idx):
        image_batch, energy_batch, ang_batch, ecal_batch = batch[
            'X'], batch['Y'], batch['ang'], batch['ecal']

        image_batch = image_batch.permute(0, 4, 1, 2, 3)

        image_batch = image_batch.to(self.device)
        energy_batch = energy_batch.to(self.device)
        ang_batch = ang_batch.to(self.device)
        ecal_batch = ecal_batch.to(self.device)

        batch_size = energy_batch.shape[0]

        # Generate Fake events with same energy and angle as data batch
        noise = torch.randn(
            (batch_size, self.latent_size - 2),
            dtype=torch.float32,
            device=self.device
        )

        generator_ip = torch.cat(
            (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise), dim=1)
        generated_images = self.generator(generator_ip)

        # concatenate to fake and real batches
        X = torch.cat((image_batch, generated_images), dim=0)

        # y = np.array([1] * batch_size \
        # + [0] * batch_size).astype(np.float32)
        y = torch.tensor([1] * batch_size + [0] *
                         batch_size, dtype=torch.float32).to(self.device)
        y = y.view(-1, 1)

        ang = torch.cat((ang_batch, ang_batch), dim=0)
        ecal = torch.cat((ecal_batch, ecal_batch), dim=0)
        aux_y = torch.cat((energy_batch, energy_batch), dim=0)

        # y = [[el] for el in y]
        labels = [y, aux_y, ang, ecal]

        # Calculate discriminator loss
        disc_eval = self.discriminator(X)
        disc_eval_loss = self.compute_global_loss(
            labels, disc_eval, self.loss_weights)

        # Calculate generator loss
        trick = np.ones(batch_size).astype(np.float32)
        fake_batch = torch.tensor([[el] for el in trick]).to(self.device)
        # fake_batch = [[el] for el in trick]
        labels = [fake_batch, energy_batch, ang_batch, ecal_batch]

        generated_images = self.generator(generator_ip)
        gen_eval = self.discriminator(generated_images)
        gen_eval_loss = self.compute_global_loss(
            labels, gen_eval, self.loss_weights)

        if self.itwinai_logger:
            self.itwinai_logger.log(
                item=sum(disc_eval_loss),
                identifier="val_discriminator_loss",
                kind='metric',
                step=self.global_step,
                batch_idx=batch_idx,
                context='validation'
            )
            self.itwinai_logger.log(
                item=sum(gen_eval_loss),
                identifier="val_generator_loss",
                kind='metric',
                step=self.global_step,
                batch_idx=batch_idx,
                context='validation'
            )
            # Log provenance information
            if self.provenance_verbose:
                # Log provenance at every validation step
                self._log_provenance(context='validation')

        # self.log('val_discriminator_loss', sum(
        #     disc_eval_loss), on_epoch=True, prog_bar=True, sync_dist=True)
        # self.log('val_generator_loss', sum(gen_eval_loss),
        #          on_epoch=True, prog_bar=True, sync_dist=True)

        disc_test_loss = [disc_eval_loss[0], disc_eval_loss[1],
                          disc_eval_loss[2], disc_eval_loss[3]]
        gen_test_loss = [gen_eval_loss[0], gen_eval_loss[1],
                         gen_eval_loss[2], gen_eval_loss[3]]

        # Configure the loss so it is equal to the original values
        disc_eval_loss = [el.cpu().detach().numpy() for el in disc_test_loss]
        disc_eval_loss_total_loss = np.sum(disc_eval_loss)
        new_disc_eval_loss = [disc_eval_loss_total_loss]
        for i_weights in range(len(disc_eval_loss)):
            new_disc_eval_loss.append(
                disc_eval_loss[i_weights] / self.loss_weights[i_weights])
        disc_eval_loss = new_disc_eval_loss

        gen_eval_loss = [el.cpu().detach().numpy() for el in gen_test_loss]
        gen_eval_loss_total_loss = np.sum(gen_eval_loss)
        new_gen_eval_loss = [gen_eval_loss_total_loss]
        for i_weights in range(len(gen_eval_loss)):
            new_gen_eval_loss.append(
                gen_eval_loss[i_weights] / self.loss_weights[i_weights])
        gen_eval_loss = new_gen_eval_loss

        self.index += 1
        # evaluate discriminator loss
        self.disc_epoch_test_loss.append(disc_eval_loss)
        # evaluate generator loss
        self.gen_epoch_test_loss.append(gen_eval_loss)

    def _log_provenance(self, context: str):

        if self.itwinai_logger:
            # Some provenance metrics
            self.itwinai_logger.log(
                item=self.current_epoch,
                identifier="epoch",
                kind='metric',
                step=self.current_epoch,
                context=context)
            self.itwinai_logger.log(
                item=self,
                identifier=f"model_version_{self.current_epoch}",
                kind='model_version',
                step=self.current_epoch,
                context=context)
            self.itwinai_logger.log(
                item=None, identifier=None,
                kind='system',
                step=self.current_epoch,
                context=context)
            self.itwinai_logger.log(
                item=None, identifier=None,
                kind='carbon',
                step=self.current_epoch,
                context=context)
            self.itwinai_logger.log(
                item=None, identifier="train_epoch_time",
                kind='execution_time',
                step=self.current_epoch,
                context=context)

    def on_validation_epoch_end(self):

        if not self.provenance_verbose:
            # Log provenance only at the end of an epoch
            self._log_provenance(context='validation')

        discriminator_test_loss = np.mean(
            np.array(self.disc_epoch_test_loss), axis=0)
        generator_test_loss = np.mean(
            np.array(self.gen_epoch_test_loss), axis=0)

        self.test_history["generator"].append(generator_test_loss)
        self.test_history["discriminator"].append(discriminator_test_loss)

        print("-" * 65)
        ROW_FMT = (
            "{0:<20s} | {1:<4.2f} | {2:<10.2f} | "
            "{3:<10.2f}| {4:<10.2f} | {5:<10.2f}")
        print(ROW_FMT.format("generator (test)",
              *self.test_history["generator"][-1]))
        print(ROW_FMT.format("discriminator (test)",
              *self.test_history["discriminator"][-1]))

        # # save loss dict to pkl file
        # with open(self.pklfile, "wb") as f:
        #     pickle.dump({"train": self.train_history,
        #                 "test": self.test_history}, f)
        # pickle.dump({"test": self.test_history}, open(self.pklfile, "wb"))
        # print("train-loss:" + str(self.train_history["generator"][-1][0]))

    def predict_step(
        self,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0
    ) -> Any:
        energy_batch, ang_batch = batch['Y'], batch['ang']

        energy_batch = energy_batch.to(self.device)
        ang_batch = ang_batch.to(self.device)

        # Generate Fake events with same energy and angle as data batch
        noise = torch.randn(
            (energy_batch.shape[0], self.latent_size - 2),
            dtype=torch.float32,
            device=self.device
        )

        # print(f"Reshape energy: {energy_batch.view(-1, 1).shape}")
        # print(f"Reshape angle: {ang_batch.view(-1, 1).shape}")
        # print(f"Noise: {noise.shape}")

        generator_ip = torch.cat(
            [energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise],
            dim=1
        )
        # print(f"Generator input: {generator_ip.shape}")
        generated_images = self.generator(generator_ip)
        # print(f"Generated batch size {generated_images.shape}")
        return {'images': generated_images,
                'energies': energy_batch,
                'angles': ang_batch}

    def configure_optimizers(self):
        lr = self.lr

        optimizer_discriminator = torch.optim.RMSprop(
            self.discriminator.parameters(),
            lr
        )
        optimizer_generator = torch.optim.RMSprop(
            self.generator.parameters(),
            lr
        )

        if self.itwinai_logger:
            self.itwinai_logger.log(
                optimizer_discriminator,
                'optimizer_discriminator',
                kind='torch')
            self.itwinai_logger.log(
                optimizer_generator,
                'optimizer_generator',
                kind='torch')

        return [optimizer_discriminator, optimizer_generator], []

trainer.py

import os
import sys
from typing import Union, Dict, Optional, Any
import tempfile
import yaml

import torch
from torch import Tensor
import lightning as pl
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch import Trainer as LightningTrainer

from itwinai.components import Trainer, Predictor, monitor_exec
from itwinai.serialization import ModelLoader
from itwinai.torch.inference import TorchModelLoader
from itwinai.torch.type import Batch
from itwinai.utils import load_yaml
# from itwinai.torch.mlflow import (
#     init_lightning_mlflow,
#     teardown_lightning_mlflow
# )
from itwinai.loggers import Logger, _EmptyLogger


from model import ThreeDGAN
from dataloader import ParticlesDataModule


class Lightning3DGANTrainer(Trainer):
    def __init__(
        self,
        config: Union[Dict, str],
        itwinai_logger: Optional[Logger] = None
    ):
        self.save_parameters(**self.locals2params(locals()))
        super().__init__()
        if isinstance(config, str) and os.path.isfile(config):
            # Load from YAML
            config = load_yaml(config)
        self.conf = config
        self.itwinai_logger = (
            itwinai_logger if itwinai_logger else _EmptyLogger()
        )

    @monitor_exec
    def execute(self) -> Any:

        # Parse lightning configuration
        old_argv = sys.argv
        sys.argv = ['some_script_placeholder.py']
        cli = LightningCLI(
            args=self.conf,
            model_class=ThreeDGAN,
            datamodule_class=ParticlesDataModule,
            trainer_class=LightningTrainer,
            run=False,
            save_config_kwargs={
                "overwrite": True,
                "config_filename": "pl-training.yml",
            },
            subclass_mode_model=True,
            subclass_mode_data=True,
        )
        sys.argv = old_argv

        # Get current worker rank (assuming torchrun launcher)
        global_rank = int(os.getenv('RANK', 0))

        with self.itwinai_logger.start_logging(rank=global_rank):
            # Set the logger into the LightningTrainer
            cli.trainer.itwinai_logger = self.itwinai_logger

            # Start training
            cli.trainer.fit(cli.model, datamodule=cli.datamodule)

            self._log_config(self.itwinai_logger)
            self.itwinai_logger.log(
                cli.trainer.train_dataloader,
                "train_dataloader",
                kind='torch'
            )
            self.itwinai_logger.log(
                cli.trainer.val_dataloaders,
                "val_dataloader",
                kind='torch'
            )

    def _log_config(self, logger: Logger):
        with tempfile.TemporaryDirectory(dir='/tmp') as tmp_dir:
            local_yaml_path = os.path.join(tmp_dir, 'pl-conf.yaml')
            with open(local_yaml_path, 'w') as outfile:
                yaml.dump(self.conf, outfile, default_flow_style=False)
            logger.log(local_yaml_path, 'lightning-config', kind='artifact')


class LightningModelLoader(TorchModelLoader):
    """Loads a torch lightning model from somewhere.

    Args:
        model_uri (str): Can be a path on local filesystem
            or an mlflow 'locator' in the form:
            'mlflow+MLFLOW_TRACKING_URI+RUN_ID+ARTIFACT_PATH'
    """

    def __call__(self) -> pl.LightningModule:
        """"Loads model from model URI.

        Raises:
            ValueError: if the model URI is not recognized
                or the model is not found.

        Returns:
            pl.LightningModule: torch lightning module.
        """
        # TODO: improve
        # # Load best model
        # loaded_model = cli.model.load_from_checkpoint(
        #     ckpt_path,
        #     lightning_conf['model']['init_args']
        # )
        return super().__call__()


class Lightning3DGANPredictor(Predictor):

    def __init__(
        self,
        model: Union[ModelLoader, pl.LightningModule],
        config: Union[Dict, str],
        name: Optional[str] = None
    ):
        self.save_parameters(**self.locals2params(locals()))
        super().__init__(model, name)
        if isinstance(config, str) and os.path.isfile(config):
            # Load from YAML
            config = load_yaml(config)
        self.conf = config

    @monitor_exec
    def execute(
        self,
        datamodule: Optional[pl.LightningDataModule] = None,
        model: Optional[pl.LightningModule] = None
    ) -> Dict[str, Tensor]:
        old_argv = sys.argv
        sys.argv = ['some_script_placeholder.py']
        cli = LightningCLI(
            args=self.conf,
            model_class=ThreeDGAN,
            datamodule_class=ParticlesDataModule,
            run=False,
            save_config_kwargs={
                "overwrite": True,
                "config_filename": "pl-training.yml",
            },
            subclass_mode_model=True,
            subclass_mode_data=True,
        )
        sys.argv = old_argv

        # Override config file with inline arguments, if given
        if datamodule is None:
            datamodule = cli.datamodule
        if model is None:
            model = cli.model

        predictions = cli.trainer.predict(model, datamodule=datamodule)

        # Transpose predictions into images, energies and angles
        images = torch.cat(list(map(
            lambda pred: self.transform_predictions(
                pred['images']), predictions
        )))
        energies = torch.cat(list(map(
            lambda pred: pred['energies'], predictions
        )))
        angles = torch.cat(list(map(
            lambda pred: pred['angles'], predictions
        )))

        predictions_dict = dict()
        for img, en, ang in zip(images, energies, angles):
            sample_key = f"energy={en.item()}&angle={ang.item()}"
            predictions_dict[sample_key] = img

        return predictions_dict

    def transform_predictions(self, batch: Batch) -> Batch:
        """
        Post-process the predictions of the torch model.
        """
        return batch.squeeze(1)

saver.py

from typing import Dict
import os
import shutil
import pickle
import random

import torch
from torch import Tensor
import matplotlib.pyplot as plt
import numpy as np

from itwinai.components import Saver, monitor_exec


class ParticleImagesSaver(Saver):
    """Saves generated particle trajectories to disk."""

    def __init__(
        self,
        save_dir: str = '3dgan-generated',
        aggregate_predictions: bool = False
    ) -> None:
        self.save_parameters(**self.locals2params(locals()))
        super().__init__()
        self.save_dir = save_dir
        self.aggregate_predictions = aggregate_predictions

    @monitor_exec
    def execute(self, generated_images: Dict[str, Tensor]) -> None:
        """Saves generated images to disk.

        Args:
            generated_images (Dict[str, Tensor]): maps unique item ID to
                the generated image.
        """
        if self.aggregate_predictions:
            os.makedirs(self.save_dir, exist_ok=True)
            sparse_generated_images = dict()
            for name, res in generated_images.items():
                sparse_generated_images[name] = res.to_sparse()
            del generated_images
            with open(self._random_file(), 'wb') as fp:
                pickle.dump(sparse_generated_images, fp)
        else:
            if os.path.exists(self.save_dir):
                shutil.rmtree(self.save_dir)
            os.makedirs(self.save_dir)
            # Save as torch tensor and jpg image
            for img_id, img in generated_images.items():
                img_path = os.path.join(self.save_dir, img_id)
                torch.save(img, img_path + '.pth')
                self._save_image(img, img_id, img_path + '.jpg')

    def _random_file(self, extension: str = 'pkl') -> str:
        fname = "%032x.%s" % (random.getrandbits(128), extension)
        fpath = os.path.join(self.save_dir, fname)
        while os.path.exists(fpath):
            fname = "%032x.%s" % (random.getrandbits(128), extension)
            fpath = os.path.join(self.save_dir, fname)
        return fpath

    def _save_image(
        self,
        img: Tensor,
        img_idx: str,
        img_path: str,
        center: bool = True
    ) -> None:
        """Converts a 3D tensor to a 3D scatter plot and saves it
        to disk as jpg image.
        """
        x_offset = img.shape[0] // 2 if center else 0
        y_offset = img.shape[1] // 2 if center else 0
        z_offset = img.shape[2] // 2 if center else 0

        # Convert tensor dimension IDs to coordinates
        x_coords = []
        y_coords = []
        z_coords = []
        values = []

        for x in range(img.shape[0]):
            for y in range(img.shape[1]):
                for z in range(img.shape[2]):
                    if img[x, y, z] > 0.0:
                        x_coords.append(x - x_offset)
                        y_coords.append(y - y_offset)
                        z_coords.append(z - z_offset)
                        values.append(img[x, y, z])

        # import plotly.graph_objects as go
        # normalize_intensity_by = 1
        # trace = go.Scatter3d(
        #     x=x_coords,
        #     y=y_coords,
        #     z=z_coords,
        #     mode='markers',
        #     marker_symbol='square',
        #     marker_color=[
        #         f"rgba(0,0,255,{i*100//normalize_intensity_by/10})"
        #         for i in values],
        # )
        # fig = go.Figure()
        # fig.add_trace(trace)
        # fig.write_image(img_path)

        values = np.array(values)
        # 0-1 scaling
        values = (values - values.min()) / (values.max() - values.min())
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
        ax.scatter(x_coords, y_coords, z_coords, alpha=values)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')

        # Extract energy and angle from idx
        en, ang = img_idx.split('&')
        en = en[7:]
        ang = ang[6:]
        ax.set_title(f"Energy: {en} - Angle: {ang}")
        fig.savefig(img_path)

dataloader.py

from typing import Optional
import os
from lightning.pytorch.utilities.types import EVAL_DATALOADERS

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
import lightning as pl
import glob
import h5py
import gdown

from itwinai.components import DataGetter, monitor_exec
from itwinai.loggers import Logger as BaseItwinaiLogger


class Lightning3DGANDownloader(DataGetter):
    def __init__(
        self,
        data_path: str,
        data_url: Optional[str] = None,
        name: Optional[str] = None,
    ) -> None:
        self.save_parameters(**self.locals2params(locals()))
        super().__init__(name)
        self.data_path = data_path
        self.data_url = data_url

    @monitor_exec
    def execute(self):
        # Download data
        if not os.path.exists(self.data_path):
            if self.data_url is None:
                print("WARNING! Data URL is None. "
                      "Skipping dataset downloading")

            gdown.download_folder(
                url=self.data_url, quiet=False,
                output=self.data_path,
                # verify=False
            )


class ParticlesDataset(Dataset):
    def __init__(self, datapath: str, max_samples: Optional[int] = None):
        self.datapath = datapath
        self.max_samples = max_samples
        self.data = dict()

        self.fetch_data()

    def __len__(self):
        return len(self.data["X"])

    def __getitem__(self, idx):
        return {"X": self.data["X"][idx], "Y": self.data["Y"][idx],
                "ang": self.data["ang"][idx], "ecal": self.data["ecal"][idx]}

    def fetch_data(self) -> None:

        print("Searching in :", self.datapath)
        files = sorted(glob.glob(os.path.join(
            self.datapath, '**/*.h5'), recursive=True))
        print("Found {} files. ".format(len(files)))
        if len(files) == 0:
            raise RuntimeError(f"No H5 files found at '{self.datapath}'!")

        # concatenated_datasets = []
        # for datafile in files:
        #     f = h5py.File(datafile, 'r')
        #     dataset = self.GetDataAngleParallel(f)
        #     concatenated_datasets.append(dataset)
        #     # Initialize result dictionary
        #     result = {key: [] for key in concatenated_datasets[0].keys()}
        #     for d in concatenated_datasets:
        #         for key in result.keys():
        #             result[key].extend(d[key])
        # return result

        for datafile in files:
            f = h5py.File(datafile, 'r')
            dataset = self.GetDataAngleParallel(f)
            for field, vals_array in dataset.items():
                if self.data.get(field) is not None:
                    # Resize to include the new array
                    new_shape = list(self.data[field].shape)
                    new_shape[0] += len(vals_array)
                    self.data[field].resize(new_shape)
                    self.data[field][-len(vals_array):] = vals_array
                else:
                    self.data[field] = vals_array

            # Stop loading data, if self.max_samples reached
            if (self.max_samples is not None
                    and len(self.data[field]) >= self.max_samples):
                for field, vals_array in self.data.items():
                    self.data[field] = vals_array[:self.max_samples]
                break

    def GetDataAngleParallel(
        self,
        dataset,
        xscale=1,
        xpower=0.85,
        yscale=100,
        angscale=1,
        angtype="theta",
        thresh=1e-4,
        daxis=-1
    ):
        """Preprocess function for the dataset

        Args:
            dataset (str): Dataset file path
            xscale (int, optional): Value to scale the ECAL values.
                Defaults to 1.
            xpower (int, optional): Value to scale the ECAL values,
                exponentially. Defaults to 1.
            yscale (int, optional): Value to scale the energy values.
                Defaults to 100.
            angscale (int, optional): Value to scale the angle values.
                Defaults to 1.
            angtype (str, optional): Which type of angle to use.
                Defaults to "theta".
            thresh (_type_, optional): Maximum value for ECAL values.
                Defaults to 1e-4.
            daxis (int, optional): Axis to expand values. Defaults to -1.

        Returns:
          Dict: Dictionary containning the preprocessed dataset
        """
        X = np.array(dataset.get("ECAL")) * xscale
        Y = np.array(dataset.get("energy")) / yscale
        X[X < thresh] = 0
        X = X.astype(np.float32)
        Y = Y.astype(np.float32)
        ecal = np.sum(X, axis=(1, 2, 3))
        indexes = np.where(ecal > 10.0)
        X = X[indexes]
        Y = Y[indexes]
        if angtype in dataset:
            ang = np.array(dataset.get(angtype))[indexes]
        # else:
        # ang = gan.measPython(X)
        X = np.expand_dims(X, axis=daxis)
        ecal = ecal[indexes]
        ecal = np.expand_dims(ecal, axis=daxis)
        if xpower != 1.0:
            X = np.power(X, xpower)

        Y = np.array([[el] for el in Y])
        ang = np.array([[el] for el in ang])
        ecal = np.array([[el] for el in ecal])

        final_dataset = {"X": X, "Y": Y, "ang": ang, "ecal": ecal}

        return final_dataset


class ParticlesDataModule(pl.LightningDataModule):
    def __init__(
            self,
            datapath: str,
            batch_size: int,
            num_workers: int = 4,
            max_samples: Optional[int] = None,
            train_proportion: float = 0.9
    ) -> None:
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.datapath = datapath
        self.max_samples = max_samples
        self.train_proportion = train_proportion

    @property
    def itwinai_logger(self) -> BaseItwinaiLogger:
        try:
            itwinai_logger = self.trainer.itwinai_logger
        except AttributeError:
            print("WARNING: itwinai_logger attribute not set "
                  f"in {self.__class__.__name__}")
            itwinai_logger = None
        return itwinai_logger

    def setup(self, stage: str = None):
        # make assignments here (val/train/test split)
        # called on every process in DDP

        if stage == 'fit' or stage is None:
            self.dataset = ParticlesDataset(
                self.datapath,
                max_samples=self.max_samples
            )
            dataset_length = len(self.dataset)
            split_point = int(dataset_length * self.train_proportion)
            self.train_dataset, self.val_dataset = \
                torch.utils.data.random_split(
                    self.dataset, [split_point, dataset_length - split_point])

        if stage == 'predict':
            # TODO: inference dataset should be different in that it
            # does not contain images!
            self.predict_dataset = ParticlesDataset(
                self.datapath,
                max_samples=self.max_samples
            )

        # if stage == 'test' or stage is None:
            # self.test_dataset = MyDataset(self.data_dir, train=False)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, num_workers=self.num_workers,
                          batch_size=self.batch_size, drop_last=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, num_workers=self.num_workers,
                          batch_size=self.batch_size, drop_last=True)

    def predict_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(self.predict_dataset, num_workers=self.num_workers,
                          batch_size=self.batch_size, drop_last=True)

    # def test_dataloader(self):
        # return DataLoader(self.test_dataset, batch_size=self.batch_size)

config.yaml

This YAML file defines the pipeline configuration for the CERN use case.

# Main configurations
dataset_location: exp_data/
dataset_url: https://drive.google.com/drive/folders/1ooUIfkhpokvwh4-7qPxX084N7-LgqnIL # https://drive.google.com/drive/folders/1uPpz0tquokepptIfJenTzGpiENfo2xRX
hw_accelerators: auto
distributed_strategy: ddp_find_unused_parameters_true #deepspeed auto horovod
devices: auto #[0]
checkpoints_path: checkpoints
logs_dir: ml_logs
mlflow_tracking_uri: mlruns # https://131.154.99.166.myip.cloud.infn.it
batch_size: 4
train_dataset_proportion: 0.7
num_workers_dataloader: 0
max_epochs: 2
max_dataset_size: 48
random_seed: 4231162351
inference_results_location: 3dgan-generated-data/
inference_model_uri: 3dgan-inference.pth
aggregate_predictions: false
num_nodes: 1
provenance_verbose: true

# Dataloading step is common and can be reused
dataloading_step:
  class_path: dataloader.Lightning3DGANDownloader
  init_args:
    data_path: ${dataset_location} # Set to null to skip dataset download
    data_url: ${dataset_url}

# AI workflows
training_pipeline:
  class_path: itwinai.pipeline.Pipeline
  init_args:
    steps:
      dataloading_step: ${dataloading_step}

      training_step:
        class_path: trainer.Lightning3DGANTrainer
        init_args:
          # NOTE: before pushing to the repo, disable logging to prevent slowing down unit tests
          # itwinai_logger:
          #   class_path: itwinai.loggers.LoggersCollection
          #   init_args:
          #     loggers: 
          #       # - class_path: itwinai.loggers.ConsoleLogger
          #       #   init_args:
          #       #     log_freq: 100
          #       # - class_path: itwinai.loggers.MLFlowLogger
          #       #   init_args:
          #       #     experiment_name: 3DGAN
          #       #     log_freq: batch
          #       - class_path: itwinai.loggers.Prov4MLLogger
          #         init_args:
          #           provenance_save_dir: mllogs/prov_logs
          #           experiment_name: 3DGAN
          #           log_freq: batch
          #           log_on_workers: -1
          #       # - class_path: itwinai.loggers.WandBLogger
          #       #   init_args:
          #       #     log_freq: batch

          # Pytorch lightning config for training
          config:
            seed_everything: ${random_seed}
            trainer:
              accelerator: ${hw_accelerators}
              accumulate_grad_batches: 1
              barebones: false
              benchmark: null
              # callbacks:
              #   - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping
              #     init_args:
              #       monitor: val_generator_loss
              #       patience: 2
              #   - class_path: lightning.pytorch.callbacks.lr_monitor.LearningRateMonitor
              #     init_args:
              #       logging_interval: step
              #   - class_path: lightning.pytorch.callbacks.ModelCheckpoint
              #     init_args:
              #       dirpath: ${checkpoints_path}
              #       filename: best-checkpoint
              #       mode: min
              #       monitor: val_generator_loss
              #       save_top_k: 1
              #       verbose: true
              check_val_every_n_epoch: 1
              default_root_dir: null
              detect_anomaly: false
              deterministic: null
              devices: ${devices}
              num_nodes: ${num_nodes}
              enable_checkpointing: true
              enable_model_summary: null
              enable_progress_bar: null
              fast_dev_run: false
              gradient_clip_algorithm: null
              gradient_clip_val: null
              inference_mode: true
              limit_predict_batches: null
              limit_test_batches: null
              limit_train_batches: null
              limit_val_batches: null
              log_every_n_steps: 1
              # logger:
              #   - class_path: lightning.pytorch.loggers.CSVLogger
              #     init_args:
              #       name: 3DGAN
              #       save_dir: ${logs_dir}
              #   - class_path: lightning.pytorch.loggers.MLFlowLogger
              #     init_args:
              #       experiment_name: 3DGAN
              #       save_dir: null #ml_logs/mlflow_logs
              #       tracking_uri: ${mlflow_tracking_uri}
              #       log_model: all
              max_epochs: ${max_epochs}
              max_time: null
              min_epochs: null
              min_steps: null
              num_sanity_val_steps: null
              overfit_batches: 0.0
              plugins: null
              profiler: null
              reload_dataloaders_every_n_epochs: 0
              strategy: ${distributed_strategy}
              sync_batchnorm: false
              use_distributed_sampler: true
              val_check_interval: null

            # Lightning Model configuration
            model:
              class_path: model.ThreeDGAN
              init_args:
                latent_size: 256
                loss_weights: [3, 0.1, 25, 0.1]
                power: 0.85
                lr: 0.001
                checkpoints_dir: ${checkpoints_path}
                provenance_verbose: ${provenance_verbose}

            # Lightning data module configuration
            data:
              class_path: dataloader.ParticlesDataModule
              init_args:
                datapath: ${dataset_location}
                batch_size: ${batch_size}
                num_workers: ${num_workers_dataloader}
                max_samples: ${max_dataset_size}
                train_proportion: ${train_dataset_proportion}

inference_pipeline:
  class_path: itwinai.pipeline.Pipeline
  init_args:
    steps:
      dataloading_step: ${dataloading_step}

      inference_step:
        class_path: trainer.Lightning3DGANPredictor
        init_args:
          model:
            class_path: trainer.LightningModelLoader
            init_args:
              model_uri: ${inference_model_uri}

          # Pytorch lightning config for training
          config:
            seed_everything: ${random_seed}
            trainer:
              accelerator: ${hw_accelerators}
              accumulate_grad_batches: 1
              barebones: false
              benchmark: null
              check_val_every_n_epoch: 1
              default_root_dir: null
              detect_anomaly: false
              deterministic: null
              devices: ${devices}
              enable_checkpointing: true
              enable_model_summary: null
              enable_progress_bar: null
              fast_dev_run: false
              gradient_clip_algorithm: null
              gradient_clip_val: null
              inference_mode: true
              limit_predict_batches: null
              limit_test_batches: null
              limit_train_batches: null
              limit_val_batches: null
              log_every_n_steps: 2
              logger: 
                # - class_path: lightning.pytorch.loggers.CSVLogger
                #   init_args:
                #     save_dir: ml_logs/csv_logs
                class_path: lightning.pytorch.loggers.MLFlowLogger
                init_args:
                  experiment_name: 3DGAN
                  save_dir: ${logs_dir}
                  log_model: all
              max_epochs: ${max_epochs}
              max_steps: 20
              max_time: null
              min_epochs: null
              min_steps: null
              num_sanity_val_steps: null
              overfit_batches: 0.0
              plugins: null
              profiler: null
              reload_dataloaders_every_n_epochs: 0
              strategy: ${distributed_strategy}
              sync_batchnorm: false
              use_distributed_sampler: true
              val_check_interval: null

            # Lightning Model configuration
            model:
              class_path: model.ThreeDGAN
              init_args:
                latent_size: 256
                loss_weights: [3, 0.1, 25, 0.1]
                power: 0.85
                lr: 0.001
                checkpoints_dir: ${checkpoints_path}

            # Lightning data module configuration
            data:
              class_path: dataloader.ParticlesDataModule
              init_args:
                datapath: ${dataset_location}
                batch_size: ${batch_size} #1024
                num_workers: ${num_workers_dataloader} #4
                max_samples: ${max_dataset_size} #null, 10000

      saver_step:
        class_path: saver.ParticleImagesSaver
        init_args:
          save_dir: ${inference_results_location}
          aggregate_predictions: ${aggregate_predictions}

create_inference_sample.py

This file defines a pipeline configuration for the CERN use case inference.

"""Create a simple inference dataset sample and a checkpoint."""

import argparse
import os
import torch
from model import ThreeDGAN


def create_checkpoint(
    root: str = '.',
    ckpt_name: str = "3dgan-inference.pth"
):
    ckpt_path = os.path.join(root, ckpt_name)
    net = ThreeDGAN()
    torch.save(net, ckpt_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--root", type=str, default='.')
    parser.add_argument("--ckpt-name", type=str, default="3dgan-inference.pth")
    args = parser.parse_args()
    create_checkpoint(**vars(args))

Dockerfile

FROM nvcr.io/nvidia/pytorch:23.09-py3
# FROM python:3.11

WORKDIR /usr/src/app

# Install itwinai
COPY pyproject.toml ./
COPY src ./
RUN pip install --upgrade pip \
    && pip install --no-cache-dir lightning \
    && pip install --no-cache-dir .

# Add 3DGAN use case files and install additional requirements
COPY use-cases/3dgan/requirements.txt ./
COPY use-cases/3dgan/* ./
RUN pip install --no-cache-dir -r requirements.txt

# Create non-root user
RUN groupadd -g 10001 jovyan \
    && useradd -m -u 10000 -g jovyan jovyan \
    && chown -R jovyan:jovyan /usr/src/app
USER jovyan:jovyan

# ENTRYPOINT [ "itwinai", "exec-pipeline" ]
# CMD [ "--config", "pipeline.yaml" ]

SLURM job script for JSC (HDFML system)

#!/bin/bash

# SLURM jobscript for JSC systems

# general configuration of the job
#SBATCH --job-name=PrototypeTest
#SBATCH --account=intertwin
#SBATCH --mail-user=
#SBATCH --mail-type=ALL
#SBATCH --output=job.out
#SBATCH --error=job.err
#SBATCH --time=00:30:00

# configure node and process count on the CM
#SBATCH --partition=batch
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=4
#SBATCH --gpus-per-node=4

#SBATCH --exclusive

# gres options have to be disabled for deepv
#SBATCH --gres=gpu:4

# load modules
ml --force purge
ml Stages/2024 GCC CUDA/12 cuDNN Python 
# ml Stages/2024 GCC OpenMPI CUDA/12 cuDNN MPI-settings/CUDA
# ml Python CMake HDF5 PnetCDF libaio mpi4py

# shellcheck source=/dev/null
source ~/.bashrc

# Activate the environment
source ../../envAI_hdfml/bin/activate

GAN_DATASET="exp_data" #"/p/scratch/intertwin/datasets/cern/"

# launch training
TRAINING_CMD="$(which itwinai) exec-pipeline --config config.yaml --pipe-key training_pipeline \
                -o num_nodes=$SLURM_NNODES \
                -o dataset_location=$GAN_DATASET "

srun --cpu-bind=none --ntasks-per-node=1 \
    bash -c "torchrun \
    --log_dir='logs_torchrun' \
    --nnodes=$SLURM_NNODES \
    --nproc_per_node=$SLURM_GPUS_PER_NODE \
    --rdzv_id=$SLURM_JOB_ID \
    --rdzv_conf=is_host=\$(((SLURM_NODEID)) && echo 0 || echo 1) \
    --rdzv_backend=c10d \
    --rdzv_endpoint='$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)'i:29500 \
    $TRAINING_CMD "

SLURM job script for Vega Supercomputer (GPU partition)

#!/bin/bash

# SLURM jobscript for Vega systems

# Job configuration
#SBATCH --job-name=3dgan_training
#SBATCH --account=s24r05-03-users
#SBATCH --mail-user=
#SBATCH --mail-type=ALL
#SBATCH --output=job.out
#SBATCH --error=job.err
#SBATCH --time=01:00:00

# Resources allocation
#SBATCH --partition=gpu
#SBATCH --nodes=2
#SBATCH --gpus-per-node=4
#SBATCH --cpus-per-gpu=4
#SBATCH --ntasks-per-node=1
# SBATCH --mem-per-gpu=10G
#SBATCH --exclusive

# gres options have to be disabled for deepv
#SBATCH --gres=gpu:4

echo "DEBUG: SLURM_SUBMIT_DIR: $SLURM_SUBMIT_DIR"
echo "DEBUG: SLURM_JOB_ID: $SLURM_JOB_ID"
echo "DEBUG: SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST"
echo "DEBUG: SLURM_NNODES: $SLURM_NNODES"
echo "DEBUG: SLURM_NTASKS: $SLURM_NTASKS"
echo "DEBUG: SLURM_TASKS_PER_NODE: $SLURM_TASKS_PER_NODE"
echo "DEBUG: SLURM_SUBMIT_HOST: $SLURM_SUBMIT_HOST"
echo "DEBUG: SLURMD_NODENAME: $SLURMD_NODENAME"
echo "DEBUG: CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"

# ml --force purge
# ml Python CMake/3.24.3-GCCcore-11.3.0 mpi4py OpenMPI CUDA/11.7
# ml GCCcore/11.3.0 NCCL/2.12.12-GCCcore-11.3.0-CUDA-11.7.0 cuDNN

ml Python
module unload OpenSSL

source ~/.bashrc

# Activate the environment
source ../../.venv-pytorch/bin/activate

GAN_DATASET="exp_data" #"/ceph/hpc/data/st2301-itwin-users/egarciagarcia"

# launch training
TRAINING_CMD="$(which itwinai) exec-pipeline --config config.yaml --pipe-key training_pipeline \
                -o num_nodes=$SLURM_NNODES \
                -o dataset_location=$GAN_DATASET "

srun --cpu-bind=none --ntasks-per-node=1 \
    bash -c "torchrun \
    --log_dir='logs_torchrun' \
    --nnodes=$SLURM_NNODES \
    --nproc_per_node=$SLURM_GPUS_PER_NODE \
    --rdzv_id=$SLURM_JOB_ID \
    --rdzv_conf=is_host=\$(((SLURM_NODEID)) && echo 0 || echo 1) \
    --rdzv_backend=c10d \
    --rdzv_endpoint='$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)':29500 \
    $TRAINING_CMD "