3DGAN

This section covers the CERN use case that utilizes the torch-lightning framework for training and evaluation. Following you can find instructions to execute CERN use case and its integral scripts:

itwinai x 3DGAN

First of all, from the repository root, create a torch environment:

make torch-gpu

Now, install custom requirements for 3DGAN:

micromamba activate ./.venv-pytorch
cd use-cases/3dgan
pip install -r requirements.txt

NOTE: Python commands below assumed to be executed from within the micromamba virtual environment.

Training

Launch training using itwinai and the training configuration:

cd use-cases/3dgan
itwinai exec-pipeline --config config.yaml --pipe-key training_pipeline

# Or better:
micromamba run -p ../../.venv-pytorch/ torchrun --nproc_per_node gpu \
   itwinai exec-pipeline --config config.yaml --pipe-key training_pipeline

To visualize the logs with MLFLow, if you set a local path as tracking URI, run the following in the terminal:

micromamba run -p ../../.venv-pytorch mlflow ui --backend-store-uri LOCAL_TRACKING_URI

And select the β€œ3DGAN” experiment.

Inference

Disclaimer: the following is preliminary and not 100% ML/scientifically sound.

1. As inference dataset we can reuse training/validation dataset, for instance the one downloaded from Google Drive folder: if the dataset root folder is not present, the dataset will be downloaded. The inference dataset is a set of H5 files stored inside exp_data sub-folders:

β”œβ”€β”€ exp_data
β”‚   β”œβ”€β”€ data
|   β”‚   β”œβ”€β”€ file_0.h5
|   β”‚   β”œβ”€β”€ file_1.h5
...
|   β”‚   β”œβ”€β”€ file_N.h5

2. As model, if a pre-trained checkpoint is not available, we can create a dummy version of it with:

python create_inference_sample.py

3. Run inference command. This will generate a 3dgan-generated-data folder containing generated particle traces in form of torch tensors (.pth files) and 3D scatter plots (.jpg images).

itwinai exec-pipeline --config config.yaml --pipe-key inference_pipeline

The inference execution will produce a folder called 3dgan-generated-data containing generated 3D particle trajectories (overwritten if already there). Each generated 3D image is stored both as a torch tensor (.pth) and 3D scatter plot (.jpg):

β”œβ”€β”€ 3dgan-generated-data
|   β”œβ”€β”€ energy=1.296749234199524&angle=1.272539496421814.pth
|   β”œβ”€β”€ energy=1.296749234199524&angle=1.272539496421814.jpg
...
|   β”œβ”€β”€ energy=1.664689540863037&angle=1.4906378984451294.pth
|   β”œβ”€β”€ energy=1.664689540863037&angle=1.4906378984451294.jpg

However, if aggregate_predictions in the ParticleImagesSaver step is set to True, only one pickled file will be generated inside 3dgan-generated-data folder. Notice that multiple inference calls will create new files under 3dgan-generated-data folder.

With fields overriding:

# Override variables
export CERN_DATA_ROOT="../.."  # data root
export TMP_DATA_ROOT=$CERN_DATA_ROOT
export CERN_CODE_ROOT="." # where code and configuration are stored
export MAX_DATA_SAMPLES=20000 # max dataset size
export BATCH_SIZE=1024 # increase to fill up GPU memory
export NUM_WORKERS_DL=4 # num worker processes used by the dataloader to pre-fetch data
export AGGREGATE_PREDS="true" # write predictions in a single file
export ACCELERATOR="gpu" # choose "cpu" or "gpu"
export STRATEGY="auto" # distributed strategy
export DEVICES="0," # GPU devices list


itwinai exec-pipeline --print-config --config $CERN_CODE_ROOT/config.yaml \
   --pipe-key inference_pipeline \
   -o dataset_location=$CERN_DATA_ROOT/exp_data \
   -o logs_dir=$TMP_DATA_ROOT/ml_logs/mlflow_logs \
   -o distributed_strategy=$STRATEGY \
   -o devices=$DEVICES \
   -o hw_accelerators=$ACCELERATOR \
   -o checkpoints_path=\\$TMP_DATA_ROOT/checkpoints \
   -o inference_model_uri=$CERN_CODE_ROOT/3dgan-inference.pth \
   -o max_dataset_size=$MAX_DATA_SAMPLES \
   -o batch_size=$BATCH_SIZE \
   -o num_workers_dataloader=$NUM_WORKERS_DL \
   -o inference_results_location=$TMP_DATA_ROOT/3dgan-generated-data \
   -o aggregate_predictions=$AGGREGATE_PREDS

Docker image

Build from project root with

# Local
docker buildx build -t itwinai:0.0.1-3dgan-0.1 -f use-cases/3dgan/Dockerfile .

# Ghcr.io
docker buildx build -t ghcr.io/intertwin-eu/itwinai:0.0.1-3dgan-0.1 -f use-cases/3dgan/Dockerfile .
docker push ghcr.io/intertwin-eu/itwinai:0.0.1-3dgan-0.1

You can run inference from wherever a sample of H5 files is available (folder called exp_data/’):

β”œβ”€β”€ $PWD
|   β”œβ”€β”€ exp_data
|   β”‚   β”œβ”€β”€ data
|   |   β”‚   β”œβ”€β”€ file_0.h5
|   |   β”‚   β”œβ”€β”€ file_1.h5
...
|   |   β”‚   β”œβ”€β”€ file_N.h5
docker run -it --rm --name running-inference -v "$PWD":/tmp/data ghcr.io/intertwin-eu/itwinai:0.0.1-3dgan-0.1

This command will store the results in a folder called 3dgan-generated-data:

β”œβ”€β”€ $PWD
|   β”œβ”€β”€ 3dgan-generated-data
|   β”‚   β”œβ”€β”€ energy=1.296749234199524&angle=1.272539496421814.pth
|   β”‚   β”œβ”€β”€ energy=1.296749234199524&angle=1.272539496421814.jpg
...
|   β”‚   β”œβ”€β”€ energy=1.664689540863037&angle=1.4906378984451294.pth
|   β”‚   β”œβ”€β”€ energy=1.664689540863037&angle=1.4906378984451294.jpg

To override fields in the configuration file at runtime, you can use the -o flag. Example: -o path.to.config.element=NEW_VALUE.

Please find a complete exampled below, showing how to override default configurations by setting some env variables:

# Override variables
export CERN_DATA_ROOT="/usr/data"
export CERN_CODE_ROOT="/usr/src/app"
export MAX_DATA_SAMPLES=10 # max dataset size
export BATCH_SIZE=64 # increase to fill up GPU memory
export NUM_WORKERS_DL=4 # num worker processes used by the dataloader to pre-fetch data
export AGGREGATE_PREDS="true" # write predictions in a single file
export ACCELERATOR="gpu" # choose "cpu" or "gpu"

docker run -it --rm --name running-inference \
-v "$PWD":/usr/data ghcr.io/intertwin-eu/itwinai:0.0.1-3dgan-0.1 \
/bin/bash -c "itwinai exec-pipeline \
   --print-config --config $CERN_CODE_ROOT/config.yaml \
   --pipe-key inference_pipeline \
   -o dataset_location=$CERN_DATA_ROOT/exp_data \
   -o logs_dir=$TMP_DATA_ROOT/ml_logs/mlflow_logs \
   -o distributed_strategy=$STRATEGY \
   -o devices=$DEVICES \
   -o hw_accelerators=$ACCELERATOR \
   -o checkpoints_path=\\$TMP_DATA_ROOT/checkpoints \
   -o inference_model_uri=$CERN_CODE_ROOT/3dgan-inference.pth \
   -o max_dataset_size=$MAX_DATA_SAMPLES \
   -o batch_size=$BATCH_SIZE \
   -o num_workers_dataloader=$NUM_WORKERS_DL \
   -o inference_results_location=$TMP_DATA_ROOT/3dgan-generated-data \
   -o aggregate_predictions=$AGGREGATE_PREDS "

How to fully exploit GPU resources

