Source code for itwinai.utils

# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# --------------------------------------------------------------------------------------

"""Utilities for itwinai package."""

import inspect
import os
import random
import sys
from collections.abc import MutableMapping
from pathlib import Path
from typing import Callable, Dict, Hashable, List, Tuple, Type
from urllib.parse import urlparse

import yaml

adjectives = [
    "quantum",
    "relativistic",
    "wavy",
    "entangled",
    "chiral",
    "tachyonic",
    "superluminal",
    "anomalous",
    "hypercharged",
    "fermionic",
    "hadronic",
    "quarky",
    "holographic",
    "dark",
    "force-sensitive",
    "chaotic",
]

names = [
    "neutrino",
    "graviton",
    "muon",
    "gluon",
    "tachyon",
    "quasar",
    "pulsar",
    "blazar",
    "meson",
    "boson",
    "hyperon",
    "starlord",
    "groot",
    "rocket",
    "yoda",
    "skywalker",
    "sithlord",
    "midichlorian",
    "womp-rat",
    "beskar",
    "mandalorian",
    "ewok",
    "vibranium",
    "nova",
    "gamora",
    "drax",
    "ronan",
    "thanos",
    "cosmo",
]


[docs] def generate_random_name(): return f"{random.choice(adjectives)}-{random.choice(names)}"
[docs] def load_yaml(path: str) -> Dict: """Load YAML file as dict. Args: path (str): path to YAML file. Raises: yaml.YAMLError: for loading/parsing errors. Returns: Dict: nested dict representation of parsed YAML file. """ with open(path, "r", encoding="utf-8") as yaml_file: try: loaded_config = yaml.safe_load(yaml_file) except yaml.YAMLError as exc: print(exc) raise exc return loaded_config
[docs] def dynamically_import_class(name: str) -> Type: """ Dynamically import class by module path. Adapted from https://stackoverflow.com/a/547867 Args: name (str): path to the class (e.g., mypackage.mymodule.MyClass) Returns: __class__: class type. """ try: module, class_name = name.rsplit(".", 1) mod = __import__(module, fromlist=[class_name]) klass = getattr(mod, class_name) except ModuleNotFoundError as err: print( f"Module not found when trying to dynamically import '{name}'. " "Make sure that the module's file is reachable from your current " "directory." ) raise err except Exception as err: print( f"Exception occurred when trying to dynamically import '{name}'. " "Make sure that the module's file is reachable from your current " "directory and that the class is present in that module." ) raise err return klass
[docs] def flatten_dict(d: MutableMapping, parent_key: str = "", sep: str = ".") -> MutableMapping: """Flatten dictionary Args: d (MutableMapping): nested dictionary to flatten parent_key (str, optional): prefix for all keys. Defaults to ''. sep (str, optional): separator for nested key concatenation. Defaults to '.'. Returns: MutableMapping: flattened dictionary with new keys. """ items = [] for k, v in d.items(): new_key = parent_key + sep + k if parent_key else k if isinstance(v, MutableMapping): items.extend(flatten_dict(v, new_key, sep=sep).items()) else: items.append((new_key, v)) return dict(items)
[docs] class SignatureInspector: """Provides the functionalities to inspect the signature of a function or a method. Args: func (Callable): function to be inspected. """ INFTY: int = sys.maxsize def __init__(self, func: Callable) -> None: self.func = func self.func_params = inspect.signature(func).parameters.items() @property def has_varargs(self) -> bool: """Checks if the function has ``*args`` parameter.""" return any(map(lambda p: p[1].kind == p[1].VAR_POSITIONAL, self.func_params)) @property def has_kwargs(self) -> bool: """Checks if the function has ``**kwargs`` parameter.""" return any(map(lambda p: p[1].kind == p[1].VAR_KEYWORD, self.func_params)) @property def required_params(self) -> Tuple[str]: """Names of required parameters. Class method's 'self' is skipped.""" required_params = list( filter( lambda p: ( p[0] != "self" and p[1].default == inspect._empty and p[1].kind != p[1].VAR_POSITIONAL and p[1].kind != p[1].VAR_KEYWORD ), self.func_params, ) ) return tuple(map(lambda p: p[0], required_params)) @property def min_params_num(self) -> int: """Minimum number of arguments required.""" return len(self.required_params) @property def max_params_num(self) -> int: """Max number of supported input arguments. If no limit, ``SignatureInspector.INFTY`` is returned. """ if self.has_kwargs or self.has_varargs: return self.INFTY return len(self.func_params)
[docs] def clear_key(my_dict: Dict, dict_name: str, key: Hashable, complain: bool = True) -> Dict: """Remove key from dictionary if present and complain. Args: my_dict (Dict): Dictionary. dict_name (str): name of the dictionary. key (Hashable): Key to remove. """ if key in my_dict: if complain: print(f"Field '{key}' should not be present in dictionary '{dict_name}'") del my_dict[key] return my_dict
[docs] def make_config_paths_absolute(args: List[str]): """Process CLI arguments to make paths specified for `--config-path` or `-cp` absolute. Returns the modified arguments list. Args: args (List[str]): a list of system arguments Returns: List(str): the updated list of system arguments, where all the config path argument is absolute. """ updated_args = args.copy() for i, arg in enumerate(updated_args): if arg.startswith("--config-path=") or arg.startswith("-cp="): prefix, path = arg.split("=", 1) abs_path = os.path.abspath(path) updated_args[i] = f"{prefix}={abs_path}" sys.path.append(abs_path) break elif arg in {"--config-path", "-cp"}: # Handle the case where the path is in the next argument abs_path = os.path.abspath(updated_args[i + 1]) updated_args[i + 1] = abs_path sys.path.append(abs_path) break return updated_args
[docs] def get_root_cause(exception: Exception) -> Exception: """Recursively extract the first exception in the exception chain.""" root = exception while root.__cause__ is not None: # Traverse the exception chain root = root.__cause__ return root
[docs] def to_uri(path_str: str | Path) -> str: """Parse a path and convert it to a URI. Args: path_str (str): path to convert. Returns: str: URI. """ if isinstance(path_str, Path): return str(Path(path_str).resolve()) parsed = urlparse(path_str) if parsed.scheme: # If it has a scheme, assume it's a URI and return as-is return path_str # Otherwise, make it absolute return str(Path(path_str).resolve())