Source code for itwinai.torch.reproducibility

# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# --------------------------------------------------------------------------------------

"""This module provides the tools to support reproducible execution of torch scripts."""

import random
from typing import Optional

import numpy as np
import torch


[docs] def seed_worker(worker_id): """Seed DataLoader worker.""" worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed)
[docs] def set_seed(rnd_seed: Optional[int], deterministic_cudnn: bool = True) -> torch.Generator: """Set torch random seed and return a PRNG object. Args: rnd_seed (Optional[int]): random seed. If None, the seed is not set. deterministic_cudnn (bool): if True, sets ``torch.backends.cudnn.benchmark = False``, which may affect performances. Returns: torch.Generator: PRNG object. """ g = torch.Generator() if rnd_seed is not None: # Deterministic execution np.random.seed(rnd_seed) random.seed(rnd_seed) torch.manual_seed(rnd_seed) g.manual_seed(rnd_seed) if torch.cuda.is_available(): torch.cuda.manual_seed(rnd_seed) torch.cuda.manual_seed_all(rnd_seed) if deterministic_cudnn: torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True return g