Virgo

The code is adapted from this notebook available on the Virgo use case’s repository.

To know more on the interTwin Virgo Noise detector use case and its DT, please visit the published deliverables, D4.2, D7.2 and D7.4.

Installation

Before continuing, install the required libraries in the pre-existing itwinai environment.

pip install -r requirements.txt

Training

You can run the whole pipeline in one shot, including dataset generation, or you can execute it from the second step (after the synthetic dataset have been generated).

itwinai exec-pipeline --config config.yaml --pipe-key training_pipeline

# Run from the second step (use python-like slicing syntax).
# In this case, the dataset is loaded from "data/Image_dataset_synthetic_64x64.pkl"
itwinai exec-pipeline --config config.yaml --pipe-key training_pipeline --steps 1:

Launch distributed training with SLURM using the dedicated slurm.sh job script:

# Distributed training with torch DistributedDataParallel
PYTHON_VENV="../../envAI_hdfml"
DIST_MODE="ddp"
RUN_NAME="ddp-virgo"
TRAINING_CMD="$PYTHON_VENV/bin/itwinai exec-pipeline --config config.yaml --steps 1: --pipe-key training_pipeline -o strategy=ddp"
sbatch --export=ALL,DIST_MODE="$DIST_MODE",RUN_NAME="$RUN_NAME",TRAINING_CMD="$TRAINING_CMD",PYTHON_VENV="$PYTHON_VENV" \
    slurm.sh

…and check the results in job.out and job.err log files.

To understand how to use all the distributed strategies supported by itwinai, check the content of runall.sh:

bash runall.sh

> [!WARNING] > The file train.py is not to be considered the suggested way to launch training, > as it is deprecated and is there to testify an intermediate integration step > of the use case into itwinai.

When using MLFLow logger, you can visualize the logs in from the MLFlow UI:

mlflow ui --backend-store-uri mllogs/mlflow

# In background
mlflow ui --backend-store-uri mllogs/mlflow > /dev/null 2>&1 &

config.yaml

# General configuration
data_root: data
epochs: 2
batch_size: 20
strategy: ddp
checkpoint_path: checkpoints/epoch_{}.pth

training_pipeline:
  class_path: itwinai.pipeline.Pipeline
  init_args:
    steps:
      - class_path: data.TimeSeriesDatasetGenerator
        init_args:
          data_root: ${data_root}
      - class_path: data.TimeSeriesDatasetSplitter
        init_args:
          train_proportion: 0.9
          rnd_seed: 42
          images_dataset: ${data_root}/Image_dataset_synthetic_64x64.pkl
      - class_path: data.TimeSeriesProcessor
      - class_path: trainer.NoiseGeneratorTrainer
        init_args:
          generator: simple #unet
          batch_size: ${batch_size}
          num_epochs: ${epochs}
          strategy: ${strategy}
          checkpoint_path: ${checkpoint_path}
          random_seed: 17
          logger:
            class_path: itwinai.loggers.LoggersCollection
            init_args:
              loggers:
                - class_path: itwinai.loggers.ConsoleLogger
                  init_args:
                    log_freq: 100
                - class_path: itwinai.loggers.MLFlowLogger
                  init_args:
                    experiment_name: Noise simulator (Virgo)
                    log_freq: batch 
                - class_path: itwinai.loggers.WandBLogger
                  init_args:
                    log_freq: batch
          

data.py

from typing import Optional, Tuple, Any
import os
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from itwinai.components import (
    DataGetter, DataProcessor, DataSplitter, monitor_exec
)

from src.dataset import (
    generate_dataset_aux_channels,
    generate_dataset_main_channel,
    generate_cut_image_dataset,
    normalize_
)


