Source code for itwinai.serialization

# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# --------------------------------------------------------------------------------------

import abc
import inspect
import json
from pathlib import Path
from typing import Any, Dict, Union

import yaml

from .type import MLModel
from .utils import SignatureInspector


[docs] def is_jsonable(x): try: json.dumps(x) return True except Exception: return False
[docs] def fullname(o): klass = o.__class__ module = klass.__module__ if module == "builtins": return klass.__qualname__ # avoid outputs like 'builtins.str' return module + "." + klass.__qualname__
[docs] class SerializationError(Exception): """Serialization error"""
[docs] class Serializable: #: Dictionary storing constructor arguments. Needed to serialize the #: class to dictionary. Set by ``self.save_parameters()`` method. parameters: Dict[Any, Any] = None
[docs] def save_parameters(self, **kwargs) -> None: """Simplified way to store constructor arguments in as class attributes. Keeps track of the parameters to enable YAML/JSON serialization. """ if self.parameters is None: self.parameters = {} self.parameters.update(kwargs)
# for k, v in kwargs.items(): # self.__setattr__(k, v)
[docs] @staticmethod def locals2params(locals: Dict, pop_self: bool = True) -> Dict: """Remove ``self`` from the output of ``locals()``. Args: locals (Dict): output of ``locals()`` called in the constructor of a class. pop_self (bool, optional): whether to remove ``self``. Defaults to True. Returns: Dict: cleaned ``locals()``. """ if pop_self: locals.pop("self", None) return locals
[docs] def update_parameters(self, **kwargs) -> None: """Updates stored parameters.""" self.save_parameters(**kwargs)
[docs] def to_dict(self) -> Dict: """Returns a dict serialization of the current object.""" self._validate_parameters() class_path = self._get_class_path() init_args = dict() for par_name, par in self._saved_constructor_parameters().items(): init_args[par_name] = self._recursive_serialization(par, par_name) init_args["_target_"] = class_path return init_args
def _validate_parameters(self) -> None: if self.parameters is None: raise SerializationError( f"{self.__class__.__name__} cannot be serialized " "because its constructor arguments were not saved. " "Please add 'self.save_parameters(param_1=param_1, " "..., param_n=param_n)' as first instruction of its " "constructor." ) init_inspector = SignatureInspector(self.__init__) for par_name in init_inspector.required_params: if self.parameters.get(par_name) is None: raise SerializationError( f"Required parameter '{par_name}' of " f"{self.__class__.__name__} class not present in " "saved parameters. " "Please add 'self.save_parameters(param_1=param_1, " "..., param_n=param_n)' as first instruction of its " f"constructor, including also '{par_name}'." ) def _get_class_path(self) -> str: class_path = fullname(self) if "<locals>" in class_path: raise SerializationError( f"{self.__class__.__name__} is " "defined locally, which is not supported for serialization. " "Move the class to a separate Python file and import it " "from there." ) return class_path def _saved_constructor_parameters(self) -> Dict[str, Any]: """Extracts the current constructor parameters from all the saved parameters, as some of them may had been added by superclasses. Returns: Dict[str, Any]: subset of saved parameters containing only the constructor parameters for this class. """ init_params = inspect.signature(self.__init__).parameters.items() init_par_nam = map(lambda x: x[0], init_params) return { par_name: self.parameters[par_name] for par_name in init_par_nam if self.parameters.get(par_name, inspect._empty) != inspect._empty } def _recursive_serialization(self, item: Any, item_name: str) -> Any: if isinstance(item, (tuple, list, set)): return [self._recursive_serialization(x, item_name) for x in item] elif isinstance(item, dict): return {k: self._recursive_serialization(v, item_name) for k, v in item.items()} elif is_jsonable(item): return item elif isinstance(item, Serializable): return item.to_dict() else: raise SerializationError( f"{self.__class__.__name__} cannot be serialized " f"because its constructor argument '{item_name}' " "is not a Python built-in type and it does not " "extend 'itwinai.serialization.Serializable' class." )
[docs] def to_json(self, file_path: Union[str, Path], nested_key: str) -> None: """Save a component to JSON file. Args: file_path (Union[str, Path]): JSON file path. nested_key (str): root field containing the serialized object. """ with open(file_path, "w") as fp: json.dump({nested_key: self.to_dict()}, fp)
[docs] def to_yaml(self, file_path: Union[str, Path], nested_key: str) -> None: """Save a component to YAML file. Args: file_path (Union[str, Path]): YAML file path. nested_key (str): root field containing the serialized object. """ with open(file_path, "w") as fp: yaml.dump({nested_key: self.to_dict()}, fp)
[docs] class ModelLoader(abc.ABC, Serializable): """Loads a machine learning model from somewhere.""" def __init__(self, model_uri: str) -> None: super().__init__() self.model_uri = model_uri @abc.abstractmethod def __call__(self) -> MLModel: """Loads model from model URI."""