Keeping the example above as reference, increase the value of BATCH_SIZE as much as possible (just below β€œout of memory” errors). Also, make sure that ACCELERATOR="gpu". Also, make sure to use a dataset large enough by changing the value of MAX_DATA_SAMPLES to collect meaningful performance data. Consider that each H5 file contains roughly 5k items, thus setting MAX_DATA_SAMPLES=10000 should be enough to use all items in each input H5 file.

You can try:

export MAX_DATA_SAMPLES=10000 # max dataset size
export BATCH_SIZE=1024 # increase to fill up GPU memory
export ACCELERATOR="gpu

Singularity

Run Docker container with Singularity:

singularity run --nv -B "$PWD":/usr/data docker://ghcr.io/intertwin-eu/itwinai:0.0.1-3dgan-0.1 /bin/bash -c \
"cd /usr/src/app && itwinai exec-pipeline --config config.yaml --pipe-key inference_pipeline"

Example with overrides (as above for Docker):

# Override variables
export CERN_DATA_ROOT="/usr/data"
export CERN_CODE_ROOT="/usr/src/app"
export MAX_DATA_SAMPLES=10 # max dataset size
export BATCH_SIZE=64 # increase to fill up GPU memory
export NUM_WORKERS_DL=4 # num worker processes used by the dataloader to pre-fetch data
export AGGREGATE_PREDS="true" # write predictions in a single file
export ACCELERATOR="gpu" # choose "cpu" or "gpu"

singularity run --nv -B "$PWD":/usr/data docker://ghcr.io/intertwin-eu/itwinai:0.0.1-3dgan-0.1 /bin/bash -c \
"cd /usr/src/app && itwinai exec-pipeline \
   --print-config --config $CERN_CODE_ROOT/config.yaml \
   --pipe-key inference_pipeline \
   -o dataset_location=$CERN_DATA_ROOT/exp_data \
   -o logs_dir=$TMP_DATA_ROOT/ml_logs/mlflow_logs \
   -o distributed_strategy=$STRATEGY \
   -o devices=$DEVICES \
   -o hw_accelerators=$ACCELERATOR \
   -o checkpoints_path=\\$TMP_DATA_ROOT/checkpoints \
   -o inference_model_uri=$CERN_CODE_ROOT/3dgan-inference.pth \
   -o max_dataset_size=$MAX_DATA_SAMPLES \
   -o batch_size=$BATCH_SIZE \
   -o num_workers_dataloader=$NUM_WORKERS_DL \
   -o inference_results_location=$TMP_DATA_ROOT/3dgan-generated-data \
   -o aggregate_predictions=$AGGREGATE_PREDS "

model.py

import sys
import os
# import pickle
from collections import defaultdict
import math
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning as pl
import numpy as np


class Generator(nn.Module):
    def __init__(self, latent_dim):  # img_shape
        super().__init__()
        # self.img_shape = img_shape
        self.latent_dim = latent_dim

        self.l1 = nn.Linear(self.latent_dim, 5184)
        self.up1 = nn.Upsample(
            scale_factor=(6, 6, 6),
            mode='trilinear',
            align_corners=False
        )
        self.conv1 = nn.Conv3d(
            in_channels=8, out_channels=8,
            kernel_size=(6, 6, 8), padding=0
        )
        nn.init.kaiming_uniform_(self.conv1.weight)
        # num_features is the number of channels (see doc)
        self.bn1 = nn.BatchNorm3d(num_features=8, eps=1e-6)
        self.pad1 = nn.ConstantPad3d((1, 1, 2, 2, 2, 2), 0)

        self.conv2 = nn.Conv3d(
            in_channels=8, out_channels=6,
            kernel_size=(4, 4, 6), padding=0
        )
        nn.init.kaiming_uniform_(self.conv2.weight)
        self.bn2 = nn.BatchNorm3d(num_features=6, eps=1e-6)
        self.pad2 = nn.ConstantPad3d((1, 1, 2, 2, 2, 2), 0)

        self.conv3 = nn.Conv3d(
            in_channels=6, out_channels=6,
            kernel_size=(4, 4, 6), padding=0
        )
        nn.init.kaiming_uniform_(self.conv3.weight)
        self.bn3 = nn.BatchNorm3d(num_features=6, eps=1e-6)
        self.pad3 = nn.ConstantPad3d((1, 1, 2, 2, 2, 2), 0)

        self.conv4 = nn.Conv3d(
            in_channels=6, out_channels=6,
            kernel_size=(4, 4, 6), padding=0
        )
        nn.init.kaiming_uniform_(self.conv4.weight)
        self.bn4 = nn.BatchNorm3d(num_features=6, eps=1e-6)
        self.pad4 = nn.ConstantPad3d((0, 0, 1, 1, 1, 1), 0)

        self.conv5 = nn.Conv3d(
            in_channels=6, out_channels=6,
            kernel_size=(3, 3, 5), padding=0
        )
        nn.init.kaiming_uniform_(self.conv5.weight)
        self.bn5 = nn.BatchNorm3d(num_features=6, eps=1e-6)
        self.pad5 = nn.ConstantPad3d((0, 0, 1, 1, 1, 1), 0)

        self.conv6 = nn.Conv3d(
            in_channels=6, out_channels=6,
            kernel_size=(3, 3, 3), padding=0
        )
        nn.init.kaiming_uniform_(self.conv6.weight)

        self.conv7 = nn.Conv3d(
            in_channels=6, out_channels=1,
            kernel_size=(2, 2, 2), padding=0
        )
        nn.init.xavier_normal_(self.conv7.weight)

    def forward(self, z):
        img = self.l1(z)
        img = img.view(-1, 8, 9, 9, 8)
        img = self.up1(img)
        img = self.conv1(img)
        img = F.relu(img)
        img = self.bn1(img)
        img = self.pad1(img)

        img = self.conv2(img)
        img = F.relu(img)
        img = self.bn2(img)
        img = self.pad2(img)

        img = self.conv3(img)
        img = F.relu(img)
        img = self.bn3(img)
        img = self.pad3(img)

        img = self.conv4(img)
        img = F.relu(img)
        img = self.bn4(img)
        img = self.pad4(img)

        img = self.conv5(img)
        img = F.relu(img)
        img = self.bn5(img)
        img = self.pad5(img)

        img = self.conv6(img)
        img = F.relu(img)

        img = self.conv7(img)
        img = F.relu(img)

        return img


