Virgoο
The code is adapted from this notebook available on the Virgo use caseβs repository.
To know more on the interTwin Virgo Noise detector use case and its DT, please visit the published deliverables, D4.2, D7.2 and D7.4.
Installationο
Before continuing, install the required libraries in the pre-existing itwinai environment.
pip install -r requirements.txt
Trainingο
You can run the whole pipeline in one shot, including dataset generation, or you can execute it from the second step (after the synthetic dataset have been generated).
itwinai exec-pipeline --config config.yaml --pipe-key training_pipeline
# Run from the second step (use python-like slicing syntax).
# In this case, the dataset is loaded from "data/Image_dataset_synthetic_64x64.pkl"
itwinai exec-pipeline --config config.yaml --pipe-key training_pipeline --steps 1:
Launch distributed training with SLURM using the dedicated slurm.sh job script:
# Distributed training with torch DistributedDataParallel
PYTHON_VENV="../../envAI_hdfml"
DIST_MODE="ddp"
RUN_NAME="ddp-virgo"
TRAINING_CMD="$PYTHON_VENV/bin/itwinai exec-pipeline --config config.yaml --steps 1: --pipe-key training_pipeline -o strategy=ddp"
sbatch --export=ALL,DIST_MODE="$DIST_MODE",RUN_NAME="$RUN_NAME",TRAINING_CMD="$TRAINING_CMD",PYTHON_VENV="$PYTHON_VENV" \
slurm.sh
β¦and check the results in job.out and job.err log files.
To understand how to use all the distributed strategies supported by itwinai,
check the content of runall.sh:
bash runall.sh
> [!WARNING]
> The file train.py is not to be considered the suggested way to launch training,
> as it is deprecated and is there to testify an intermediate integration step
> of the use case into itwinai.
When using MLFLow logger, you can visualize the logs in from the MLFlow UI:
mlflow ui --backend-store-uri mllogs/mlflow
# In background
mlflow ui --backend-store-uri mllogs/mlflow > /dev/null 2>&1 &
config.yamlο
# General configuration
data_root: data
epochs: 2
batch_size: 20
strategy: ddp
checkpoint_path: checkpoints/epoch_{}.pth
training_pipeline:
class_path: itwinai.pipeline.Pipeline
init_args:
steps:
- class_path: data.TimeSeriesDatasetGenerator
init_args:
data_root: ${data_root}
- class_path: data.TimeSeriesDatasetSplitter
init_args:
train_proportion: 0.9
rnd_seed: 42
images_dataset: data/Image_dataset_synthetic_64x64.pkl
- class_path: data.TimeSeriesProcessor
- class_path: trainer.NoiseGeneratorTrainer
init_args:
generator: unet
batch_size: ${batch_size}
num_epochs: ${epochs}
strategy: ${strategy}
checkpoint_path: ${checkpoint_path}
logger:
class_path: itwinai.loggers.LoggersCollection
init_args:
loggers:
- class_path: itwinai.loggers.ConsoleLogger
init_args:
log_freq: 100
- class_path: itwinai.loggers.MLFlowLogger
init_args:
experiment_name: Noise simulator (Virgo)
log_freq: batch
data.pyο
from typing import Optional, Tuple, Any
import os
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from itwinai.components import (
DataGetter, DataProcessor, DataSplitter, monitor_exec
)
from src.dataset import (
generate_dataset_aux_channels,
generate_dataset_main_channel,
generate_cut_image_dataset,
normalize_
)
class TimeSeriesDatasetGenerator(DataGetter):
# TODO: move configuration to the constructor.
def __init__(
self,
data_root: str = "data",
name: Optional[str] = None
) -> None:
super().__init__(name)
self.save_parameters(**self.locals2params(locals()))
self.data_root = data_root
if not os.path.exists(data_root):
os.makedirs(data_root, exist_ok=True)
@monitor_exec
def execute(self) -> pd.DataFrame:
"""Generate a time-series dataset, convert it to Q-plots,
save it to disk, and return it.
Returns:
pd.DataFrame: dataset of Q-plot images.
"""
df_aux_ts = generate_dataset_aux_channels(
1000, 3, duration=16, sample_rate=500,
num_waves_range=(20, 25), noise_amplitude=0.6
)
df_main_ts = generate_dataset_main_channel(
df_aux_ts, weights=None, noise_amplitude=0.1
)
# save datasets
save_name_main = 'TimeSeries_dataset_synthetic_main.pkl'
save_name_aux = 'TimeSeries_dataset_synthetic_aux.pkl'
df_main_ts.to_pickle(os.path.join(self.data_root, save_name_main))
df_aux_ts.to_pickle(os.path.join(self.data_root, save_name_aux))
# Transform to images and save to disk
df_ts = pd.concat([df_main_ts, df_aux_ts], axis=1)
df = generate_cut_image_dataset(
df_ts, list(df_ts.columns),
num_processes=20, square_size=64
)
save_name = 'Image_dataset_synthetic_64x64.pkl'
df.to_pickle(os.path.join(self.data_root, save_name))
return df
class TimeSeriesDatasetSplitter(DataSplitter):
def __init__(
self,
train_proportion: int | float,
validation_proportion: int | float = 0.0,
test_proportion: int | float = 0.0,
rnd_seed: int | None = None,
images_dataset: str = "data/Image_dataset_synthetic_64x64.pkl",
name: str | None = None
) -> None:
super().__init__(
train_proportion, validation_proportion,
test_proportion, name
)
self.save_parameters(**self.locals2params(locals()))
self.validation_proportion = 1-train_proportion
self.rnd_seed = rnd_seed
self.images_dataset = images_dataset
def get_or_load(self, dataset: Optional[pd.DataFrame] = None):
"""If the dataset is not given, load it from disk."""
if dataset is None:
print("WARNING: loading time series dataset from disk.")
return pd.read_pickle(self.images_dataset)
return dataset
@monitor_exec
def execute(
self,
dataset: Optional[pd.DataFrame] = None
) -> Tuple:
"""Splits a dataset into train, validation and test splits.
Args:
dataset (pd.DataFrame): input dataset.
Returns:
Tuple: tuple of train, validation and test splits. Test is None.
"""
dataset = self.get_or_load(dataset)
# Convert data to torch
df = dataset.applymap(lambda x: torch.tensor(x))
# Divide Image dataset in main and aux channels. Note that df
# generated in the section Generate Synthetic Dataset will always have
# the main channel as its first column
main_channel = list(df.columns)[0]
aux_channels = list(df.columns)[1:]
df_aux_all_2d = pd.DataFrame(df[aux_channels])
df_main_all_2d = pd.DataFrame(df[main_channel])
X_train_2d, X_test_2d, y_train_2d, y_test_2d = train_test_split(
df_aux_all_2d, df_main_all_2d,
test_size=self.validation_proportion, random_state=self.rnd_seed)
return (X_train_2d, y_train_2d), (X_test_2d, y_test_2d), None
class TimeSeriesProcessor(DataProcessor):
def __init__(self, name: str | None = None) -> None:
super().__init__(name)
self.save_parameters(**self.locals2params(locals()))
@monitor_exec
def execute(
self,
train_dataset: Tuple,
validation_dataset: Tuple,
test_dataset: Any = None
) -> Tuple[torch.Tensor, torch.Tensor, None]:
"""Pre-process datasets: rearrange and normalize before training.
Args:
train_dataset (Tuple): training dataset.
validation_dataset (Tuple): validation dataset.
test_dataset (Any, optional): unused placeholder. Defaults to None.
Returns:
Tuple[torch.Tensor, torch.Tensor, None]: train, validation, and
test (placeholder) datasets. Ready to be used for training.
"""
X_train_2d, y_train_2d = train_dataset
X_test_2d, y_test_2d = validation_dataset
# Name of the main channel (assuming it's in position 0)
main_channel = list(y_train_2d.columns)[0]
# TRAINING SET
# # smaller dataset
# signal_data_train_small_2d = torch.stack([
# torch.stack([y_train_2d[main_channel].iloc[i]])
# for i in range(100)
# ]) # for i in range(y_train.shape[0])
# aux_data_train_small_2d = torch.stack([
# torch.stack([X_train_2d.iloc[i][0], X_train_2d.iloc[i]
# [1], X_train_2d.iloc[i][2]])
# for i in range(100)
# ]) # for i in range(X_train.shape[0])
# whole dataset
signal_data_train_2d = torch.stack([
torch.stack([y_train_2d[main_channel].iloc[i]])
for i in range(y_train_2d.shape[0])
])
aux_data_train_2d = torch.stack([
torch.stack(
[X_train_2d.iloc[i][0], X_train_2d.iloc[i][1],
X_train_2d.iloc[i][2]])
for i in range(X_train_2d.shape[0])
])
# concatenate torch.tensors
train_data_2d = torch.cat(
[signal_data_train_2d, aux_data_train_2d], dim=1)
# train_data_small_2d = torch.cat(
# [signal_data_train_small_2d, aux_data_train_small_2d], dim=1)
# VALIDATION SET
# # smaller dataset
# signal_data_test_small_2d = torch.stack([
# torch.stack(
# [y_test_2d[main_channel].iloc[i]])
# for i in range(100)
# ]) # for i in range(y_test.shape[0])
# aux_data_test_small_2d = torch.stack([
# torch.stack(
# [X_test_2d.iloc[i][0], X_test_2d.iloc[i][1],
# X_test_2d.iloc[i][2]])
# for i in range(100)
# ]) # for i in range(X_test.shape[0])
# whole dataset
signal_data_test_2d = torch.stack([
torch.stack(
[y_test_2d[main_channel].iloc[i]])
for i in range(y_test_2d.shape[0])
])
aux_data_test_2d = torch.stack([
torch.stack(
[X_test_2d.iloc[i][0], X_test_2d.iloc[i][1],
X_test_2d.iloc[i][2]])
for i in range(X_test_2d.shape[0])
])
test_data_2d = torch.cat(
[signal_data_test_2d, aux_data_test_2d], dim=1)
# test_data_small_2d = torch.cat(
# [signal_data_test_small_2d, aux_data_test_small_2d], dim=1)
# NORMALIZE
train_data_2d = normalize_(train_data_2d)
test_data_2d = normalize_(test_data_2d)
return train_data_2d, test_data_2d, None
runall.shο
#!/bin/bash
# Python virtual environment (no conda/micromamba)
PYTHON_VENV="../../envAI_hdfml"
# Clear SLURM logs (*.out and *.err files)
rm -rf logs_slurm checkpoints*
mkdir logs_slurm
rm -rf logs_torchrun
# DDP itwinai
DIST_MODE="ddp"
RUN_NAME="ddp-itwinai"
TRAINING_CMD="$PYTHON_VENV/bin/itwinai exec-pipeline --config config.yaml --pipe-key training_pipeline --steps 1: -o strategy=ddp -o checkpoint_path=checkpoints_ddp/epoch_{}.pth"
sbatch --export=ALL,DIST_MODE="$DIST_MODE",RUN_NAME="$RUN_NAME",TRAINING_CMD="$TRAINING_CMD",PYTHON_VENV="$PYTHON_VENV" \
--job-name="$RUN_NAME-n$N" \
--output="logs_slurm/job-$RUN_NAME-n$N.out" \
--error="logs_slurm/job-$RUN_NAME-n$N.err" \
slurm.sh
# DeepSpeed itwinai
DIST_MODE="deepspeed"
RUN_NAME="deepspeed-itwinai"
TRAINING_CMD="$PYTHON_VENV/bin/itwinai exec-pipeline --config config.yaml --pipe-key training_pipeline --steps 1: -o strategy=deepspeed -o checkpoint_path=checkpoints_deepspeed/epoch_{}.pth"
sbatch --export=ALL,DIST_MODE="$DIST_MODE",RUN_NAME="$RUN_NAME",TRAINING_CMD="$TRAINING_CMD",PYTHON_VENV="$PYTHON_VENV" \
--job-name="$RUN_NAME-n$N" \
--output="logs_slurm/job-$RUN_NAME-n$N.out" \
--error="logs_slurm/job-$RUN_NAME-n$N.err" \
slurm.sh
# Horovod itwinai
DIST_MODE="horovod"
RUN_NAME="horovod-itwinai"
TRAINING_CMD="$PYTHON_VENV/bin/itwinai exec-pipeline --config config.yaml --pipe-key training_pipeline --steps 1: -o strategy=horovod -o checkpoint_path=checkpoints_horovod/epoch_{}.pth"
sbatch --export=ALL,DIST_MODE="$DIST_MODE",RUN_NAME="$RUN_NAME",TRAINING_CMD="$TRAINING_CMD",PYTHON_VENV="$PYTHON_VENV" \
--job-name="$RUN_NAME-n$N" \
--output="logs_slurm/job-$RUN_NAME-n$N.out" \
--error="logs_slurm/job-$RUN_NAME-n$N.err" \
slurm.sh
slurm.shο
#!/bin/bash
# SLURM jobscript for JSC systems
# Job configuration
#SBATCH --job-name=distributed_training
#SBATCH --account=intertwin
#SBATCH --mail-user=
#SBATCH --mail-type=ALL
#SBATCH --output=job.out
#SBATCH --error=job.err
#SBATCH --time=00:30:00
# Resources allocation
#SBATCH --partition=batch
#SBATCH --nodes=2
#SBATCH --gpus-per-node=4
#SBATCH --cpus-per-gpu=4
#SBATCH --exclusive
# gres options have to be disabled for deepv
#SBATCH --gres=gpu:4
# Load environment modules
ml Stages/2024 GCC OpenMPI CUDA/12 MPI-settings/CUDA Python HDF5 PnetCDF libaio mpi4py
# Job info
echo "DEBUG: TIME: $(date)"
sysN="$(uname -n | cut -f2- -d.)"
sysN="${sysN%%[0-9]*}"
echo "Running on system: $sysN"
echo "DEBUG: EXECUTE: $EXEC"
echo "DEBUG: SLURM_SUBMIT_DIR: $SLURM_SUBMIT_DIR"
echo "DEBUG: SLURM_JOB_ID: $SLURM_JOB_ID"
echo "DEBUG: SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST"
echo "DEBUG: SLURM_NNODES: $SLURM_NNODES"
echo "DEBUG: SLURM_NTASKS: $SLURM_NTASKS"
echo "DEBUG: SLURM_TASKS_PER_NODE: $SLURM_TASKS_PER_NODE"
echo "DEBUG: SLURM_SUBMIT_HOST: $SLURM_SUBMIT_HOST"
echo "DEBUG: SLURMD_NODENAME: $SLURMD_NODENAME"
echo "DEBUG: CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
if [ "$DEBUG" = true ] ; then
echo "DEBUG: NCCL_DEBUG=INFO"
export NCCL_DEBUG=INFO
fi
echo
# Setup env for distributed ML
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export OMP_NUM_THREADS=1
if [ "$SLURM_CPUS_PER_GPU" -gt 0 ] ; then
export OMP_NUM_THREADS=$SLURM_CPUS_PER_GPU
fi
# Env vairables check
if [ -z "$DIST_MODE" ]; then
>&2 echo "ERROR: env variable DIST_MODE is not set. Allowed values are 'horovod', 'ddp' or 'deepspeed'"
exit 1
fi
if [ -z "$RUN_NAME" ]; then
>&2 echo "WARNING: env variable RUN_NAME is not set. It's a way to identify some specific run of an experiment."
RUN_NAME=$DIST_MODE
fi
if [ -z "$TRAINING_CMD" ]; then
>&2 echo "ERROR: env variable TRAINING_CMD is not set. It's the python command to execute."
exit 1
fi
if [ -z "$PYTHON_VENV" ]; then
>&2 echo "WARNING: env variable PYTHON_VENV is not set. It's the path to a python virtual environment."
else
# Activate Python virtual env
source $PYTHON_VENV/bin/activate
fi
# Get GPUs info per node
srun --cpu-bind=none --ntasks-per-node=1 bash -c 'echo -e "NODE hostname: $(hostname)\n$(nvidia-smi)\n\n"'
# Launch training
if [ "$DIST_MODE" == "ddp" ] ; then
echo "DDP training: $TRAINING_CMD"
srun --cpu-bind=none --ntasks-per-node=1 \
bash -c "torchrun \
--log_dir='logs_torchrun' \
--nnodes=$SLURM_NNODES \
--nproc_per_node=$SLURM_GPUS_PER_NODE \
--rdzv_id=$SLURM_JOB_ID \
--rdzv_conf=is_host=\$(((SLURM_NODEID)) && echo 0 || echo 1) \
--rdzv_backend=c10d \
--rdzv_endpoint='$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)'i:29500 \
$TRAINING_CMD"
elif [ "$DIST_MODE" == "deepspeed" ] ; then
echo "DEEPSPEED training: $TRAINING_CMD"
MASTER_ADDR=$(scontrol show hostnames "\$SLURM_JOB_NODELIST" | head -n 1)i
export MASTER_ADDR
export MASTER_PORT=29500
srun --cpu-bind=none --ntasks-per-node=$SLURM_GPUS_PER_NODE --cpus-per-task=$SLURM_CPUS_PER_GPU \
$TRAINING_CMD
# # Run with deepspeed launcher: set --ntasks-per-node=1
# # https://www.deepspeed.ai/getting-started/#multi-node-environment-variables
# export NCCL_IB_DISABLE=1
# export NCCL_SOCKET_IFNAME=eth0
# nodelist=$(scontrol show hostname $SLURM_NODELIST)
# echo "$nodelist" | sed -e 's/$/ slots=4/' > .hostfile
# # Requires passwordless SSH access among compute node
# srun --cpu-bind=none deepspeed --hostfile=.hostfile $TRAINING_CMD --deepspeed
# rm .hostfile
elif [ "$DIST_MODE" == "horovod" ] ; then
echo "HOROVOD training: $TRAINING_CMD"
srun --cpu-bind=none --ntasks-per-node=$SLURM_GPUS_PER_NODE --cpus-per-task=$SLURM_CPUS_PER_GPU \
$TRAINING_CMD
else
>&2 echo "ERROR: unrecognized \$DIST_MODE env variable"
exit 1
fi
trainer.pyο
from typing import Literal, Optional
import os
import torch.nn as nn
import torch
import time
import numpy as np
from itwinai.torch.trainer import TorchTrainer
from itwinai.torch.distributed import (
DeepSpeedStrategy,
)
from itwinai.torch.config import TrainingConfiguration
from itwinai.loggers import Logger
from src.model import Decoder, Decoder_2d_deep, UNet, GeneratorResNet
from src.utils import init_weights
from tqdm import tqdm
class NoiseGeneratorTrainer(TorchTrainer):
def __init__(
self,
batch_size: int,
learning_rate: float = 1e-3,
num_epochs: int = 2,
generator: Literal["simple", "deep", "resnet", "unet"] = "unet",
loss: Literal["L1", "L2"] = "L1",
strategy: Literal["ddp", "deepspeed", "horovod"] = 'ddp',
checkpoint_path: str = "checkpoints/epoch_{}.pth",
save_best: bool = True,
logger: Optional[Logger] = None,
name: str | None = None
) -> None:
super().__init__(
epochs=num_epochs,
config={},
strategy=strategy,
logger=logger,
name=name
)
self.save_parameters(**self.locals2params(locals()))
self.num_epochs = num_epochs
self.batch_size = batch_size
self.learning_rate = learning_rate
self._generator = generator
self._loss = loss
self.checkpoint_path = checkpoint_path
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
# Global configuration
_config = dict(
batch_size=batch_size,
save_best=save_best
)
self.config = TrainingConfiguration(**_config)
def create_model_loss_optimizer(self) -> None:
# Select generator
generator = self._generator.lower()
if generator == "simple":
self.model = Decoder(3, norm=False)
init_weights(self.model, 'normal', scaling=.02)
elif generator == "deep":
self.model = Decoder_2d_deep(3)
init_weights(self.model, 'normal', scaling=.02)
elif generator == "resnet":
self.model = GeneratorResNet(3, 12, 1)
init_weights(self.model, 'normal', scaling=.01)
elif generator == "unet":
self.model = UNet(
input_channels=3, output_channels=1, norm=False)
init_weights(self.model, 'normal', scaling=.02)
else:
raise ValueError("Unrecognized generator type! Got", generator)
# Select loss
loss = self._loss.upper()
if loss == "L1":
self.loss = nn.L1Loss()
elif loss == "L2":
self.loss = nn.MSELoss()
else:
raise ValueError("Unrecognized loss type! Got", loss)
# Optimizer
self.optimizer = torch.optim.Adam(
self.model.parameters(), lr=self.learning_rate)
# IMPORTANT: model, optimizer, and scheduler need to be distributed
# First, define strategy-wise optional configurations
if isinstance(self.strategy, DeepSpeedStrategy):
# Batch size definition is not optional for DeepSpeedStrategy!
distribute_kwargs = dict(
config_params=dict(
train_micro_batch_size_per_gpu=self.config.batch_size
)
)
else:
distribute_kwargs = {}
# Distributed model, optimizer, and scheduler
self.model, self.optimizer, _ = self.strategy.distributed(
self.model, self.optimizer, **distribute_kwargs
)
def train(self):
# uncomment all lines relative to accuracy if you want to measure
# IOU between generated and real spectrograms.
# Note that it significantly slows down the whole process
# it also might not work as the function has not been fully
# implemented yet
loss_plot = []
val_loss_plot = []
acc_plot = []
val_acc_plot = []
best_val_loss = float('inf')
for epoch in tqdm(range(1, self.num_epochs+1)):
st = time.time()
epoch_loss = []
epoch_acc = []
for i, batch in enumerate(self.train_dataloader):
# batch= transform(batch)
target = batch[:, 0].unsqueeze(1).to(self.device)
# print(f'TARGET ON DEVICE: {target.get_device()}')
target = target.float()
input = batch[:, 1:].to(self.device)
# print(f'INPUT ON DEVICE: {input.get_device()}')
self.optimizer.zero_grad()
generated = self.model(input.float())
# generated=normalize_(generated,1)
loss = self.loss(generated, target)
loss.backward()
self.optimizer.step()
epoch_loss.append(loss.detach().cpu().numpy())
self.log(loss.detach().cpu().numpy(),
'epoch_loss_batch',
kind='metric',
step=epoch*len(self.train_dataloader) + i,
batch_idx=i)
# acc=accuracy(generated.detach().cpu().numpy(),target.detach().cpu().numpy(),20)
# epoch_acc.append(acc)
val_loss = []
val_acc = []
for batch in (self.validation_dataloader):
# batch= transform(batch)
target = batch[:, 0].unsqueeze(1).to(self.device)
target = target.float()
input = batch[:, 1:].to(self.device)
with torch.no_grad():
generated = self.model(input.float())
# generated=normalize_(generated,1)
loss = self.loss(generated, target)
val_loss.append(loss.detach().cpu().numpy())
self.log(loss.detach().cpu().numpy(),
'val_loss_batch',
kind='metric',
step=epoch*len(self.validation_dataloader) + i,
batch_idx=i)
# acc=accuracy(generated.detach().cpu().numpy(),target.detach().cpu().numpy(),20)
# val_acc.append(acc)
loss_plot.append(np.mean(epoch_loss))
val_loss_plot.append(np.mean(val_loss))
acc_plot.append(np.mean(epoch_acc))
val_acc_plot.append(np.mean(val_acc))
# Log metrics/losses
self.log(np.mean(epoch_loss), 'epoch_loss',
kind='metric', step=epoch)
self.log(np.mean(val_loss), 'val_loss',
kind='metric', step=epoch)
# self.log(np.mean(epoch_acc), 'epoch_acc',
# kind='metric', step=epoch)
# self.log(np.mean(val_acc), 'val_acc',
# kind='metric', step=epoch)
# print('epoch: {} loss: {} val loss: {} accuracy: {} val
# accuracy: {}'.format(epoch,loss_plot[-1],val_loss_plot[-1],
# acc_plot[-1],val_acc_plot[-1]))
et = time.time()
if self.strategy.is_main_worker:
print('epoch: {} loss: {} val loss: {} time:{}s'.format(
epoch, loss_plot[-1], val_loss_plot[-1], et-st))
# Save checkpoint every 100 epochs
if (epoch+1) % 1 == 0:
# uncomment the following if you want to save checkpoint every
# 100 epochs regardless of the performance of the model
# checkpoint = {
# 'epoch': epoch,
# 'model_state_dict': generator.state_dict(),
# 'optim_state_dict': optimizer.state_dict(),
# 'loss': loss_plot[-1],
# 'val_loss': val_loss_plot[-1],
# }
# if self.strategy.is_main_worker:
# # Save only in the main worker
# checkpoint_filename = checkpoint_path.format(epoch)
# torch.save(checkpoint, checkpoint_filename)
# Average loss among all workers
worker_val_losses = self.strategy.gather_obj(val_loss_plot[-1])
if self.strategy.is_main_worker:
# Save only in the main worker
# avg_loss has a meaning only in the main worker
avg_loss = np.mean(worker_val_losses)
# instead of val_loss and best_val loss we should
# use accuracy!!!
if self.config.save_best and avg_loss < best_val_loss:
# create checkpoint
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optim_state_dict': self.optimizer.state_dict(),
'loss': loss_plot[-1],
'val_loss': val_loss_plot[-1],
}
# save checkpoint only if it is better than
# the previous ones
checkpoint_filename = self.checkpoint_path.format(
epoch)
torch.save(checkpoint, checkpoint_filename)
self.log(checkpoint_filename,
os.path.basename(checkpoint_filename),
kind='artifact')
# update best model
best_val_loss = val_loss_plot[-1]
best_checkpoint_filename = (
self.checkpoint_path.format('best')
)
torch.save(checkpoint, best_checkpoint_filename)
self.log(best_checkpoint_filename,
os.path.basename(best_checkpoint_filename),
kind='artifact')
# return (loss_plot, val_loss_plot,
# acc_plot, val_acc_plot ,acc_plot, val_acc_plot)
return loss_plot, val_loss_plot, acc_plot, val_acc_plot