Source code for itwinai.torch.loggers

# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Anna Elisa Lappe
#
# Credit:
# - Anna Lappe <anna.elisa.lappe@cern.ch> - CERN
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# -------------------------------------------------------------------------------------

import os
import tempfile
from argparse import Namespace
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union

import yaml
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers.logger import Logger as LightningLogger
from lightning.pytorch.loggers.logger import rank_zero_experiment
from lightning.pytorch.loggers.utilities import _scan_checkpoints
from lightning.pytorch.utilities import rank_zero_only
from torch import Tensor
from typing_extensions import override

from itwinai.loggers import Logger as ItwinaiBaseLogger


[docs] class ItwinaiLogger(LightningLogger): """Adapter between PyTorch Lightning logger and itwinai logger. This adapter forwards logging calls from PyTorch Lightning to the itwinai Logger instance, using the itwinai Logger's `log` method. It supports the lightning logging of metrics, hyperparameters, and checkpoints. Additionally, any function calls can be forwarded to the itwinai logger instance though the `experiment` property of this Adapter. """ def __init__( self, itwinai_logger: ItwinaiBaseLogger, log_model: Union[Literal["all"], bool] = False, skip_finalize: bool = False, ): """Initializes the adapter with an itwinai logger instance. Args: itwinai_logger (Logger): An instance of itwinai Logger. log_model (Union[Literal["all"], bool], optional): Specifies which checkpoints to log. If "all", logs all checkpoints; if True, logs the best k checkpoints according to the specifications given as `save_top_k` in the Lightning ModelCheckpoint; if False, does not log checkpoints. skip_finalize (bool): if True, do not finalize the logger in the finalize method. This is useful when you also want to use the logger outside of lightning. Defaults to False. """ self.itwinai_logger = itwinai_logger self._log_model = log_model self._skip_finalize = skip_finalize self._logged_model_time = {} self._checkpoint_callback = None @property def name(self) -> Optional[str]: """Return the experiment name.""" self.experiment.experiment_id @property def version(self) -> Optional[Union[int, str]]: """Return the experiment version.""" self.experiment.run_id @property @override @rank_zero_only def save_dir(self) -> Optional[str]: """Return the directory where the logs are stored.""" return self.experiment.savedir @property @rank_zero_experiment def experiment(self) -> ItwinaiBaseLogger: """Lightning Logger function. Initializes and returns the itwinai Logger context for experiment tracking. Returns: Logger: The itwinai logger instance. """ if not self.itwinai_logger.is_initialized: # With the rank_zero_experiment decorators the rank will always be 0 self.itwinai_logger.create_logger_context(rank=0) return self.itwinai_logger
[docs] @override @rank_zero_only def finalize(self, status: str) -> None: """Lightning Logger function. Logs any remaining checkpoints and closes the logger context. Args: status (str): Describes the status of the training (e.g., 'completed', 'failed'). The status is not needed for this function but part of the parent classes' (LightningLogger) finalize functions signature, and therefore must be propagated here. """ if not self.itwinai_logger.is_initialized or self._skip_finalize: return # Log checkpoints if the last checkpoint was saved but not logged if self._checkpoint_callback: self._scan_and_log_checkpoints(self._checkpoint_callback) self.experiment.destroy_logger_context()
[docs] @override @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: """Lightning Logger function. Logs the given metrics and is usually called by the Lightning Trainer. Args: metrics (Dict[str, float]): Dictionary of metrics to log. step (Optional[int], optional): Training step associated with the metrics. Defaults to None. """ for identifier, item in metrics.items(): self.experiment.log(item=item, identifier=identifier, kind="metric", step=step)
[docs] @override @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: """Lightning Logger function. Logs hyperparameters for the experiment. Args: params (Union[Dict[str, Any], Namespace]): Hyperparameters dictionary or object. """ if isinstance(params, Namespace): params = vars(params) self.experiment.save_hyperparameters(params)
[docs] @override def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: """Lightning Logger function. Handles checkpoint saving to the logger after the ModelCheckpoint Callback of the Lightning Trainer is called. The checkpoints are logged as artifacts. Args: checkpoint_callback (ModelCheckpoint): Callback instance to manage checkpointing. """ if self._log_model == "all" or ( self._log_model is True and checkpoint_callback.save_top_k == -1 ): self._scan_and_log_checkpoints(checkpoint_callback) elif self._log_model is True: self._checkpoint_callback = checkpoint_callback
def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: """Scans and logs checkpoints as artifacts in the experiment. This function retrieves new checkpoints, logs them as artifacts, and saves related metadata, including checkpoint score, filename, and any model checkpoint parameters. Args: checkpoint_callback (ModelCheckpoint): Callback instance used for checkpointing. """ # Get checkpoints to be saved with associated score checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time) # Log iteratively all new checkpoints for time, path, score, _ in checkpoints: metadata = { "score": score.item() if isinstance(score, Tensor) else score, "original_filename": Path(path).name, checkpoint_callback.__class__.__name__: { k: getattr(checkpoint_callback, k) for k in [ "monitor", "mode", "save_last", "save_top_k", "save_weights_only", "_every_n_train_steps", ] # ensure it does not break if `ModelCheckpoint` args change if hasattr(checkpoint_callback, k) }, } aliases = ( ["latest", "best"] if path == checkpoint_callback.best_model_path else ["latest"] ) artifact_path = Path(path).stem # Log the checkpoint self.experiment.log(item=path, identifier=artifact_path, kind="artifact") with tempfile.TemporaryDirectory( prefix="test", suffix="test", dir=os.getcwd() ) as tmp_dir: # Save the metadata with open(f"{tmp_dir}/metadata.yaml", "w") as tmp_file_metadata: yaml.dump(metadata, tmp_file_metadata, default_flow_style=False) # Save the aliases with open(f"{tmp_dir}/aliases.txt", "w") as tmp_file_aliases: tmp_file_aliases.write(str(aliases)) # Log metadata and aliases self.experiment.log(item=tmp_dir, identifier=artifact_path, kind="artifact") self._logged_model_time[path] = time