class TimeSeriesDatasetGenerator(DataGetter):
    # TODO: move configuration to the constructor.
    def __init__(
        self,
        data_root: str = "data",
        name: Optional[str] = None
    ) -> None:
        super().__init__(name)
        self.save_parameters(**self.locals2params(locals()))
        self.data_root = data_root
        if not os.path.exists(data_root):
            os.makedirs(data_root, exist_ok=True)

    @monitor_exec
    def execute(self) -> pd.DataFrame:
        """Generate a time-series dataset, convert it to Q-plots,
        save it to disk, and return it.

        Returns:
            pd.DataFrame: dataset of Q-plot images.
        """
        df_aux_ts = generate_dataset_aux_channels(
            1000, 3, duration=16, sample_rate=500,
            num_waves_range=(20, 25), noise_amplitude=0.6
        )
        df_main_ts = generate_dataset_main_channel(
            df_aux_ts, weights=None, noise_amplitude=0.1
        )

        # save datasets
        save_name_main = 'TimeSeries_dataset_synthetic_main.pkl'
        save_name_aux = 'TimeSeries_dataset_synthetic_aux.pkl'
        df_main_ts.to_pickle(os.path.join(self.data_root, save_name_main))
        df_aux_ts.to_pickle(os.path.join(self.data_root, save_name_aux))

        # Transform to images and save to disk
        df_ts = pd.concat([df_main_ts, df_aux_ts], axis=1)
        df = generate_cut_image_dataset(
            df_ts, list(df_ts.columns),
            num_processes=20, square_size=64
        )
        save_name = 'Image_dataset_synthetic_64x64.pkl'
        df.to_pickle(os.path.join(self.data_root, save_name))
        return df


class TimeSeriesDatasetSplitter(DataSplitter):
    def __init__(
        self,
        train_proportion: int | float,
        validation_proportion: int | float = 0.0,
        test_proportion: int | float = 0.0,
        rnd_seed: int | None = None,
        images_dataset: str = "data/Image_dataset_synthetic_64x64.pkl",
        name: str | None = None
    ) -> None:
        super().__init__(
            train_proportion, validation_proportion,
            test_proportion, name
        )
        self.save_parameters(**self.locals2params(locals()))
        self.validation_proportion = 1-train_proportion
        self.rnd_seed = rnd_seed
        self.images_dataset = images_dataset

    def get_or_load(self, dataset: Optional[pd.DataFrame] = None):
        """If the dataset is not given, load it from disk."""
        if dataset is None:
            print("WARNING: loading time series dataset from disk.")
            return pd.read_pickle(self.images_dataset)
        return dataset

    @monitor_exec
    def execute(
        self,
        dataset: Optional[pd.DataFrame] = None
    ) -> Tuple:
        """Splits a dataset into train, validation and test splits.

        Args:
            dataset (pd.DataFrame): input dataset.

        Returns:
            Tuple: tuple of train, validation and test splits. Test is None.
        """
        dataset = self.get_or_load(dataset)

        # Convert data to torch
        df = dataset.applymap(lambda x: torch.tensor(x))

        # Divide Image dataset in main and aux channels. Note that df
        # generated in the section Generate Synthetic Dataset will always have
        # the main channel as its first column
        main_channel = list(df.columns)[0]
        aux_channels = list(df.columns)[1:]

        df_aux_all_2d = pd.DataFrame(df[aux_channels])
        df_main_all_2d = pd.DataFrame(df[main_channel])
        X_train_2d, X_test_2d, y_train_2d, y_test_2d = train_test_split(
            df_aux_all_2d, df_main_all_2d,
            test_size=self.validation_proportion, random_state=self.rnd_seed)
        return (X_train_2d, y_train_2d), (X_test_2d, y_test_2d), None


