Source code for itwinai.tensorflow.distributed

# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# --------------------------------------------------------------------------------------

"""TensorFlow distributed strategies."""

import os
from typing import Tuple

import tensorflow as tf
import tensorflow.distribute as dist


[docs] def get_strategy() -> Tuple[tf.distribute.Strategy, int]: """Strategy for distributed TensorFlow training. It will automatically detect if you are running in a multi-node environment, returning the correct TensorFlow distributed strategy for data parallel distributed training. Returns: Tuple[tf.distribute.Strategy, int]: strategy and number of `parallel workers`_. .. _parallel workers: https://stackoverflow.com/questions/66005641/why-we-are-using-strategy-num-replicas-in-sync. """ slurm_jobid = os.environ.get("SLURM_JOB_ID") slurm_nnodes = int(os.environ.get("SLURM_NNODES", 0)) if not slurm_jobid or slurm_nnodes < 2: # Single-node environment print("Not in SLURM env! Assuming that you are running on a single node") mirrored_strategy = dist.MirroredStrategy() return mirrored_strategy, mirrored_strategy.num_replicas_in_sync # Multi-node environment in SLURM cluster_resolver = dist.cluster_resolver.SlurmClusterResolver(port_base=12345) implementation = dist.experimental.CommunicationImplementation.NCCL communication_options = dist.experimental.CommunicationOptions( implementation=implementation ) # declare distribution strategy tf_dist_strategy = dist.MultiWorkerMirroredStrategy( cluster_resolver=cluster_resolver, communication_options=communication_options ) # number of workers n_workers = int(os.environ["SLURM_NTASKS"]) # list of devices per worker devices = tf.config.experimental.list_physical_devices("GPU") # number of devices per worker n_gpus_per_worker = len(devices) # total number of GPUs n_gpus = n_workers * n_gpus_per_worker # get total number of detected GPUs print("Number of detected devices: {}".format(n_gpus)) # get total number of workers print("Number of devices: {}".format(tf_dist_strategy.num_replicas_in_sync)) return tf_dist_strategy, tf_dist_strategy.num_replicas_in_sync