class Discriminator(nn.Module):
    def __init__(self, power):
        super().__init__()

        self.power = power

        self.conv1 = nn.Conv3d(
            in_channels=1, out_channels=16,
            kernel_size=(5, 6, 6), padding=(2, 3, 3)
        )
        self.drop1 = nn.Dropout(0.2)
        self.pad1 = nn.ConstantPad3d((1, 1, 0, 0, 0, 0), 0)

        self.conv2 = nn.Conv3d(
            in_channels=16, out_channels=8,
            kernel_size=(5, 6, 6), padding=0
        )
        self.bn1 = nn.BatchNorm3d(num_features=8, eps=1e-6)
        self.drop2 = nn.Dropout(0.2)
        self.pad2 = nn.ConstantPad3d((1, 1, 0, 0, 0, 0), 0)

        self.conv3 = nn.Conv3d(
            in_channels=8, out_channels=8,
            kernel_size=(5, 6, 6), padding=0
        )
        self.bn2 = nn.BatchNorm3d(num_features=8, eps=1e-6)
        self.drop3 = nn.Dropout(0.2)

        self.conv4 = nn.Conv3d(
            in_channels=8, out_channels=8,
            kernel_size=(5, 6, 6), padding=0
        )
        self.bn3 = nn.BatchNorm3d(num_features=8, eps=1e-6)
        self.drop4 = nn.Dropout(0.2)

        self.avgpool = nn.AvgPool3d((2, 2, 2))
        self.flatten = nn.Flatten()

        # The input features for the Linear layer need to be calculated based
        # on the output shape from the previous layers.
        self.fakeout = nn.Linear(19152, 1)
        self.auxout = nn.Linear(19152, 1)  # The same as above for this layer.

    # calculate sum of intensities
    def ecal_sum(self, image, daxis):
        sum = torch.sum(image, dim=daxis)
        return sum

    # angle calculation
    def ecal_angle(self, image, daxis1):
        image = torch.squeeze(image, dim=daxis1)  # squeeze along channel axis

        # get shapes
        x_shape = image.shape[1]
        y_shape = image.shape[2]
        z_shape = image.shape[3]
        sumtot = torch.sum(image, dim=(1, 2, 3))  # sum of events

        # get 1. where event sum is 0 and 0 elsewhere
        amask = torch.where(sumtot == 0.0, torch.ones_like(
            sumtot), torch.zeros_like(sumtot))
        # masked_events = torch.sum(amask)  # counting zero sum events

        # ref denotes barycenter as that is our reference point
        x_ref = torch.sum(torch.sum(image, dim=(2, 3))
                          * (torch.arange(x_shape, device=image.device,
                             dtype=torch.float32).unsqueeze(0) + 0.5),
                          dim=1,)  # sum for x position * x index
        y_ref = torch.sum(
            torch.sum(image, dim=(1, 3))
            * (torch.arange(y_shape, device=image.device,
               dtype=torch.float32).unsqueeze(0) + 0.5),
            dim=1,)
        z_ref = torch.sum(
            torch.sum(image, dim=(1, 2))
            * (torch.arange(z_shape, device=image.device,
               dtype=torch.float32).unsqueeze(0) + 0.5),
            dim=1,)

        # return max position if sumtot=0 and divide by sumtot otherwise
        x_ref = torch.where(
            sumtot == 0.0, torch.ones_like(x_ref), x_ref / sumtot)
        y_ref = torch.where(
            sumtot == 0.0, torch.ones_like(y_ref), y_ref / sumtot)
        z_ref = torch.where(
            sumtot == 0.0, torch.ones_like(z_ref), z_ref / sumtot)

        # reshape
        x_ref = x_ref.unsqueeze(1)
        y_ref = y_ref.unsqueeze(1)
        z_ref = z_ref.unsqueeze(1)

        sumz = torch.sum(image, dim=(1, 2))  # sum for x,y planes going along z

        # Get 0 where sum along z is 0 and 1 elsewhere
        zmask = torch.where(sumz == 0.0, torch.zeros_like(
            sumz), torch.ones_like(sumz))

        x = torch.arange(x_shape, device=image.device).unsqueeze(
            0)  # x indexes
        x = (x.unsqueeze(2).float()) + 0.5
        y = torch.arange(y_shape, device=image.device).unsqueeze(
            0)  # y indexes
        y = (y.unsqueeze(2).float()) + 0.5

        # barycenter for each z position
        x_mid = torch.sum(torch.sum(image, dim=2) * x, dim=1)
        y_mid = torch.sum(torch.sum(image, dim=1) * y, dim=1)

        x_mid = torch.where(sumz == 0.0, torch.zeros_like(
            sumz), x_mid / sumz)  # if sum != 0 then divide by sum
        y_mid = torch.where(sumz == 0.0, torch.zeros_like(
            sumz), y_mid / sumz)  # if sum != 0 then divide by sum

        # Angle Calculations
        z = (torch.arange(
            z_shape,
            device=image.device,
            dtype=torch.float32
            # Make an array of z indexes for all events
        ) + 0.5) * torch.ones_like(z_ref)

        # projection from z axis with stability check
        zproj = torch.sqrt(
            torch.max(
                (x_mid - x_ref) ** 2.0 + (z - z_ref) ** 2.0,
                torch.tensor(
                    [torch.finfo(torch.float32).eps]
                ).to(x_mid.device)
            )
        )
        # torch.finfo(torch.float32).eps))
        # to avoid divide by zero for zproj =0
        m = torch.where(zproj == 0.0, torch.zeros_like(
            zproj), (y_mid - y_ref) / zproj)
        m = torch.where(z < z_ref, -1 * m, m)  # sign inversion
        ang = (math.pi / 2.0) - torch.atan(m)  # angle correction
        zmask = torch.where(zproj == 0.0, torch.zeros_like(zproj), zmask)
        ang = ang * zmask  # place zero where zsum is zero
        ang = ang * z  # weighted by position
        sumz_tot = z * zmask  # removing indexes with 0 energies or angles

        # zunmasked = K.sum(zmask, axis=1) # used for simple mean
        # Mean does not include positions where zsum=0
        # ang = K.sum(ang, axis=1)/zunmasked

        # sum ( measured * weights)/sum(weights)
        ang = torch.sum(ang, dim=1) / torch.sum(sumz_tot, dim=1)
        # Place 100 for measured angle where no energy is deposited in events
        ang = torch.where(amask == 0.0, ang, 100.0 * torch.ones_like(ang))
        ang = ang.unsqueeze(1)
        return ang

    def forward(self, x):
        z = self.conv1(x)
        z = F.leaky_relu(z)
        z = self.drop1(z)
        z = self.pad1(z)

        z = self.conv2(z)
        z = F.leaky_relu(z)
        z = self.bn1(z)
        z = self.drop2(z)
        z = self.pad2(z)

        z = self.conv3(z)
        z = F.leaky_relu(z)
        z = self.bn2(z)
        z = self.drop3(z)

        z = self.conv4(z)
        z = F.leaky_relu(z)
        z = self.bn3(z)
        z = self.drop4(z)
        z = self.avgpool(z)
        z = self.flatten(z)

        # generation output that says fake/real
        fake = torch.sigmoid(self.fakeout(z))
        aux = self.auxout(z)  # auxiliary output
        inv_image = x.pow(1.0 / self.power)
        ang = self.ecal_angle(inv_image, 1)  # angle calculation
        ecal = self.ecal_sum(inv_image, (2, 3, 4))  # sum of energies

        return fake, aux, ang, ecal