class TimeSeriesProcessor(DataProcessor):
    def __init__(self, name: str | None = None) -> None:
        super().__init__(name)
        self.save_parameters(**self.locals2params(locals()))

    @monitor_exec
    def execute(
        self,
        train_dataset: Tuple,
        validation_dataset: Tuple,
        test_dataset: Any = None
    ) -> Tuple[torch.Tensor, torch.Tensor, None]:
        """Pre-process datasets: rearrange and normalize before training.

        Args:
            train_dataset (Tuple): training dataset.
            validation_dataset (Tuple): validation dataset.
            test_dataset (Any, optional): unused placeholder. Defaults to None.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, None]: train, validation, and
                test (placeholder) datasets. Ready to be used for training.
        """
        X_train_2d, y_train_2d = train_dataset
        X_test_2d, y_test_2d = validation_dataset

        # Name of the main channel (assuming it's in position 0)
        main_channel = list(y_train_2d.columns)[0]

        # TRAINING SET

        # # smaller dataset
        # signal_data_train_small_2d = torch.stack([
        #     torch.stack([y_train_2d[main_channel].iloc[i]])
        #     for i in range(100)
        # ])  # for i in range(y_train.shape[0])
        # aux_data_train_small_2d = torch.stack([
        #     torch.stack([X_train_2d.iloc[i][0], X_train_2d.iloc[i]
        #                 [1], X_train_2d.iloc[i][2]])
        #     for i in range(100)
        # ])  # for i in range(X_train.shape[0])

        # whole dataset
        signal_data_train_2d = torch.stack([
            torch.stack([y_train_2d[main_channel].iloc[i]])
            for i in range(y_train_2d.shape[0])
        ])
        aux_data_train_2d = torch.stack([
            torch.stack(
                [X_train_2d.iloc[i][0], X_train_2d.iloc[i][1],
                 X_train_2d.iloc[i][2]])
            for i in range(X_train_2d.shape[0])
        ])

        # concatenate torch.tensors
        train_data_2d = torch.cat(
            [signal_data_train_2d, aux_data_train_2d], dim=1)
        # train_data_small_2d = torch.cat(
        #     [signal_data_train_small_2d, aux_data_train_small_2d], dim=1)

        # VALIDATION SET

        # # smaller dataset
        # signal_data_test_small_2d = torch.stack([
        #     torch.stack(
        #         [y_test_2d[main_channel].iloc[i]])
        #     for i in range(100)
        # ])  # for i in range(y_test.shape[0])
        # aux_data_test_small_2d = torch.stack([
        #     torch.stack(
        #         [X_test_2d.iloc[i][0], X_test_2d.iloc[i][1],
        #          X_test_2d.iloc[i][2]])
        #     for i in range(100)
        # ])  # for i in range(X_test.shape[0])

        # whole dataset
        signal_data_test_2d = torch.stack([
            torch.stack(
                [y_test_2d[main_channel].iloc[i]])
            for i in range(y_test_2d.shape[0])
        ])
        aux_data_test_2d = torch.stack([
            torch.stack(
                [X_test_2d.iloc[i][0], X_test_2d.iloc[i][1],
                 X_test_2d.iloc[i][2]])
            for i in range(X_test_2d.shape[0])
        ])

        test_data_2d = torch.cat(
            [signal_data_test_2d, aux_data_test_2d], dim=1)
        # test_data_small_2d = torch.cat(
        #     [signal_data_test_small_2d, aux_data_test_small_2d], dim=1)

        # NORMALIZE
        train_data_2d = normalize_(train_data_2d)
        test_data_2d = normalize_(test_data_2d)

        return train_data_2d, test_data_2d, None

runall.sh

#!/bin/bash

# Python virtual environment (no conda/micromamba)
PYTHON_VENV="../../envAI_hdfml"

# Clear SLURM logs (*.out and *.err files)
rm -rf logs_slurm checkpoints*
mkdir logs_slurm
rm -rf logs_torchrun

# DDP itwinai
DIST_MODE="ddp"
RUN_NAME="ddp-itwinai"
TRAINING_CMD="$PYTHON_VENV/bin/itwinai exec-pipeline --config config.yaml --pipe-key training_pipeline --steps 1: -o strategy=ddp -o checkpoint_path=checkpoints_ddp/epoch_{}.pth"
sbatch --export=ALL,DIST_MODE="$DIST_MODE",RUN_NAME="$RUN_NAME",TRAINING_CMD="$TRAINING_CMD",PYTHON_VENV="$PYTHON_VENV" \
    --job-name="$RUN_NAME-n$N" \
    --output="logs_slurm/job-$RUN_NAME-n$N.out" \
    --error="logs_slurm/job-$RUN_NAME-n$N.err" \
    slurm.sh

