Source code for itwinai.utils

# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# --------------------------------------------------------------------------------------

"""Utilities for itwinai package."""

import inspect
import sys
from collections.abc import MutableMapping
from typing import Callable, Dict, Hashable, Tuple, Type

import yaml


[docs] def load_yaml(path: str) -> Dict: """Load YAML file as dict. Args: path (str): path to YAML file. Raises: yaml.YAMLError: for loading/parsing errors. Returns: Dict: nested dict representation of parsed YAML file. """ with open(path, "r", encoding="utf-8") as yaml_file: try: loaded_config = yaml.safe_load(yaml_file) except yaml.YAMLError as exc: print(exc) raise exc return loaded_config
[docs] def dynamically_import_class(name: str) -> Type: """ Dynamically import class by module path. Adapted from https://stackoverflow.com/a/547867 Args: name (str): path to the class (e.g., mypackage.mymodule.MyClass) Returns: __class__: class type. """ try: module, class_name = name.rsplit(".", 1) mod = __import__(module, fromlist=[class_name]) klass = getattr(mod, class_name) except ModuleNotFoundError as err: print( f"Module not found when trying to dynamically import '{name}'. " "Make sure that the module's file is reachable from your current " "directory." ) raise err except Exception as err: print( f"Exception occurred when trying to dynamically import '{name}'. " "Make sure that the module's file is reachable from your current " "directory and that the class is present in that module." ) raise err return klass
[docs] def flatten_dict(d: MutableMapping, parent_key: str = "", sep: str = ".") -> MutableMapping: """Flatten dictionary Args: d (MutableMapping): nested dictionary to flatten parent_key (str, optional): prefix for all keys. Defaults to ''. sep (str, optional): separator for nested key concatenation. Defaults to '.'. Returns: MutableMapping: flattened dictionary with new keys. """ items = [] for k, v in d.items(): new_key = parent_key + sep + k if parent_key else k if isinstance(v, MutableMapping): items.extend(flatten_dict(v, new_key, sep=sep).items()) else: items.append((new_key, v)) return dict(items)
[docs] class SignatureInspector: """Provides the functionalities to inspect the signature of a function or a method. Args: func (Callable): function to be inspected. """ INFTY: int = sys.maxsize def __init__(self, func: Callable) -> None: self.func = func self.func_params = inspect.signature(func).parameters.items() @property def has_varargs(self) -> bool: """Checks if the function has ``*args`` parameter.""" return any(map(lambda p: p[1].kind == p[1].VAR_POSITIONAL, self.func_params)) @property def has_kwargs(self) -> bool: """Checks if the function has ``**kwargs`` parameter.""" return any(map(lambda p: p[1].kind == p[1].VAR_KEYWORD, self.func_params)) @property def required_params(self) -> Tuple[str]: """Names of required parameters. Class method's 'self' is skipped.""" required_params = list( filter( lambda p: ( p[0] != "self" and p[1].default == inspect._empty and p[1].kind != p[1].VAR_POSITIONAL and p[1].kind != p[1].VAR_KEYWORD ), self.func_params, ) ) return tuple(map(lambda p: p[0], required_params)) @property def min_params_num(self) -> int: """Minimum number of arguments required.""" return len(self.required_params) @property def max_params_num(self) -> int: """Max number of supported input arguments. If no limit, ``SignatureInspector.INFTY`` is returned. """ if self.has_kwargs or self.has_varargs: return self.INFTY return len(self.func_params)
[docs] def str_to_slice(interval: str) -> slice: """Transform string interval to Python slice. Example: "1:17:3" -> slice(1,17,3) Args: interval (str): interval to parse. Raises: ValueError: when interval is invalid. Returns: slice: parsed slice. """ import re # TODO: add support for slices starting with empty index # e.g., :20:3 if not re.match(r"\d+(:\d+)?(:\d+)?", interval): raise ValueError(f"Received invalid interval for slice: '{interval}'") if ":" in interval: return slice( *map(lambda x: int(x.strip()) if x.strip() else None, interval.split(":")) ) return int(interval)
[docs] def clear_key(my_dict: Dict, dict_name: str, key: Hashable, complain: bool = True) -> Dict: """Remove key from dictionary if present and complain. Args: my_dict (Dict): Dictionary. dict_name (str): name of the dictionary. key (Hashable): Key to remove. """ if key in my_dict: if complain: print(f"Field '{key}' should not be present " f"in dictionary '{dict_name}'") del my_dict[key] return my_dict