Source code for itwinai.torch.mlflow

# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# --------------------------------------------------------------------------------------

import logging
from pathlib import Path
from typing import Dict

import mlflow
import yaml

py_logger = logging.getLogger(__name__)


def _get_mlflow_logger_conf(pl_config: Dict) -> Dict | None:
    """Extract MLFLowLogger configuration from pytorch lightning
    configuration file, if present.

    Args:
        pl_config (Dict): lightning configuration loaded in memory.

    Returns:
        Optional[Dict]: if present, MLFLowLogger constructor arguments
        (under 'init_args' key).
    """
    if not pl_config["trainer"].get("logger"):
        return None
    if isinstance(pl_config["trainer"]["logger"], list):
        # If multiple loggers are provided
        for logger_conf in pl_config["trainer"]["logger"]:
            if logger_conf["class_path"].endswith("MLFlowLogger"):
                return logger_conf["init_args"]
    elif pl_config["trainer"]["logger"]["class_path"].endswith("MLFlowLogger"):
        return pl_config["trainer"]["logger"]["init_args"]


def _mlflow_log_pl_config(pl_config: Dict, local_yaml_path: str | Path) -> None:
    if isinstance(local_yaml_path, str):
        local_yaml_path = Path(local_yaml_path)

    local_yaml_path.parent.mkdir(exist_ok=True, parents=True)
    with open(local_yaml_path, "w") as outfile:
        yaml.dump(pl_config, outfile, default_flow_style=False)
    mlflow.log_artifact(str(local_yaml_path))


[docs] def init_lightning_mlflow( pl_config: Dict, default_experiment_name: str = "Default", tmp_dir: str = ".tmp", **autolog_kwargs, ) -> None: """Initialize mlflow for pytorch lightning, also setting up auto-logging (mlflow.pytorch.autolog(...)). Creates a new mlflow run and attaches it to the mlflow auto-logger. Args: pl_config (Dict): pytorch lightning configuration loaded in memory. default_experiment_name (str, optional): used as experiment name if it is not given in the lightning conf. Defaults to 'Default'. tmp_dir (str): where to temporarily store some artifacts. autolog_kwargs (kwargs): args for mlflow.pytorch.autolog(...). """ mlflow_conf: Dict | None = _get_mlflow_logger_conf(pl_config) if not mlflow_conf: return tracking_uri = mlflow_conf.get("tracking_uri") if not tracking_uri: save_path = mlflow_conf.get("save_dir") tracking_uri = "file://" + str(Path(save_path).resolve()) experiment_name = mlflow_conf.get("experiment_name") if not experiment_name: experiment_name = default_experiment_name mlflow.set_tracking_uri(tracking_uri) mlflow.set_experiment(experiment_name) mlflow.pytorch.autolog(**autolog_kwargs) run = mlflow.start_run() py_logger.info(f"MLFlow's artifacts URI: {run.info.artifact_uri}") mlflow_conf["experiment_name"] = experiment_name mlflow_conf["run_id"] = mlflow.active_run().info.run_id tmp_dir_path = Path(tmp_dir) _mlflow_log_pl_config(pl_config, tmp_dir_path / "pl_config.yml")
[docs] def teardown_lightning_mlflow() -> None: """End active mlflow run, if any.""" if mlflow.active_run() is not None: mlflow.end_run()