# DeepSpeed itwinai
DIST_MODE="deepspeed"
RUN_NAME="deepspeed-itwinai"
TRAINING_CMD="$PYTHON_VENV/bin/itwinai exec-pipeline --config config.yaml --pipe-key training_pipeline --steps 1: -o strategy=deepspeed -o checkpoint_path=checkpoints_deepspeed/epoch_{}.pth"
sbatch --export=ALL,DIST_MODE="$DIST_MODE",RUN_NAME="$RUN_NAME",TRAINING_CMD="$TRAINING_CMD",PYTHON_VENV="$PYTHON_VENV" \
    --job-name="$RUN_NAME-n$N" \
    --output="logs_slurm/job-$RUN_NAME-n$N.out" \
    --error="logs_slurm/job-$RUN_NAME-n$N.err" \
    slurm.sh

# Horovod itwinai
DIST_MODE="horovod"
RUN_NAME="horovod-itwinai"
TRAINING_CMD="$PYTHON_VENV/bin/itwinai exec-pipeline --config config.yaml --pipe-key training_pipeline --steps 1: -o strategy=horovod -o checkpoint_path=checkpoints_horovod/epoch_{}.pth"
sbatch --export=ALL,DIST_MODE="$DIST_MODE",RUN_NAME="$RUN_NAME",TRAINING_CMD="$TRAINING_CMD",PYTHON_VENV="$PYTHON_VENV" \
    --job-name="$RUN_NAME-n$N" \
    --output="logs_slurm/job-$RUN_NAME-n$N.out" \
    --error="logs_slurm/job-$RUN_NAME-n$N.err" \
    slurm.sh

slurm.sh

#!/bin/bash

# SLURM jobscript for JSC systems

# Job configuration
#SBATCH --job-name=distributed_training
#SBATCH --account=intertwin
#SBATCH --mail-user=
#SBATCH --mail-type=ALL
#SBATCH --output=job.out
#SBATCH --error=job.err
#SBATCH --time=00:30:00

# Resources allocation
#SBATCH --partition=batch
#SBATCH --nodes=2
#SBATCH --gpus-per-node=4
#SBATCH --cpus-per-gpu=4
#SBATCH --exclusive

# gres options have to be disabled for deepv
#SBATCH --gres=gpu:4

# Load environment modules
ml Stages/2024 GCC OpenMPI CUDA/12 MPI-settings/CUDA Python HDF5 PnetCDF libaio mpi4py

# Job info
echo "DEBUG: TIME: $(date)"
sysN="$(uname -n | cut -f2- -d.)"
sysN="${sysN%%[0-9]*}"
echo "Running on system: $sysN"
echo "DEBUG: EXECUTE: $EXEC"
echo "DEBUG: SLURM_SUBMIT_DIR: $SLURM_SUBMIT_DIR"
echo "DEBUG: SLURM_JOB_ID: $SLURM_JOB_ID"
echo "DEBUG: SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST"
echo "DEBUG: SLURM_NNODES: $SLURM_NNODES"
echo "DEBUG: SLURM_NTASKS: $SLURM_NTASKS"
echo "DEBUG: SLURM_TASKS_PER_NODE: $SLURM_TASKS_PER_NODE"
echo "DEBUG: SLURM_SUBMIT_HOST: $SLURM_SUBMIT_HOST"
echo "DEBUG: SLURMD_NODENAME: $SLURMD_NODENAME"
echo "DEBUG: CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
if [ "$DEBUG" = true ] ; then
  echo "DEBUG: NCCL_DEBUG=INFO" 
  export NCCL_DEBUG=INFO
fi
echo

# Setup env for distributed ML
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export OMP_NUM_THREADS=1
if [ "$SLURM_CPUS_PER_GPU" -gt 0 ] ; then
  export OMP_NUM_THREADS=$SLURM_CPUS_PER_GPU
fi

# Env vairables check
if [ -z "$DIST_MODE" ]; then 
  >&2 echo "ERROR: env variable DIST_MODE is not set. Allowed values are 'horovod', 'ddp' or 'deepspeed'"
  exit 1
fi
if [ -z "$RUN_NAME" ]; then 
  >&2 echo "WARNING: env variable RUN_NAME is not set. It's a way to identify some specific run of an experiment."
  RUN_NAME=$DIST_MODE