class ThreeDGAN(pl.LightningModule):
    def __init__(
        self,
        latent_size=256,
        loss_weights=[3, 0.1, 25, 0.1],
        power=0.85,
        lr=0.001,
        checkpoints_dir: str = '.'
        # checkpoint_path: str = '3Dgan.pth'
    ):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False

        self.latent_size = latent_size
        self.loss_weights = loss_weights
        self.lr = lr
        self.power = power
        self.checkpoints_dir = checkpoints_dir
        os.makedirs(self.checkpoints_dir, exist_ok=True)

        self.generator = Generator(self.latent_size)
        self.discriminator = Discriminator(self.power)

        self.epoch_gen_loss = []
        self.epoch_disc_loss = []
        self.disc_epoch_test_loss = []
        self.gen_epoch_test_loss = []
        self.index = 0
        self.train_history = defaultdict(list)
        self.test_history = defaultdict(list)
        # self.pklfile = checkpoint_path
        # checkpoint_dir = os.path.dirname(checkpoint_path)
        # os.makedirs(checkpoint_dir, exist_ok=True)

    def BitFlip(self, x, prob=0.05):
        """
        Flips a single bit according to a certain probability.

        Args:
            x (list): list of bits to be flipped
            prob (float): probability of flipping one bit

        Returns:
            list: List of flipped bits

        """
        x = np.array(x)
        selection = np.random.uniform(0, 1, x.shape) < prob
        x[selection] = 1 * np.logical_not(x[selection])
        return x

    def mean_absolute_percentage_error(self, y_true, y_pred):
        return torch.mean(torch.abs((y_true - y_pred) / (y_true + 1e-7))) * 100

    def compute_global_loss(
        self,
        labels,
        predictions,
        loss_weights=(3, 0.1, 25, 0.1)
    ):
        # Can be initialized outside
        binary_crossentropy_object = nn.BCEWithLogitsLoss(reduction='none')
        # there is no equivalent in pytorch for
        # tf.keras.losses.MeanAbsolutePercentageError --> using the
        # custom "mean_absolute_percentage_error" above!
        mean_absolute_percentage_error_object1 = \
            self.mean_absolute_percentage_error(predictions[1], labels[1])
        mean_absolute_percentage_error_object2 = \
            self.mean_absolute_percentage_error(predictions[3], labels[3])
        mae_object = nn.L1Loss(reduction='none')

        binary_example_loss = binary_crossentropy_object(
            predictions[0], labels[0]) * loss_weights[0]

        # mean_example_loss_1 = mean_absolute_percentage_error_object(
        # predictions[1], labels[1]) * loss_weights[1]
        mean_example_loss_1 = \
            mean_absolute_percentage_error_object1 * loss_weights[1]

        mae_example_loss = mae_object(
            predictions[2], labels[2]) * loss_weights[2]

        # mean_example_loss_2 = mean_absolute_percentage_error_object(
        # predictions[3], labels[3]) * loss_weights[3]
        mean_example_loss_2 = \
            mean_absolute_percentage_error_object2 * loss_weights[3]

        binary_loss = binary_example_loss.mean()
        mean_loss_1 = mean_example_loss_1.mean()
        mae_loss = mae_example_loss.mean()
        mean_loss_2 = mean_example_loss_2.mean()

        return [binary_loss, mean_loss_1, mae_loss, mean_loss_2]

    def forward(self, z):
        return self.generator(z)

    def training_step(self, batch, batch_idx):
        image_batch, energy_batch, ang_batch, ecal_batch = \
            batch['X'], batch['Y'], batch['ang'], batch['ecal']

        image_batch = image_batch.permute(0, 4, 1, 2, 3)

        image_batch = image_batch.to(self.device)
        energy_batch = energy_batch.to(self.device)
        ang_batch = ang_batch.to(self.device)
        ecal_batch = ecal_batch.to(self.device)

        optimizer_discriminator, optimizer_generator = self.optimizers()
        batch_size = energy_batch.shape[0]

        noise = torch.randn(
            (batch_size, self.latent_size - 2),
            dtype=torch.float32,
            device=self.device
        )
        # print(f'Energy elements: {energy_batch.numel} {energy_batch.shape}')
        # print(f'Angle elements: {ang_batch.numel} {ang_batch.shape}')
        generator_ip = torch.cat(
            (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise),
            dim=1
        )
        generated_images = self.generator(generator_ip)

        # Train discriminator first on real batch
        fake_batch = self.BitFlip(np.ones(batch_size).astype(np.float32))
        fake_batch = torch.tensor([[el] for el in fake_batch]).to(self.device)
        labels = [fake_batch, energy_batch, ang_batch, ecal_batch]

        predictions = self.discriminator(image_batch)
        # print("calculating real_batch_loss...")
        real_batch_loss = self.compute_global_loss(
            labels, predictions, self.loss_weights)
        self.log("real_batch_loss", sum(real_batch_loss),
                 prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
        # print("real batch disc train")
        # the following 3 lines correspond in tf version to:
        # gradients = tape.gradient(real_batch_loss,
        # discriminator.trainable_variables)
        # optimizer_discriminator.apply_gradients(zip(gradients,
        #  discriminator.trainable_variables)) in Tensorflow
        optimizer_discriminator.zero_grad()
        self.manual_backward(sum(real_batch_loss))
        # sum(real_batch_loss).backward()
        # real_batch_loss.backward()
        optimizer_discriminator.step()

        # Train discriminator on the fake batch
        fake_batch = self.BitFlip(np.zeros(batch_size).astype(np.float32))
        fake_batch = torch.tensor([[el] for el in fake_batch]).to(self.device)
        labels = [fake_batch, energy_batch, ang_batch, ecal_batch]

        predictions = self.discriminator(generated_images)

        fake_batch_loss = self.compute_global_loss(
            labels, predictions, self.loss_weights)
        self.log("fake_batch_loss", sum(fake_batch_loss),
                 prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
        # print("fake batch disc train")
        # the following 3 lines correspond to
        # gradients = tape.gradient(fake_batch_loss,
        # discriminator.trainable_variables)
        # optimizer_discriminator.apply_gradients(zip(gradients,
        # discriminator.trainable_variables)) in Tensorflow
        optimizer_discriminator.zero_grad()
        self.manual_backward(sum(fake_batch_loss))
        # sum(fake_batch_loss).backward()
        optimizer_discriminator.step()

        # avg_disc_loss = (sum(real_batch_loss) + sum(fake_batch_loss)) / 2

        trick = np.ones(batch_size).astype(np.float32)
        fake_batch = torch.tensor([[el] for el in trick]).to(self.device)
        labels = [fake_batch, energy_batch.view(-1, 1), ang_batch, ecal_batch]

        gen_losses_train = []
        # Train generator twice using combined model
        for _ in range(2):
            noise = torch.randn(
                (batch_size, self.latent_size - 2)).to(self.device)
            generator_ip = torch.cat(
                (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise),
                dim=1
            )

            generated_images = self.generator(generator_ip)
            predictions = self.discriminator(generated_images)

            loss = self.compute_global_loss(
                labels, predictions, self.loss_weights)
            self.log("gen_loss", sum(loss), prog_bar=True,
                     on_step=True, on_epoch=True, sync_dist=True)
            # print("gen train")
            optimizer_generator.zero_grad()
            self.manual_backward(sum(loss))
            # sum(loss).backward()
            optimizer_generator.step()

            for el in loss:
                gen_losses_train.append(el)

        avg_generator_loss = sum(gen_losses_train) / len(gen_losses_train)
        self.log("generator_loss", avg_generator_loss.item(),
                 prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
        # avg_generator_loss = [(a + b) / 2 for a, b in zip(*gen_losses_train)]
        # self.log("generator_loss", sum(avg_generator_loss), prog_bar=True,
        # on_step=True, on_epoch=True, sync_dist=True)

        gen_losses = []
        # I'm not returning anything as in pl you do not return anything when
        # you back-propagate manually
        # return_loss = real_batch_loss
        real_batch_loss = [real_batch_loss[0], real_batch_loss[1],
                           real_batch_loss[2], real_batch_loss[3]]
        fake_batch_loss = [fake_batch_loss[0], fake_batch_loss[1],
                           fake_batch_loss[2], fake_batch_loss[3]]
        gen_batch_loss = [gen_losses_train[0], gen_losses_train[1],
                          gen_losses_train[2], gen_losses_train[3]]
        gen_losses.append(gen_batch_loss)
        gen_batch_loss = [gen_losses_train[4], gen_losses_train[5],
                          gen_losses_train[6], gen_losses_train[7]]
        gen_losses.append(gen_batch_loss)

        real_batch_loss = [el.cpu().detach().numpy() for el in real_batch_loss]
        real_batch_loss_total_loss = np.sum(real_batch_loss)
        new_real_batch_loss = [real_batch_loss_total_loss]
        for i_weights in range(len(real_batch_loss)):
            new_real_batch_loss.append(
                real_batch_loss[i_weights] / self.loss_weights[i_weights])
        real_batch_loss = new_real_batch_loss

        fake_batch_loss = [el.cpu().detach().numpy() for el in fake_batch_loss]
        fake_batch_loss_total_loss = np.sum(fake_batch_loss)
        new_fake_batch_loss = [fake_batch_loss_total_loss]
        for i_weights in range(len(fake_batch_loss)):
            new_fake_batch_loss.append(
                fake_batch_loss[i_weights] / self.loss_weights[i_weights])
        fake_batch_loss = new_fake_batch_loss

        # if ecal sum has 100% loss(generating empty events) then end
        # the training
        if fake_batch_loss[3] == 100.0 and self.index > 10:
            # print("Empty image with Ecal loss equal to 100.0 "
            #       f"for {self.index} batch")
            torch.save(self.generator.state_dict(), os.path.join(
                self.checkpoints_dir, "generator_weights.pth"))
            torch.save(self.discriminator.state_dict(), os.path.join(
                       self.checkpoints_dir, "discriminator_weights.pth"))
            # print("real_batch_loss", real_batch_loss)
            # print("fake_batch_loss", fake_batch_loss)
            sys.exit()

        # append mean of discriminator loss for real and fake events
        self.epoch_disc_loss.append(
            [(a + b) / 2 for a, b in zip(real_batch_loss, fake_batch_loss)])

        gen_losses[0] = [el.cpu().detach().numpy() for el in gen_losses[0]]
        gen_losses_total_loss = np.sum(gen_losses[0])
        new_gen_losses = [gen_losses_total_loss]
        for i_weights in range(len(gen_losses[0])):
            new_gen_losses.append(
                gen_losses[0][i_weights] / self.loss_weights[i_weights])
        gen_losses[0] = new_gen_losses

        gen_losses[1] = [el.cpu().detach().numpy() for el in gen_losses[1]]
        gen_losses_total_loss = np.sum(gen_losses[1])
        new_gen_losses = [gen_losses_total_loss]
        for i_weights in range(len(gen_losses[1])):
            new_gen_losses.append(
                gen_losses[1][i_weights] / self.loss_weights[i_weights])
        gen_losses[1] = new_gen_losses

        generator_loss = [(a + b) / 2 for a, b in zip(*gen_losses)]

        self.epoch_gen_loss.append(generator_loss)

        # # MB: verify weight synchronization among workers
        # # Ref: https://github.com/Lightning-AI/lightning/issues/9237
        # disc_w = self.discriminator.conv1.weight.reshape(-1)[0:5]
        # gen_w = self.generator.conv1.weight.reshape(-1)[0:5]
        # print(f"DISC w: {disc_w}")
        # print(f"GEN w: {gen_w}")

        # self.index += 1 #this might be moved after test cycle

        # logging of gen and disc loss done by Trainer
        # self.log('epoch_gen_loss', self.epoch_gen_loss, on_step=True,
        #  on_epoch=True, sync_dist=True)
        # self.log('epoch_disc_loss', self.epoch_disc_loss, on_step=True, o
        # n_epoch=True, sync_dist=True)

        # return avg_disc_loss + avg_generator_loss

    def on_train_epoch_end(self):  # outputs
        discriminator_train_loss = np.mean(
            np.array(self.epoch_disc_loss), axis=0)
        generator_train_loss = np.mean(np.array(self.epoch_gen_loss), axis=0)

        self.train_history["generator"].append(generator_train_loss)
        self.train_history["discriminator"].append(discriminator_train_loss)

        print("-" * 65)
        ROW_FMT = (
            "{0:<20s} | {1:<4.2f} | {2:<10.2f} | "
            "{3:<10.2f}| {4:<10.2f} | {5:<10.2f}")
        print(ROW_FMT.format("generator (train)",
              *self.train_history["generator"][-1]))
        print(ROW_FMT.format("discriminator (train)",
              *self.train_history["discriminator"][-1]))

        torch.save(self.generator.state_dict(), os.path.join(
            self.checkpoints_dir, "generator_weights.pth"))
        torch.save(self.discriminator.state_dict(), os.path.join(
            self.checkpoints_dir, "discriminator_weights.pth"))

        # with open(self.pklfile, "wb") as f:
        #     pickle.dump({"train": self.train_history,
        #                 "test": self.test_history}, f)

        # pickle.dump({"train": self.train_history}, open(self.pklfile, "wb"))
        print("train-loss:" + str(self.train_history["generator"][-1][0]))

    def validation_step(self, batch, batch_idx):
        image_batch, energy_batch, ang_batch, ecal_batch = batch[
            'X'], batch['Y'], batch['ang'], batch['ecal']

        image_batch = image_batch.permute(0, 4, 1, 2, 3)

        image_batch = image_batch.to(self.device)
        energy_batch = energy_batch.to(self.device)
        ang_batch = ang_batch.to(self.device)
        ecal_batch = ecal_batch.to(self.device)

        batch_size = energy_batch.shape[0]

        # Generate Fake events with same energy and angle as data batch
        noise = torch.randn(
            (batch_size, self.latent_size - 2),
            dtype=torch.float32,
            device=self.device
        )

        generator_ip = torch.cat(
            (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise), dim=1)
        generated_images = self.generator(generator_ip)

        # concatenate to fake and real batches
        X = torch.cat((image_batch, generated_images), dim=0)

        # y = np.array([1] * batch_size \
        # + [0] * batch_size).astype(np.float32)
        y = torch.tensor([1] * batch_size + [0] *
                         batch_size, dtype=torch.float32).to(self.device)
        y = y.view(-1, 1)

        ang = torch.cat((ang_batch, ang_batch), dim=0)
        ecal = torch.cat((ecal_batch, ecal_batch), dim=0)
        aux_y = torch.cat((energy_batch, energy_batch), dim=0)

        # y = [[el] for el in y]
        labels = [y, aux_y, ang, ecal]

        # Calculate discriminator loss
        disc_eval = self.discriminator(X)
        disc_eval_loss = self.compute_global_loss(
            labels, disc_eval, self.loss_weights)

        # Calculate generator loss
        trick = np.ones(batch_size).astype(np.float32)
        fake_batch = torch.tensor([[el] for el in trick]).to(self.device)
        # fake_batch = [[el] for el in trick]
        labels = [fake_batch, energy_batch, ang_batch, ecal_batch]

        generated_images = self.generator(generator_ip)
        gen_eval = self.discriminator(generated_images)
        gen_eval_loss = self.compute_global_loss(
            labels, gen_eval, self.loss_weights)

        self.log('val_discriminator_loss', sum(
            disc_eval_loss), on_epoch=True, prog_bar=True, sync_dist=True)
        self.log('val_generator_loss', sum(gen_eval_loss),
                 on_epoch=True, prog_bar=True, sync_dist=True)

        disc_test_loss = [disc_eval_loss[0], disc_eval_loss[1],
                          disc_eval_loss[2], disc_eval_loss[3]]
        gen_test_loss = [gen_eval_loss[0], gen_eval_loss[1],
                         gen_eval_loss[2], gen_eval_loss[3]]

        # Configure the loss so it is equal to the original values
        disc_eval_loss = [el.cpu().detach().numpy() for el in disc_test_loss]
        disc_eval_loss_total_loss = np.sum(disc_eval_loss)
        new_disc_eval_loss = [disc_eval_loss_total_loss]
        for i_weights in range(len(disc_eval_loss)):
            new_disc_eval_loss.append(
                disc_eval_loss[i_weights] / self.loss_weights[i_weights])
        disc_eval_loss = new_disc_eval_loss

        gen_eval_loss = [el.cpu().detach().numpy() for el in gen_test_loss]
        gen_eval_loss_total_loss = np.sum(gen_eval_loss)
        new_gen_eval_loss = [gen_eval_loss_total_loss]
        for i_weights in range(len(gen_eval_loss)):
            new_gen_eval_loss.append(
                gen_eval_loss[i_weights] / self.loss_weights[i_weights])
        gen_eval_loss = new_gen_eval_loss

        self.index += 1
        # evaluate discriminator loss
        self.disc_epoch_test_loss.append(disc_eval_loss)
        # evaluate generator loss
        self.gen_epoch_test_loss.append(gen_eval_loss)

    def on_validation_epoch_end(self):
        discriminator_test_loss = np.mean(
            np.array(self.disc_epoch_test_loss), axis=0)
        generator_test_loss = np.mean(
            np.array(self.gen_epoch_test_loss), axis=0)

        self.test_history["generator"].append(generator_test_loss)
        self.test_history["discriminator"].append(discriminator_test_loss)

        print("-" * 65)
        ROW_FMT = (
            "{0:<20s} | {1:<4.2f} | {2:<10.2f} | "
            "{3:<10.2f}| {4:<10.2f} | {5:<10.2f}")
        print(ROW_FMT.format("generator (test)",
              *self.test_history["generator"][-1]))
        print(ROW_FMT.format("discriminator (test)",
              *self.test_history["discriminator"][-1]))

        # # save loss dict to pkl file
        # with open(self.pklfile, "wb") as f:
        #     pickle.dump({"train": self.train_history,
        #                 "test": self.test_history}, f)
        # pickle.dump({"test": self.test_history}, open(self.pklfile, "wb"))
        # print("train-loss:" + str(self.train_history["generator"][-1][0]))

    def predict_step(
        self,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0
    ) -> Any:
        energy_batch, ang_batch = batch['Y'], batch['ang']

        energy_batch = energy_batch.to(self.device)
        ang_batch = ang_batch.to(self.device)

        # Generate Fake events with same energy and angle as data batch
        noise = torch.randn(
            (energy_batch.shape[0], self.latent_size - 2),
            dtype=torch.float32,
            device=self.device
        )

        # print(f"Reshape energy: {energy_batch.view(-1, 1).shape}")
        # print(f"Reshape angle: {ang_batch.view(-1, 1).shape}")
        # print(f"Noise: {noise.shape}")

        generator_ip = torch.cat(
            [energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise],
            dim=1
        )
        # print(f"Generator input: {generator_ip.shape}")
        generated_images = self.generator(generator_ip)
        # print(f"Generated batch size {generated_images.shape}")
        return {'images': generated_images,
                'energies': energy_batch,
                'angles': ang_batch}

    def configure_optimizers(self):
        lr = self.lr

        optimizer_discriminator = torch.optim.RMSprop(
            self.discriminator.parameters(),
            lr
        )
        optimizer_generator = torch.optim.RMSprop(
            self.generator.parameters(),
            lr
        )
        return [optimizer_discriminator, optimizer_generator], []

trainer.py

import os
import sys
from typing import Union, Dict, Optional, Any

import torch
from torch import Tensor
import lightning as pl
from lightning.pytorch.cli import LightningCLI

from itwinai.components import Trainer, Predictor, monitor_exec
from itwinai.serialization import ModelLoader
from itwinai.torch.inference import TorchModelLoader
from itwinai.torch.type import Batch
from itwinai.utils import load_yaml
from itwinai.torch.mlflow import (
    init_lightning_mlflow,
    teardown_lightning_mlflow
)


from model import ThreeDGAN
from dataloader import ParticlesDataModule


class Lightning3DGANTrainer(Trainer):
    def __init__(self, config: Union[Dict, str], exp_root: str = '.'):
        self.save_parameters(**self.locals2params(locals()))
        super().__init__()
        if isinstance(config, str) and os.path.isfile(config):
            # Load from YAML
            config = load_yaml(config)
        self.conf = config
        self.exp_root = exp_root

    @monitor_exec
    def execute(self) -> Any:
        init_lightning_mlflow(
            self.conf,
            tmp_dir=os.path.join(self.exp_root, '.tmp'),
            registered_model_name='3dgan-lite'
        )
        old_argv = sys.argv
        sys.argv = ['some_script_placeholder.py']
        cli = LightningCLI(
            args=self.conf,
            model_class=ThreeDGAN,
            datamodule_class=ParticlesDataModule,
            run=False,
            save_config_kwargs={
                "overwrite": True,
                "config_filename": "pl-training.yml",
            },
            subclass_mode_model=True,
            subclass_mode_data=True,
        )
        sys.argv = old_argv
        cli.trainer.fit(cli.model, datamodule=cli.datamodule)
        teardown_lightning_mlflow()


class LightningModelLoader(TorchModelLoader):
    """Loads a torch lightning model from somewhere.

    Args:
        model_uri (str): Can be a path on local filesystem
            or an mlflow 'locator' in the form:
            'mlflow+MLFLOW_TRACKING_URI+RUN_ID+ARTIFACT_PATH'
    """

    def __call__(self) -> pl.LightningModule:
        """"Loads model from model URI.

        Raises:
            ValueError: if the model URI is not recognized
                or the model is not found.

        Returns:
            pl.LightningModule: torch lightning module.
        """
        # TODO: improve
        # # Load best model
        # loaded_model = cli.model.load_from_checkpoint(
        #     ckpt_path,
        #     lightning_conf['model']['init_args']
        # )
        return super().__call__()


class Lightning3DGANPredictor(Predictor):

    def __init__(
        self,
        model: Union[ModelLoader, pl.LightningModule],
        config: Union[Dict, str],
        name: Optional[str] = None
    ):
        self.save_parameters(**self.locals2params(locals()))
        super().__init__(model, name)
        if isinstance(config, str) and os.path.isfile(config):
            # Load from YAML
            config = load_yaml(config)
        self.conf = config

    @monitor_exec
    def execute(
        self,
        datamodule: Optional[pl.LightningDataModule] = None,
        model: Optional[pl.LightningModule] = None
    ) -> Dict[str, Tensor]:
        old_argv = sys.argv
        sys.argv = ['some_script_placeholder.py']
        cli = LightningCLI(
            args=self.conf,
            model_class=ThreeDGAN,
            datamodule_class=ParticlesDataModule,
            run=False,
            save_config_kwargs={
                "overwrite": True,
                "config_filename": "pl-training.yml",
            },
            subclass_mode_model=True,
            subclass_mode_data=True,
        )
        sys.argv = old_argv

        # Override config file with inline arguments, if given
        if datamodule is None:
            datamodule = cli.datamodule
        if model is None:
            model = cli.model

        predictions = cli.trainer.predict(model, datamodule=datamodule)

        # Transpose predictions into images, energies and angles
        images = torch.cat(list(map(
            lambda pred: self.transform_predictions(
                pred['images']), predictions
        )))
        energies = torch.cat(list(map(
            lambda pred: pred['energies'], predictions
        )))
        angles = torch.cat(list(map(
            lambda pred: pred['angles'], predictions
        )))

        predictions_dict = dict()
        for img, en, ang in zip(images, energies, angles):
            sample_key = f"energy={en.item()}&angle={ang.item()}"
            predictions_dict[sample_key] = img

        return predictions_dict

    def transform_predictions(self, batch: Batch) -> Batch:
        """
        Post-process the predictions of the torch model.
        """
        return batch.squeeze(1)

saver.py

from typing import Dict
import os
import shutil
import pickle
import random

import torch
from torch import Tensor
import matplotlib.pyplot as plt
import numpy as np

from itwinai.components import Saver, monitor_exec


class ParticleImagesSaver(Saver):
    """Saves generated particle trajectories to disk."""

    def __init__(
        self,
        save_dir: str = '3dgan-generated',
        aggregate_predictions: bool = False
    ) -> None:
        self.save_parameters(**self.locals2params(locals()))
        super().__init__()
        self.save_dir = save_dir
        self.aggregate_predictions = aggregate_predictions

    @monitor_exec
    def execute(self, generated_images: Dict[str, Tensor]) -> None:
        """Saves generated images to disk.

        Args:
            generated_images (Dict[str, Tensor]): maps unique item ID to
                the generated image.
        """
        if self.aggregate_predictions:
            os.makedirs(self.save_dir, exist_ok=True)
            sparse_generated_images = dict()
            for name, res in generated_images.items():
                sparse_generated_images[name] = res.to_sparse()
            del generated_images
            with open(self._random_file(), 'wb') as fp:
                pickle.dump(sparse_generated_images, fp)
        else:
            if os.path.exists(self.save_dir):
                shutil.rmtree(self.save_dir)
            os.makedirs(self.save_dir)
            # Save as torch tensor and jpg image
            for img_id, img in generated_images.items():
                img_path = os.path.join(self.save_dir, img_id)
                torch.save(img, img_path + '.pth')
                self._save_image(img, img_id, img_path + '.jpg')

    def _random_file(self, extension: str = 'pkl') -> str:
        fname = "%032x.%s" % (random.getrandbits(128), extension)
        fpath = os.path.join(self.save_dir, fname)
        while os.path.exists(fpath):
            fname = "%032x.%s" % (random.getrandbits(128), extension)
            fpath = os.path.join(self.save_dir, fname)
        return fpath

    def _save_image(
        self,
        img: Tensor,
        img_idx: str,
        img_path: str,
        center: bool = True
    ) -> None:
        """Converts a 3D tensor to a 3D scatter plot and saves it
        to disk as jpg image.
        """
        x_offset = img.shape[0] // 2 if center else 0
        y_offset = img.shape[1] // 2 if center else 0
        z_offset = img.shape[2] // 2 if center else 0

        # Convert tensor dimension IDs to coordinates
        x_coords = []
        y_coords = []
        z_coords = []
        values = []

        for x in range(img.shape[0]):
            for y in range(img.shape[1]):
                for z in range(img.shape[2]):
                    if img[x, y, z] > 0.0:
                        x_coords.append(x - x_offset)
                        y_coords.append(y - y_offset)
                        z_coords.append(z - z_offset)
                        values.append(img[x, y, z])

        # import plotly.graph_objects as go
        # normalize_intensity_by = 1
        # trace = go.Scatter3d(
        #     x=x_coords,
        #     y=y_coords,
        #     z=z_coords,
        #     mode='markers',
        #     marker_symbol='square',
        #     marker_color=[
        #         f"rgba(0,0,255,{i*100//normalize_intensity_by/10})"
        #         for i in values],
        # )
        # fig = go.Figure()
        # fig.add_trace(trace)
        # fig.write_image(img_path)

        values = np.array(values)
        # 0-1 scaling
        values = (values - values.min()) / (values.max() - values.min())
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
        ax.scatter(x_coords, y_coords, z_coords, alpha=values)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')

        # Extract energy and angle from idx
        en, ang = img_idx.split('&')
        en = en[7:]
        ang = ang[6:]
        ax.set_title(f"Energy: {en} - Angle: {ang}")
        fig.savefig(img_path)

dataloader.py

from typing import Optional
import os
from lightning.pytorch.utilities.types import EVAL_DATALOADERS

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
import lightning as pl
import glob
import h5py
import gdown

from itwinai.components import DataGetter, monitor_exec


class Lightning3DGANDownloader(DataGetter):
    def __init__(
        self,
        data_path: str,
        data_url: Optional[str] = None,
        name: Optional[str] = None,
    ) -> None:
        self.save_parameters(**self.locals2params(locals()))
        super().__init__(name)
        self.data_path = data_path
        self.data_url = data_url

    @monitor_exec
    def execute(self):
        # Download data
        if not os.path.exists(self.data_path):
            if self.data_url is None:
                print("WARNING! Data URL is None. "
                      "Skipping dataset downloading")

            gdown.download_folder(
                url=self.data_url, quiet=False,
                output=self.data_path,
                # verify=False
            )


class ParticlesDataset(Dataset):
    def __init__(self, datapath: str, max_samples: Optional[int] = None):
        self.datapath = datapath
        self.max_samples = max_samples
        self.data = dict()

        self.fetch_data()

    def __len__(self):
        return len(self.data["X"])

    def __getitem__(self, idx):
        return {"X": self.data["X"][idx], "Y": self.data["Y"][idx],
                "ang": self.data["ang"][idx], "ecal": self.data["ecal"][idx]}

    def fetch_data(self) -> None:

        print("Searching in :", self.datapath)
        files = sorted(glob.glob(os.path.join(
            self.datapath, '**/*.h5'), recursive=True))
        print("Found {} files. ".format(len(files)))
        if len(files) == 0:
            raise RuntimeError(f"No H5 files found at '{self.datapath}'!")

        # concatenated_datasets = []
        # for datafile in files:
        #     f = h5py.File(datafile, 'r')
        #     dataset = self.GetDataAngleParallel(f)
        #     concatenated_datasets.append(dataset)
        #     # Initialize result dictionary
        #     result = {key: [] for key in concatenated_datasets[0].keys()}
        #     for d in concatenated_datasets:
        #         for key in result.keys():
        #             result[key].extend(d[key])
        # return result

        for datafile in files:
            f = h5py.File(datafile, 'r')
            dataset = self.GetDataAngleParallel(f)
            for field, vals_array in dataset.items():
                if self.data.get(field) is not None:
                    # Resize to include the new array
                    new_shape = list(self.data[field].shape)
                    new_shape[0] += len(vals_array)
                    self.data[field].resize(new_shape)
                    self.data[field][-len(vals_array):] = vals_array
                else:
                    self.data[field] = vals_array

            # Stop loading data, if self.max_samples reached
            if (self.max_samples is not None
                    and len(self.data[field]) >= self.max_samples):
                for field, vals_array in self.data.items():
                    self.data[field] = vals_array[:self.max_samples]
                break

    def GetDataAngleParallel(
        self,
        dataset,
        xscale=1,
        xpower=0.85,
        yscale=100,
        angscale=1,
        angtype="theta",
        thresh=1e-4,
        daxis=-1
    ):
        """Preprocess function for the dataset

        Args:
            dataset (str): Dataset file path
            xscale (int, optional): Value to scale the ECAL values.
                Defaults to 1.
            xpower (int, optional): Value to scale the ECAL values,
                exponentially. Defaults to 1.
            yscale (int, optional): Value to scale the energy values.
                Defaults to 100.
            angscale (int, optional): Value to scale the angle values.
                Defaults to 1.
            angtype (str, optional): Which type of angle to use.
                Defaults to "theta".
            thresh (_type_, optional): Maximum value for ECAL values.
                Defaults to 1e-4.
            daxis (int, optional): Axis to expand values. Defaults to -1.

        Returns:
          Dict: Dictionary containning the preprocessed dataset
        """
        X = np.array(dataset.get("ECAL")) * xscale
        Y = np.array(dataset.get("energy")) / yscale
        X[X < thresh] = 0
        X = X.astype(np.float32)
        Y = Y.astype(np.float32)
        ecal = np.sum(X, axis=(1, 2, 3))
        indexes = np.where(ecal > 10.0)
        X = X[indexes]
        Y = Y[indexes]
        if angtype in dataset:
            ang = np.array(dataset.get(angtype))[indexes]
        # else:
        # ang = gan.measPython(X)
        X = np.expand_dims(X, axis=daxis)
        ecal = ecal[indexes]
        ecal = np.expand_dims(ecal, axis=daxis)
        if xpower != 1.0:
            X = np.power(X, xpower)

        Y = np.array([[el] for el in Y])
        ang = np.array([[el] for el in ang])
        ecal = np.array([[el] for el in ecal])

        final_dataset = {"X": X, "Y": Y, "ang": ang, "ecal": ecal}

        return final_dataset


class ParticlesDataModule(pl.LightningDataModule):
    def __init__(
            self,
            datapath: str,
            batch_size: int,
            num_workers: int = 4,
            max_samples: Optional[int] = None
    ) -> None:
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.datapath = datapath
        self.max_samples = max_samples

    def setup(self, stage: str = None):
        # make assignments here (val/train/test split)
        # called on every process in DDP

        if stage == 'fit' or stage is None:
            self.dataset = ParticlesDataset(
                self.datapath,
                max_samples=self.max_samples
            )
            dataset_length = len(self.dataset)
            split_point = int(dataset_length * 0.9)
            self.train_dataset, self.val_dataset = \
                torch.utils.data.random_split(
                    self.dataset, [split_point, dataset_length - split_point])

        if stage == 'predict':
            # TODO: inference dataset should be different in that it
            # does not contain images!
            self.predict_dataset = ParticlesDataset(
                self.datapath,
                max_samples=self.max_samples
            )

        # if stage == 'test' or stage is None:
            # self.test_dataset = MyDataset(self.data_dir, train=False)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, num_workers=self.num_workers,
                          batch_size=self.batch_size, drop_last=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, num_workers=self.num_workers,
                          batch_size=self.batch_size, drop_last=True)

    def predict_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(self.predict_dataset, num_workers=self.num_workers,
                          batch_size=self.batch_size, drop_last=True)

    # def test_dataloader(self):
        # return DataLoader(self.test_dataset, batch_size=self.batch_size)

config.yaml

This YAML file defines the pipeline configuration for the CERN use case.

# Main configurations
dataset_location: exp_data/
dataset_url: https://drive.google.com/drive/folders/1uPpz0tquokepptIfJenTzGpiENfo2xRX
hw_accelerators: auto
distributed_strategy: auto #ddp_find_unused_parameters_true
devices: auto #[0]
checkpoints_path: checkpoints
logs_dir: ml_logs
mlflow_tracking_uri: https://131.154.99.166.myip.cloud.infn.it
batch_size: 4
num_workers_dataloader: 0
max_epochs: 2
max_dataset_size: 48
random_seed: 4231162351
inference_results_location: 3dgan-generated-data/
inference_model_uri: 3dgan-inference.pth
aggregate_predictions: false

# Dataloading step is common and can be reused
dataloading_step:
  class_path: dataloader.Lightning3DGANDownloader
  init_args:
    data_path: ${dataset_location} # Set to null to skip dataset download
    data_url: ${dataset_url}

# AI workflows
training_pipeline:
  class_path: itwinai.pipeline.Pipeline
  init_args:
    steps:
      dataloading_step: ${dataloading_step}

      training_step:
        class_path: trainer.Lightning3DGANTrainer
        init_args:
          exp_root: ${logs_dir}
          # Pytorch lightning config for training
          config:
            seed_everything: ${random_seed}
            trainer:
              accelerator: ${hw_accelerators}
              accumulate_grad_batches: 1
              barebones: false
              benchmark: null
              callbacks:
                - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping
                  init_args:
                    monitor: val_generator_loss
                    patience: 2
                - class_path: lightning.pytorch.callbacks.lr_monitor.LearningRateMonitor
                  init_args:
                    logging_interval: step
                - class_path: lightning.pytorch.callbacks.ModelCheckpoint
                  init_args:
                    dirpath: ${checkpoints_path}
                    filename: best-checkpoint
                    mode: min
                    monitor: val_generator_loss
                    save_top_k: 1
                    verbose: true
              check_val_every_n_epoch: 1
              default_root_dir: null
              detect_anomaly: false
              deterministic: null
              devices: ${devices}
              enable_checkpointing: true
              enable_model_summary: null
              enable_progress_bar: null
              fast_dev_run: false
              gradient_clip_algorithm: null
              gradient_clip_val: null
              inference_mode: true
              limit_predict_batches: null
              limit_test_batches: null
              limit_train_batches: null
              limit_val_batches: null
              log_every_n_steps: 1
              logger:
                - class_path: lightning.pytorch.loggers.CSVLogger
                  init_args:
                    name: 3DGAN
                    save_dir: ${logs_dir}
                - class_path: lightning.pytorch.loggers.MLFlowLogger
                  init_args:
                    experiment_name: 3DGAN
                    save_dir: null #ml_logs/mlflow_logs
                    tracking_uri: ${mlflow_tracking_uri}
                    log_model: all
              max_epochs: ${max_epochs}
              max_time: null
              min_epochs: null
              min_steps: null
              num_sanity_val_steps: null
              overfit_batches: 0.0
              plugins: null
              profiler: null
              reload_dataloaders_every_n_epochs: 0
              strategy: ${distributed_strategy}
              sync_batchnorm: false
              use_distributed_sampler: true
              val_check_interval: null

            # Lightning Model configuration
            model:
              class_path: model.ThreeDGAN
              init_args:
                latent_size: 256
                loss_weights: [3, 0.1, 25, 0.1]
                power: 0.85
                lr: 0.001
                checkpoints_dir: ${checkpoints_path}

            # Lightning data module configuration
            data:
              class_path: dataloader.ParticlesDataModule
              init_args:
                datapath: ${dataset_location}
                batch_size: ${batch_size}
                num_workers: ${num_workers_dataloader}
                max_samples: ${max_dataset_size}

inference_pipeline:
  class_path: itwinai.pipeline.Pipeline
  init_args:
    steps:
      dataloading_step: ${dataloading_step}

      inference_step:
        class_path: trainer.Lightning3DGANPredictor
        init_args:
          model:
            class_path: trainer.LightningModelLoader
            init_args:
              model_uri: ${inference_model_uri}

          # Pytorch lightning config for training
          config:
            seed_everything: ${random_seed}
            trainer:
              accelerator: ${hw_accelerators}
              accumulate_grad_batches: 1
              barebones: false
              benchmark: null
              check_val_every_n_epoch: 1
              default_root_dir: null
              detect_anomaly: false
              deterministic: null
              devices: ${devices}
              enable_checkpointing: true
              enable_model_summary: null
              enable_progress_bar: null
              fast_dev_run: false
              gradient_clip_algorithm: null
              gradient_clip_val: null
              inference_mode: true
              limit_predict_batches: null
              limit_test_batches: null
              limit_train_batches: null
              limit_val_batches: null
              log_every_n_steps: 2
              logger: 
                # - class_path: lightning.pytorch.loggers.CSVLogger
                #   init_args:
                #     save_dir: ml_logs/csv_logs
                class_path: lightning.pytorch.loggers.MLFlowLogger
                init_args:
                  experiment_name: 3DGAN
                  save_dir: ${logs_dir}
                  log_model: all
              max_epochs: ${max_epochs}
              max_steps: 20
              max_time: null
              min_epochs: null
              min_steps: null
              num_sanity_val_steps: null
              overfit_batches: 0.0
              plugins: null
              profiler: null
              reload_dataloaders_every_n_epochs: 0
              strategy: ${distributed_strategy}
              sync_batchnorm: false
              use_distributed_sampler: true
              val_check_interval: null

            # Lightning Model configuration
            model:
              class_path: model.ThreeDGAN
              init_args:
                latent_size: 256
                loss_weights: [3, 0.1, 25, 0.1]
                power: 0.85
                lr: 0.001
                checkpoints_dir: ${checkpoints_path}

            # Lightning data module configuration
            data:
              class_path: dataloader.ParticlesDataModule
              init_args:
                datapath: ${dataset_location}
                batch_size: ${batch_size} #1024
                num_workers: ${num_workers_dataloader} #4
                max_samples: ${max_dataset_size} #null, 10000

      saver_step:
        class_path: saver.ParticleImagesSaver
        init_args:
          save_dir: ${inference_results_location}
          aggregate_predictions: ${aggregate_predictions}

create_inference_sample.py

This file defines a pipeline configuration for the CERN use case inference.

"""Create a simple inference dataset sample and a checkpoint."""

import argparse
import os
import torch
from model import ThreeDGAN


def create_checkpoint(
    root: str = '.',
    ckpt_name: str = "3dgan-inference.pth"
):
    ckpt_path = os.path.join(root, ckpt_name)
    net = ThreeDGAN()
    torch.save(net, ckpt_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--root", type=str, default='.')
    parser.add_argument("--ckpt-name", type=str, default="3dgan-inference.pth")
    args = parser.parse_args()
    create_checkpoint(**vars(args))

Dockerfile

FROM nvcr.io/nvidia/pytorch:23.09-py3
# FROM python:3.11

WORKDIR /usr/src/app

# Install itwinai
COPY pyproject.toml ./
COPY src ./
RUN pip install --upgrade pip \
    && pip install --no-cache-dir lightning \
    && pip install --no-cache-dir .

# Add 3DGAN use case files and install additional requirements
COPY use-cases/3dgan/requirements.txt ./
COPY use-cases/3dgan/* ./
RUN pip install --no-cache-dir -r requirements.txt

# Create non-root user
RUN groupadd -g 10001 dotnet \
    && useradd -m -u 10000 -g dotnet dotnet \
    && chown -R dotnet:dotnet /usr/src/app
USER dotnet:dotnet

# ENTRYPOINT [ "itwinai", "exec-pipeline" ]
# CMD [ "--config", "pipeline.yaml" ]

startscript

#!/bin/bash

# general configuration of the job
#SBATCH --job-name=PrototypeTest
#SBATCH --account=intertwin
#SBATCH --mail-user=
#SBATCH --mail-type=ALL
#SBATCH --output=job.out
#SBATCH --error=job.err
#SBATCH --time=00:30:00

# configure node and process count on the CM
#SBATCH --partition=batch
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=4
#SBATCH --gpus-per-node=4

#SBATCH --exclusive

# gres options have to be disabled for deepv
#SBATCH --gres=gpu:4

# load modules
ml --force purge
ml Stages/2023 StdEnv/2023 NVHPC/23.1 OpenMPI/4.1.4 cuDNN/8.6.0.163-CUDA-11.7 Python/3.10.4 HDF5 libaio/0.3.112 GCC/11.3.0

# shellcheck source=/dev/null
source ~/.bashrc

# ON LOGIN NODE download datasets:
# ../../.venv-pytorch/bin/itwinai exec-pipeline --config pipeline.yaml --pipe-key training_pipeline --steps dataloading_step
source ../../.venv-pytorch/bin/activate
srun itwinai exec-pipeline --config pipeline.yaml --pipe-key training_pipeline

This section covers the CERN use case integration with interLink using itwinai. The following files are integral to this use case: