Source code for itwinai.tensorflow.trainer

# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# --------------------------------------------------------------------------------------

"""Base TensorFlow trainer module."""

from typing import Any, Dict, List, Optional, Tuple, Union

import keras
import tensorflow as tf
from jsonargparse import ArgumentParser
from keras.callbacks import Callback
from tensorflow.data import Dataset

from itwinai.tensorflow.distributed import get_strategy

from ..components import Trainer, monitor_exec


def _import_class(name):
    components = name.split(".")
    mod = __import__(components[0])
    for comp in components[1:]:
        mod = getattr(mod, comp)
    return mod


def _instance_from_dict(obj_dict: Dict, fail_untyped: bool = True) -> Any:
    if isinstance(obj_dict, dict) and obj_dict.get("class_path") is not None:
        # obj_dict is a dictionary with a structure compliant with
        # jsonargparse
        obj_class = _import_class(obj_dict["class_path"])
        parser = ArgumentParser()
        parser.add_subclass_arguments(obj_class, "object", fail_untyped=fail_untyped)
        obj_dict = {"object": obj_dict}
        return parser.instantiate_classes(obj_dict).object

    raise ValueError(
        "Unable to instantiate object with this "
        f"dict configuration: {obj_dict}.\nIt should have "
        "valid 'class_path' and 'init_args' fields"
    )