fi
if [ -z "$TRAINING_CMD" ]; then 
  >&2 echo "ERROR: env variable TRAINING_CMD is not set. It's the python command to execute."
  exit 1
fi
if [ -z "$PYTHON_VENV" ]; then 
  >&2 echo "WARNING: env variable PYTHON_VENV is not set. It's the path to a python virtual environment."
else
  # Activate Python virtual env
  source $PYTHON_VENV/bin/activate
fi

# Get GPUs info per node
srun --cpu-bind=none --ntasks-per-node=1 bash -c 'echo -e "NODE hostname: $(hostname)\n$(nvidia-smi)\n\n"'

# Launch training
if [ "$DIST_MODE" == "ddp" ] ; then
  echo "DDP training: $TRAINING_CMD"
  srun --cpu-bind=none --ntasks-per-node=1 \
    bash -c "torchrun \
    --log_dir='logs_torchrun' \
    --nnodes=$SLURM_NNODES \
    --nproc_per_node=$SLURM_GPUS_PER_NODE \
    --rdzv_id=$SLURM_JOB_ID \
    --rdzv_conf=is_host=\$(((SLURM_NODEID)) && echo 0 || echo 1) \
    --rdzv_backend=c10d \
    --rdzv_endpoint='$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)'i:29500 \
    $TRAINING_CMD"
elif [ "$DIST_MODE" == "deepspeed" ] ; then
  echo "DEEPSPEED training: $TRAINING_CMD"
  MASTER_ADDR=$(scontrol show hostnames "\$SLURM_JOB_NODELIST" | head -n 1)i
  export MASTER_ADDR
  export MASTER_PORT=29500 

  srun --cpu-bind=none --ntasks-per-node=$SLURM_GPUS_PER_NODE --cpus-per-task=$SLURM_CPUS_PER_GPU \
    $TRAINING_CMD

  # # Run with deepspeed launcher: set --ntasks-per-node=1
  # # https://www.deepspeed.ai/getting-started/#multi-node-environment-variables
  # export NCCL_IB_DISABLE=1
  # export NCCL_SOCKET_IFNAME=eth0
  # nodelist=$(scontrol show hostname $SLURM_NODELIST)
  # echo "$nodelist" | sed -e 's/$/ slots=4/' > .hostfile
  # # Requires passwordless SSH access among compute node
  # srun --cpu-bind=none deepspeed --hostfile=.hostfile $TRAINING_CMD --deepspeed
  # rm .hostfile
elif [ "$DIST_MODE" == "horovod" ] ; then
  echo "HOROVOD training: $TRAINING_CMD"
  srun --cpu-bind=none --ntasks-per-node=$SLURM_GPUS_PER_NODE --cpus-per-task=$SLURM_CPUS_PER_GPU \
    $TRAINING_CMD
else
  >&2 echo "ERROR: unrecognized \$DIST_MODE env variable"
  exit 1
fi

trainer.py

from typing import Literal, Optional
import os
import torch.nn as nn
import torch
import time
import numpy as np

from itwinai.torch.trainer import TorchTrainer
from itwinai.torch.distributed import (
    DeepSpeedStrategy,
)
from itwinai.torch.config import TrainingConfiguration
from itwinai.loggers import Logger

from src.model import Decoder, Decoder_2d_deep, UNet, GeneratorResNet
from src.utils import init_weights


from tqdm import tqdm


