# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# --------------------------------------------------------------------------------------
"""Provide functionalities to manage configuration files, including parsing,
execution, and dynamic override of fields.
"""
import json
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union
from jsonargparse import ActionConfigFile
from jsonargparse import ArgumentParser as JAPArgumentParser
from jsonargparse._formatters import DefaultHelpFormatter
from omegaconf import OmegaConf
from .components import BaseComponent
from .pipeline import Pipeline
from .utils import load_yaml
class _ArgumentParser(JAPArgumentParser):
def error(self, message: str, ex: Optional[Exception] = None) -> None:
"""Patch error method to re-raise exception instead of exiting execution."""
raise ex
[docs]
def add_replace_field(config: Dict, key_chain: str, value: Any) -> None:
"""Replace or add (if not present) a field in a dictionary, following a
path of dot-separated keys. Adding is not supported for list items.
Inplace operation.
Args:
config (Dict): dictionary to be modified.
key_chain (str): path of nested (dot-separated) keys to specify the
location
of the new value (e.g., 'foo.bar.line' adds/overwrites the value
located at config['foo']['bar']['line']).
value (Any): the value to insert.
"""
sub_config = config
for idx, k in enumerate(key_chain.split(".")):
if idx >= len(key_chain.split(".")) - 1:
# Last key reached
break
if isinstance(sub_config, (list, tuple)):
k = int(k)
next_elem = sub_config[k]
else:
next_elem = sub_config.get(k)
if not isinstance(next_elem, (dict, list, tuple)):
sub_config[k] = dict()
sub_config = sub_config[k]
if isinstance(sub_config, (list, tuple)):
k = int(k)
sub_config[k] = value
[docs]
def get_root_cause(exception: Exception) -> Exception:
"""Recursively extract the first exception in the exception chain."""
root = exception
while root.__cause__ is not None: # Traverse the exception chain
root = root.__cause__
return root
[docs]
class ConfigParser:
"""Parses a pipeline from a configuration file.
It also provides functionalities for dynamic override
of fields by means of nested key notation.
Args:
config (Union[str, Dict]): path to YAML configuration file
or dict storing a configuration.
override_keys (Optional[Dict[str, Any]], optional): dict mapping
nested keys to the value to override. Defaults to None.
Example:
>>> # pipeline.yaml file
>>> pipeline:
>>> class_path: itwinai.pipeline.Pipeline
>>> init_args:
>>> steps:
>>> - class_path: dataloader.MNISTDataModuleTorch
>>> init_args:
>>> save_path: .tmp/
>>>
>>> - class_path: itwinai.torch.trainer.TorchTrainer
>>> init_args:
>>> model:
>>> class_path: model.Net
>>>
>>> from itwinai.parser import ConfigParser
>>>
>>> parser = ConfigParser(
>>> config='pipeline.yaml',
>>> override_keys={
>>> 'pipeline.init_args.steps.0.init_args.save_path': /save/path
>>> }
>>> )
>>> pipeline = parser.parse_pipeline()
>>> print(pipeline)
>>> print(pipeline.steps)
>>>
>>> dataloader = parser.parse_step(0)
>>> print(dataloader)
>>> print(dataloader.save_path)
"""
#: Configuration to parse.
config: Dict
#: Pipeline object instantiated from the configuration file.
pipeline: Pipeline
def __init__(
self, config: Union[str, Dict], override_keys: Optional[Dict[str, Any]] = None
) -> None:
self.config = config
self.override_keys = override_keys
if isinstance(self.config, (str, Path)):
self.config = load_yaml(self.config)
self._dynamic_override_keys()
self._omegaconf_interpolate()
def _dynamic_override_keys(self):
if self.override_keys is not None:
for key_chain, value in self.override_keys.items():
add_replace_field(self.config, key_chain, value)
def _omegaconf_interpolate(self) -> None:
"""Performs variable interpolation with OmegaConf on internal
configuration file.
"""
conf = OmegaConf.create(self.config)
self.config = OmegaConf.to_container(conf, resolve=True)
[docs]
def parse_pipeline(
self, pipeline_nested_key: str = "pipeline", verbose: bool = False
) -> Pipeline:
"""Merges steps into pipeline and parses it.
Args:
pipeline_nested_key (str, optional): nested key in the
configuration file identifying the pipeline object.
Defaults to "pipeline".
verbose (bool): if True, prints the assembled pipeline
to console formatted as JSON.
Returns:
Pipeline: instantiated pipeline.
"""
pipe_parser = _ArgumentParser()
pipe_parser.add_subclass_arguments(Pipeline, "pipeline")
pipe_dict = self.config
for key in pipeline_nested_key.split("."):
pipe_dict = pipe_dict[key]
# pipe_dict = self.config[pipeline_nested_key]
pipe_dict = {"pipeline": pipe_dict}
if verbose:
print("Assembled pipeline:")
print(json.dumps(pipe_dict, indent=4))
try:
# Parse pipeline dict once merged with steps
conf = pipe_parser.parse_object(pipe_dict)
pipe = pipe_parser.instantiate_classes(conf)
except Exception as exc:
exc = get_root_cause(exc)
raise exc
self.pipeline = pipe["pipeline"]
return self.pipeline
[docs]
def parse_step(
self,
step_idx: Union[str, int],
pipeline_nested_key: str = "pipeline",
verbose: bool = False,
) -> BaseComponent:
pipeline_dict = self.config
for key in pipeline_nested_key.split("."):
pipeline_dict = pipeline_dict[key]
step_dict_config = pipeline_dict["init_args"]["steps"][step_idx]
if verbose:
print(f"STEP '{step_idx}' CONFIG:")
print(json.dumps(step_dict_config, indent=4))
# Wrap config under "step" field and parse it
step_dict_config = {"step": step_dict_config}
step_parser = _ArgumentParser()
step_parser.add_subclass_arguments(BaseComponent, "step")
try:
parsed_namespace = step_parser.parse_object(step_dict_config)
step = step_parser.instantiate_classes(parsed_namespace)["step"]
except Exception as exc:
exc = get_root_cause(exc)
raise exc
return step
[docs]
class ArgumentParser(JAPArgumentParser):
"""Wrapper of ``jsonargparse.ArgumentParser``.
Initializer for ArgumentParser instance. It can parse arguments from
a series of configuration files. Example:
>>> python main.py --config base-conf.yaml --config other-conf.yaml \\
>>> --param OVERRIDE_VAL
All the arguments from the initializer of `argparse.ArgumentParser
<https://docs.python.org/3/library/argparse.html#argparse.ArgumentParser>`_
are supported. Additionally it accepts:
Args:
env_prefix (Union[bool, str], optional): Prefix for environment
variables. ``True`` to derive from ``prog``.. Defaults to True.
formatter_class (Type[DefaultHelpFormatter], optional): Class for
printing help messages. Defaults to DefaultHelpFormatter.
exit_on_error (bool, optional): Defaults to True.
logger (Union[bool, str, dict, logging.Logger], optional):
Configures the logger, see :class:`.LoggerProperty`.
Defaults to False.
version (Optional[str], optional): Program version which will be
printed by the --version argument. Defaults to None.
print_config (Optional[str], optional): Add this as argument to
print config, set None to disable.
Defaults to "--print_config".
parser_mode (str, optional): Mode for parsing config files:
``'yaml'``, ``'jsonnet'`` or ones added via
:func:`.set_loader`.. Defaults to "yaml".
dump_header (Optional[List[str]], optional): Header to include
as comment when dumping a config object. Defaults to None.
default_config_files
(Optional[List[Union[str, os.PathLike]]], optional):
Default config file locations, e.g.
:code:`['~/.config/myapp/*.yaml']`. Defaults to None.
default_env (bool, optional): Set the default value on whether
to parse environment variables. Defaults to False.
default_meta (bool, optional): Set the default value on whether
to include metadata in config objects. Defaults to True.
"""
def __init__(
self,
*args,
env_prefix: Union[bool, str] = True,
formatter_class: Type[DefaultHelpFormatter] = DefaultHelpFormatter,
exit_on_error: bool = True,
logger: Union[bool, str, dict, logging.Logger] = False,
version: Optional[str] = None,
print_config: Optional[str] = "--print_config",
parser_mode: str = "yaml",
dump_header: Optional[List[str]] = None,
default_config_files: Optional[List[Union[str, os.PathLike]]] = None,
default_env: bool = False,
default_meta: bool = True,
**kwargs,
) -> None:
super().__init__(
*args,
env_prefix=env_prefix,
formatter_class=formatter_class,
exit_on_error=exit_on_error,
logger=logger,
version=version,
print_config=print_config,
parser_mode=parser_mode,
dump_header=dump_header,
default_config_files=default_config_files,
default_env=default_env,
default_meta=default_meta,
**kwargs,
)
self.add_argument(
"-c",
"--config",
action=ActionConfigFile,
help="Path to a configuration file in json or yaml format.",
)
# type: ignore