[docs] class TensorflowTrainer(Trainer): """Trains a Keras model. Args: epochs (int): number of training epochs. micro_batch_size (int): per-worker batch size. Equals macro batch size when not running distributed. shuffle_buffer (Optional[int], optional): if given, shuffles dataset using a buffer of given size. See ``tf.data.Dataset.shuffle``. Defaults to None. callbacks (Optional[List], optional): list fo Keras callbacks. Can be a list of dictionary configurations. Defaults to None. model_config (Optional[Dict], optional): model configuration. If given, a model is instantiated from this configuration. Defaults to None. model_compile_config (Optional[Dict], optional): configuration for ``keras.Model.compile``. Defaults to None. rnd_seed (Optional[int], optional): random seed. Defaults to None. verbose (Union[str, int], optional): verbosity level for ``keras.Model.fit``. Defaults to 'auto'. """ #: TensorFlow distributed strategy. strategy: tf.distribute.Strategy #: Total number of workers in distributed strategy. num_workers: int #: List of Keras callbacks. Defaults to None. callbacks: Optional[List] = None #: Total number of training epochs. epochs: int #: Buffer used to shuffle dataset. Defaults to None. shuffle_buffer: Optional[int] = None #: Per-worker batch size (when distributed). micro_batch_size: int #: Total batch size. When distributed, it is the sum of #: ``micro_batch_size`` across all workers. macro_batch_size: int #: Random seed for reproducibility. Defaults to None. rnd_seed: Optional[int] = None def __init__( self, epochs: int, micro_batch_size: int, shuffle_buffer: Optional[int] = None, callbacks: Optional[List[Union[Dict, Callback]]] = None, model_config: Optional[Dict] = None, model_compile_config: Optional[Dict] = None, rnd_seed: Optional[int] = None, verbose: Union[str, int] = "auto", ): super().__init__() self.save_parameters(**self.locals2params(locals())) self.epochs = epochs self.micro_batch_size = micro_batch_size self.shuffle_buffer = shuffle_buffer self.rnd_seed = rnd_seed self.verbose = verbose if callbacks is not None: self.callbacks = self.instantiate_callbacks(callbacks) else: self.callbacks = [] # Distributed strategy self.strategy, self.num_workers = get_strategy() print(f"Distributed strategy is working with: {self.num_workers} devices") self.macro_batch_size = self.micro_batch_size * self.num_workers # Compile model from configuration, if given if model_config is not None and model_compile_config is not None: with self.strategy.scope(): self.model: tf.keras.Model = _instance_from_dict(model_config) model_compile_config = self.instantiate_compile_conf(model_compile_config) self.model.compile(**model_compile_config) else: print( "Either model_config or model_compile_config were not given. " "Skipping automatic model compilation." )
[docs] @staticmethod def instantiate_compile_conf(model_compile_config: Dict) -> Dict[str, Any]: """Instantiate fields of Keras ``model.compile()`` from their dictionary serialization. Args: model_compile_config (Dict): fields of Keras ``model.compile()`` serialized as dictionary. Returns: Dict[str, Any]: dictionary mapping compile argument names to the instantiated objects. """ final_conf = {} for item_name, item in model_compile_config.items(): if isinstance(item, dict): item = _instance_from_dict(item) final_conf[item_name] = item return final_conf
[docs] @staticmethod def instantiate_callbacks(callbacks: List[Union[Dict, Callback]]) -> List[Callback]: """Instantiate Keras callbacks from dictionaries. Args: callbacks (List[Union[Dict, Callback]]): list of Keras callbacks in serialized as dictionary. Returns: List[Callback]: list of instantiated callbacks. """ final_callbacks = [] for item in callbacks: if isinstance(item, dict): # Not all constructor args in keras callbacks # are typed! item = _instance_from_dict(item, fail_untyped=False) final_callbacks.append(item) return final_callbacks
[docs] @monitor_exec def execute( self, train_dataset: Dataset, validation_dataset: Dataset, test_dataset: Optional[Dataset] = None, ) -> Tuple[Dataset, Dataset, Dataset, keras.Model]: """Run training. Users should override this method. Args: train_dataset (Dataset): train dataset of type ``tensorflow.data.Dataset``. validation_dataset (Dataset): validation dataset of type ``tensorflow.data.Dataset``. test_dataset (Optional[Dataset], optional): test dataset of type ``tensorflow.data.Dataset``. Defaults to None. Returns: Tuple[Dataset, Dataset, Dataset, keras.Model]: tuple of train_dataset, validation_dataset, test_dataset, and trained Keras model. """ print(f"len(train_dataset): {len(train_dataset)}") print(f"len(validation_dataset): {len(validation_dataset)}") print("micro_batch_size: ", self.micro_batch_size, flush=True) print("macro_batch_size: ", self.macro_batch_size, flush=True) # Shuffle dataset if self.shuffle_buffer: train_ds = train_dataset.shuffle(self.shuffle_buffer, seed=self.rnd_seed) valid_ds = validation_dataset.shuffle(self.shuffle_buffer, seed=self.rnd_seed) else: train_ds = train_dataset valid_ds = validation_dataset # Set batch size to the dataset and repeat train_ds = train_ds.batch( self.macro_batch_size, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE ).repeat(self.epochs) valid_ds = valid_ds.batch( self.macro_batch_size, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE ).repeat(self.epochs) print(f"len(train_ds): {len(train_ds)}") print(f"len(valid_ds): {len(valid_ds)}") # Distribute datasets among strategy's replica dist_train_dataset = self.strategy.experimental_distribute_dataset(train_ds) dist_valid_dataset = self.strategy.experimental_distribute_dataset(valid_ds) print(f"len(dist_train_dataset): {len(train_ds)}") print(f"len(dist_train_dataset): {len(valid_ds)}") # Compute the steps per epoch for train and valid steps_per_epoch = len(train_dataset) // self.macro_batch_size validation_steps = len(validation_dataset) // self.macro_batch_size print(f"steps_per_epoch: {steps_per_epoch}") print(f"validation_steps: {validation_steps}") ##################################################################### # Instantiate here model, optimizer, loss under the strategy scope, # # if not done previously through `model_compile_config` and # # `model_config` ! # # Always remember that they should be instantiated under the # # distributed strategy scope: ``with self.strategy.scope():`` # # # # Example: # # with self.strategy.scope(): # # model = tf.keras.Sequential(...) # # optimizer = rf.keras.optimizers.Adam(...) # # loss = tf.keras.losses.BinaryCrossentropy(...) # ##################################################################### # Train the model self.model.fit( dist_train_dataset, validation_data=dist_valid_dataset, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, epochs=self.epochs, callbacks=self.callbacks, verbose=self.verbose, ) print("Training completed") return train_dataset, validation_dataset, test_dataset, self.model