class NoiseGeneratorTrainer(TorchTrainer):

    def __init__(
        self,
        batch_size: int,
        learning_rate: float = 1e-3,
        num_epochs: int = 2,
        generator: Literal["simple", "deep", "resnet", "unet"] = "unet",
        loss: Literal["L1", "L2"] = "L1",
        strategy: Literal["ddp", "deepspeed", "horovod"] = 'ddp',
        checkpoint_path: str = "checkpoints/epoch_{}.pth",
        save_best: bool = True,
        logger: Optional[Logger] = None,
        random_seed: Optional[int] = None,
        name: str | None = None
    ) -> None:
        super().__init__(
            epochs=num_epochs,
            config={},
            strategy=strategy,
            logger=logger,
            random_seed=random_seed,
            name=name
        )
        self.save_parameters(**self.locals2params(locals()))
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self._generator = generator
        self._loss = loss
        self.checkpoint_path = checkpoint_path
        os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
        # Global training configuration
        self.config = TrainingConfiguration(
            batch_size=batch_size,
            save_best=save_best,
            shuffle_train=True
        )

    def create_model_loss_optimizer(self) -> None:
        # Select generator
        generator = self._generator.lower()
        if generator == "simple":
            self.model = Decoder(3, norm=False)
            init_weights(self.model, 'normal', scaling=.02)
        elif generator == "deep":
            self.model = Decoder_2d_deep(3)
            init_weights(self.model, 'normal', scaling=.02)
        elif generator == "resnet":
            self.model = GeneratorResNet(3, 12, 1)
            init_weights(self.model, 'normal', scaling=.01)
        elif generator == "unet":
            self.model = UNet(
                input_channels=3, output_channels=1, norm=False)
            init_weights(self.model, 'normal', scaling=.02)
        else:
            raise ValueError("Unrecognized generator type! Got", generator)

        # Select loss
        loss = self._loss.upper()
        if loss == "L1":
            self.loss = nn.L1Loss()
        elif loss == "L2":
            self.loss = nn.MSELoss()
        else:
            raise ValueError("Unrecognized loss type! Got", loss)

        # Optimizer
        self.optimizer = torch.optim.Adam(
            self.model.parameters(), lr=self.learning_rate)

        # IMPORTANT: model, optimizer, and scheduler need to be distributed

        # First, define strategy-wise optional configurations
        if isinstance(self.strategy, DeepSpeedStrategy):
            # Batch size definition is not optional for DeepSpeedStrategy!
            distribute_kwargs = dict(
                config_params=dict(
                    train_micro_batch_size_per_gpu=self.config.batch_size
                )
            )
        else:
            distribute_kwargs = {}

        # Distributed model, optimizer, and scheduler
        self.model, self.optimizer, _ = self.strategy.distributed(
            self.model, self.optimizer, **distribute_kwargs
        )

    def train(self):
        # uncomment all lines relative to accuracy if you want to measure
        # IOU between generated and real spectrograms.
        # Note that it significantly slows down the whole process
        # it also might not work as the function has not been fully
        # implemented yet

        loss_plot = []
        val_loss_plot = []
        acc_plot = []
        val_acc_plot = []
        best_val_loss = float('inf')
        for epoch in tqdm(range(self.num_epochs)):
            # itwinai - IMPORTANT: set current epoch ID
            self.set_epoch(epoch)

            st = time.time()
            epoch_loss = []
            # epoch_acc = []
            for i, batch in enumerate(self.train_dataloader):
                # batch= transform(batch)
                target = batch[:, 0].unsqueeze(1).to(self.device)
                # print(f'TARGET ON DEVICE: {target.get_device()}')
                target = target.float()
                input = batch[:, 1:].to(self.device)
                # print(f'INPUT ON DEVICE: {input.get_device()}')

                self.optimizer.zero_grad()
                generated = self.model(input.float())
                # generated=normalize_(generated,1)
                loss = self.loss(generated, target)
                loss.backward()
                self.optimizer.step()
                epoch_loss.append(loss.detach().cpu().numpy())
                # itwinai - log loss as metric
                self.log(loss.detach().cpu().numpy(),
                         'epoch_loss_batch',
                         kind='metric',
                         step=epoch*len(self.train_dataloader) + i,
                         batch_idx=i)
                # acc=accuracy(generated.detach().cpu().numpy(),target.detach().cpu().numpy(),20)
                # epoch_acc.append(acc)
            val_loss = []
            # val_acc = []
            for i, batch in enumerate(self.validation_dataloader):
                # batch= transform(batch)
                target = batch[:, 0].unsqueeze(1).to(self.device)
                target = target.float()
                input = batch[:, 1:].to(self.device)
                with torch.no_grad():
                    generated = self.model(input.float())
                    # generated=normalize_(generated,1)
                    loss = self.loss(generated, target)
                val_loss.append(loss.detach().cpu().numpy())
                # itwinai -log loss as metric
                self.log(loss.detach().cpu().numpy(),
                         'val_loss_batch',
                         kind='metric',
                         step=epoch*len(self.validation_dataloader) + i,
                         batch_idx=i)
                # acc=accuracy(generated.detach().cpu().numpy(),target.detach().cpu().numpy(),20)
                # val_acc.append(acc)
            loss_plot.append(np.mean(epoch_loss))
            val_loss_plot.append(np.mean(val_loss))
            # acc_plot.append(np.mean(epoch_acc))
            # val_acc_plot.append(np.mean(val_acc))

            # itwinai - Log metrics/losses
            self.log(np.mean(epoch_loss), 'epoch_loss',
                     kind='metric', step=epoch)
            self.log(np.mean(val_loss), 'val_loss',
                     kind='metric', step=epoch)
            # self.log(np.mean(epoch_acc), 'epoch_acc',
            #          kind='metric', step=epoch)
            # self.log(np.mean(val_acc), 'val_acc',
            #          kind='metric', step=epoch)

            # print('epoch: {} loss: {} val loss: {} accuracy: {} val
            # accuracy: {}'.format(epoch,loss_plot[-1],val_loss_plot[-1],
            # acc_plot[-1],val_acc_plot[-1]))
            et = time.time()
            # itwinai - print() in a multi-worker context (distributed)
            if self.strategy.is_main_worker:
                print('epoch: {} loss: {} val loss: {} time:{}s'.format(
                    epoch, loss_plot[-1], val_loss_plot[-1], et-st))

            # Save checkpoint every 100 epochs
            if epoch % 1 == 0:
                # uncomment the following if you want to save checkpoint every
                # 100 epochs regardless of the performance of the model
                # checkpoint = {
                #     'epoch': epoch,
                #     'model_state_dict': generator.state_dict(),
                #     'optim_state_dict': optimizer.state_dict(),
                #     'loss': loss_plot[-1],
                #     'val_loss': val_loss_plot[-1],
                # }
                # if self.strategy.is_main_worker:
                #     # Save only in the main worker
                #     checkpoint_filename = checkpoint_path.format(epoch)
                #     torch.save(checkpoint, checkpoint_filename)

                # Average loss among all workers
                # itwinai - gather local loss from all the workers
                worker_val_losses = self.strategy.gather_obj(val_loss_plot[-1])
                if self.strategy.is_main_worker:
                    # Save only in the main worker

                    # avg_loss has a meaning only in the main worker
                    avg_loss = np.mean(worker_val_losses)

                    # instead of val_loss and best_val loss we should
                    # use accuracy!!!
                    if self.config.save_best and avg_loss < best_val_loss:
                        # create checkpoint
                        checkpoint = {
                            'epoch': epoch,
                            'model_state_dict': self.model.state_dict(),
                            'optim_state_dict': self.optimizer.state_dict(),
                            'loss': loss_plot[-1],
                            'val_loss': val_loss_plot[-1],
                        }

                        # save checkpoint only if it is better than
                        # the previous ones
                        checkpoint_filename = self.checkpoint_path.format(
                            epoch)
                        torch.save(checkpoint, checkpoint_filename)
                        # itwinai - log checkpoint as artifact
                        self.log(checkpoint_filename,
                                 os.path.basename(checkpoint_filename),
                                 kind='artifact')

                        # update best model
                        best_val_loss = val_loss_plot[-1]
                        best_checkpoint_filename = (
                            self.checkpoint_path.format('best')
                        )
                        torch.save(checkpoint, best_checkpoint_filename)
                        # itwinai - log checkpoint as artifact
                        self.log(best_checkpoint_filename,
                                 os.path.basename(best_checkpoint_filename),
                                 kind='artifact')
        # return (loss_plot, val_loss_plot,
        # acc_plot, val_acc_plot ,acc_plot, val_acc_plot)
        return loss_plot, val_loss_plot, acc_plot, val_acc_plot