Ist meine Verwendung von tf.distribute.MirroredStrategy und strategy.scope() korrekt und sicher für Multi-GPU-Training iPython

Python-Programme
Anonymous
 Ist meine Verwendung von tf.distribute.MirroredStrategy und strategy.scope() korrekt und sicher für Multi-GPU-Training i

Post by Anonymous »

Ich trainiere ein Modell mit Keras + TensorFlow mit tf.distribute.MirroredStrategy in einem Multi-GPU-Setup. Ich möchte überprüfen, ob ich strategy.scope() korrekt verwende.

Code: Select all

import time
import logging
import os
import json
import datetime
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
os.environ["KERAS_BACKEND"] = "tensorflow"
#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
#os.environ['TF_CPP_MAX_VLOG_LEVEL'] = '0'
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
from tensorflow.python.profiler import profiler_v2 as profiler
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import keras
from keras import ops
from keras import layers
from keras import mixed_precision
from medicai.models import UNETRPlusPlus
from medicai.metrics import BinaryDiceMetric
from medicai.losses import BinaryDiceCELoss
from medicai.utils.inference import SlidingWindowInference
from medicai.callbacks import SlidingWindowInferenceCallback
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from src.experiment_config import ExperimentConfig
from src.data_pipeline.data_loader import data_loader

class TFCheckpointCallback(keras.callbacks.Callback):
"""Save model + optimizer + epoch using TF checkpointing and save SWI callback best score to a JSON file for restoring."""
def __init__(self, ckpt, ckpt_manager, swi_callback, checkpoint_dir):
super().__init__()
self.ckpt = ckpt
self.ckpt_manager = ckpt_manager
self.swi_callback = swi_callback
self.best_score_file = os.path.join(checkpoint_dir, "swi_best_score.json")

def on_epoch_end(self, epoch, logs=None):
# Update epoch variable and save checkpoint
self.ckpt.epoch.assign_add(1)   # increment epoch counter
save_path = self.ckpt_manager.save()
print(f"Saved checkpoint: {save_path} (epoch {int(self.ckpt.epoch.numpy())})")

# Save SWI best score externally
best_score = getattr(self.swi_callback, "best_score", -float("inf"))
with open(self.best_score_file, "w") as f:
json.dump({"best_score": best_score}, f)
print(f"[CheckpointCallback] Saved SWI best score: {best_score}")

class HistorySaverCallback(keras.callbacks.Callback):
"""Saves training history every epoch to CSV and allows resuming."""
def __init__(self, history_file, initial_history=None):
super().__init__()
self.history_file = history_file
self.full_history = initial_history if initial_history else {}

def on_epoch_end(self, epoch, logs=None):
if logs is None:
logs = {}
for k, v in logs.items():
self.full_history.setdefault(k, []).append(v)

# Save updated history
pd.DataFrame(self.full_history).to_csv(self.history_file, index=False)

def get_model(total_device):

model = UNETRPlusPlus(
encoder_name="unetr_plusplus_encoder",
input_shape=ExperimentConfig.input_shape,
num_classes=ExperimentConfig.num_classes,
classifier_activation=None,
)

total_train_samples = 387 # 80% ( approx.) split of the total dataset for train  as Unetr

# Compute steps per epoch and total steps
steps_per_epoch = total_train_samples // (ExperimentConfig.batch_size_train * total_device)
print(f"Steps per epoch : {steps_per_epoch}")
total_steps = steps_per_epoch * ExperimentConfig.epochs

# Warmup:  10% of total steps
warmup_steps = int(total_steps * 0.1)

# CosineDecay schedule with warmup
lr_schedule = keras.optimizers.schedules.CosineDecay(
initial_learning_rate=0.01 * ExperimentConfig.lr,  # very small starting LR
decay_steps=total_steps - warmup_steps,   # decay after warmup
alpha=ExperimentConfig.alpha,
warmup_target=ExperimentConfig.lr,
warmup_steps=warmup_steps
)

model.compile(
optimizer=keras.optimizers.AdamW(
learning_rate=lr_schedule,
weight_decay=ExperimentConfig.weight_decay,
),
loss=BinaryDiceCELoss(
from_logits=True,
dice_weight=1.0,
ce_weight=1.0,
reduction="mean",
num_classes=ExperimentConfig.num_classes,
),
metrics=[
BinaryDiceMetric(
from_logits=True,
ignore_empty=True,
num_classes=ExperimentConfig.num_classes,
name='dice',
),
BinaryDiceMetric(
from_logits=True,
ignore_empty=True,
target_class_ids=[0],
num_classes=ExperimentConfig.num_classes,
name='dice_tc',
),
BinaryDiceMetric(
from_logits=True,
ignore_empty=True,
target_class_ids=[1],
num_classes=ExperimentConfig.num_classes,
name='dice_wt',
),
BinaryDiceMetric(
from_logits=True,
ignore_empty=True,
target_class_ids=[2],
num_classes=ExperimentConfig.num_classes,
name='dice_et',
)
],
)

return model

def get_inference_metric():
swi_callback_metric = BinaryDiceMetric(
from_logits=True,
ignore_empty=True,
num_classes=ExperimentConfig.num_classes,
name='val_dice',
)
return swi_callback_metric

"""def run_sliding_window_inference_per_class_average(model, ds, roi_size, sw_batch_size, overlap, metrics_list):

#    Run sliding window inference on a dataset and compute all metrics (average + per class)

for metric in metrics_list:
metric.reset_states()

swi = SlidingWindowInference(
model,
num_classes=metrics_list[0].num_classes,
roi_size=roi_size,
sw_batch_size=sw_batch_size,
overlap=overlap
)

for x, y in ds:
y_pred = swi(x)
for metric in metrics_list:
metric.update_state(ops.convert_to_tensor(y), ops.convert_to_tensor(y_pred))

# Gather results
results = {}
for metric in metrics_list:
results[metric.name] = float(ops.convert_to_numpy(metric.result()))

return results"""

def main():
# reproducibility
keras.utils.set_random_seed(101)

print(
f"keras backend: {keras.config.backend()}\n"
f"keras version: {keras.version()}\n"
f"tensorflow version:  {tf.__version__}\n"
)

# get keras backend
keras_backend = keras.config.backend()

strategy = tf.distribute.MirroredStrategy()
total_device = strategy.num_replicas_in_sync

print('Keras backend ', keras_backend)
print('Total device found ', total_device)

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
base_save_path = os.path.join(project_root, "experiments", "msd_brain")
unetrplusplus_path = os.path.join(base_save_path, "unetrplusplus")
os.makedirs(unetrplusplus_path, exist_ok=True)

# Subfolders
logs_path = os.path.join(unetrplusplus_path, "logs")
history_path = os.path.join(unetrplusplus_path, "history")
plots_path = os.path.join(unetrplusplus_path, "plots")
os.makedirs(logs_path, exist_ok=True)
os.makedirs(history_path, exist_ok=True)
os.makedirs(plots_path, exist_ok=True)

# Timestamp
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

# Save path for best model weights
save_path = os.path.join(unetrplusplus_path, f"best_model_weights_{timestamp}.weights.h5")

# File for containing the learning history
history_file = os.path.join(history_path, f"training_history.csv")

# Load datasets
tfrecord_pattern = os.path.join(project_root, "data", "msd_brain", "tfrecords", "{}_shard_*.tfrec")

# batch size for training
train_batch = ExperimentConfig.batch_size_train * total_device

train_ds = data_loader(
tfrecord_pattern.format("training"),
batch_size=train_batch,
shuffle=True
)
val_ds = data_loader(
tfrecord_pattern.format("validation"),
batch_size=ExperimentConfig.batch_size_val,
shuffle=False
)
test_ds = data_loader(
tfrecord_pattern.format("test"),
batch_size=ExperimentConfig.batch_size_val,
shuffle=False
)

with strategy.scope():
model = get_model(total_device)

checkpoint_dir = os.path.join(unetrplusplus_path, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)

with strategy.scope():
ckpt = tf.train.Checkpoint(
epoch=tf.Variable(0),          # epoch counter — saved as part of checkpoint
optimizer=model.optimizer,     # optimizer state
model=model                    # model weights
)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=3)

# Validation with sliding window callback
swi_callback_metric = get_inference_metric()
# Create SWI callback
swi_callback = SlidingWindowInferenceCallback(
model,
dataset=val_ds,
metrics=swi_callback_metric,
num_classes=ExperimentConfig.num_classes,
interval= ExperimentConfig.sliding_window_interval,
overlap=ExperimentConfig.sliding_window_overlap,
roi_size=(ExperimentConfig.input_shape[0],ExperimentConfig.input_shape[1],ExperimentConfig.input_shape[2]),
sw_batch_size=ExperimentConfig.sw_batch_size * total_device ,
save_path=save_path
)

# TFCheckpointCallback (save model, optimizer, epoch + SWI best score)
tf_ckpt_callback = TFCheckpointCallback(ckpt, ckpt_manager, swi_callback, checkpoint_dir)

# History callback
# Load previous history if exists
if os.path.exists(history_file):
prev_history = pd.read_csv(history_file).to_dict(orient='list')
else:
prev_history = {}
history_callback = HistorySaverCallback(history_file, initial_history=prev_history)

# Resume or start from scratch
if ckpt_manager.latest_checkpoint:
ckpt.restore(ckpt_manager.latest_checkpoint)
initial_epoch = int(ckpt.epoch.numpy())
print(f"[Resume] Restored checkpoint:  starting from epoch {initial_epoch}")

# Restore SWI best score
best_score_file = os.path.join(checkpoint_dir, "swi_best_score.json")
if os.path.exists(best_score_file):
with open(best_score_file, "r") as f:
swi_callback.best_score = json.load(f).get("best_score", -float("inf"))
print(f"[Resume] Restored SWI best validation score: {swi_callback.best_score}")
else:
print(f"[Resume] Couldn't Restore SWI best validation score")
else:
initial_epoch = 0
print("[Resume] No checkpoint found. Starting from scratch.")

print(f"Model size: {model.count_params() / 1e6:.2f} M")

start_time = time.time()

with strategy.scope():
history = model.fit(
train_ds,
epochs=ExperimentConfig.epochs,
initial_epoch=initial_epoch,
callbacks=[
swi_callback,
tf_ckpt_callback,
history_callback
])

end_time = time.time()
training_time = end_time - start_time
print(f"Total training time (seconds): {training_time:.2f}")

# Save history to CSV
full_history = history_callback.full_history
# Save CSV
pd.DataFrame(full_history).to_csv(history_file, index=False)

# Plot loss
plt.figure(figsize=(10, 5))
plt.plot(full_history['loss'], label='train_loss')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.legend()
plt.grid()
plt.savefig(os.path.join(plots_path, f"loss_curve_{timestamp}.png"))
plt.close()

# Plot average Dice
if 'dice' in full_history:
plt.figure(figsize=(10, 5))
plt.plot(full_history['dice'], label='train_dice')
plt.xlabel("Epoch")
plt.ylabel("Average Dice")
plt.title("Training Average Dice")
plt.legend()
plt.grid()
plt.savefig(os.path.join(plots_path, f"dice_curve_{timestamp}.png"))
plt.close()

print("Training and saving plots finished successfully.")

if __name__ == "__main__":
main()

Um Fehler zu vermeiden, habe ich derzeit fast alles, was mit Training zu tun hat, in strategy.scope() eingefügt, einschließlich einiger Objekte, bei denen ich nicht sicher bin, ob sie TensorFlow-Variablen erstellen oder nicht.
In dem von mir erstellten Bereich konkret:
  • Das Modell
  • Der Optimierer
  • Der Verlust
  • Alle Trainingsmetriken
  • Eine Metrik, die von einem benutzerdefinierten Validierungsrückruf verwendet wird
  • Checkpoint-Objekte (

    Code: Select all

    tf.train.Checkpoint
    , CheckpointManager)
  • Rückrufe, die auf das Modell und die Metriken verweisen
Datensätze, Pfade, Protokollierung und reine Python-Dienstprogramme werden außerhalb des Geltungsbereichs erstellt.
Mein aktuelles Verständnis ist:
  • Objekte, die TensorFlow-Variablen erstellen (Modell, Optimierer, Metriken), müssen innerhalb von strategy.scope() erstellt werden.
  • Objekte, die Metriken besitzen oder aktualisieren (z. B. benutzerdefinierte Rückrufe, die Validierungswerte verfolgen), sollten ebenfalls innerhalb des Bereichs erstellt werden.
  • Checkpoint-Objekte sollten innerhalb des Bereichs erstellt werden, damit sie verteilte Variablen korrekt verfolgen.
  • Die Erstellung von Datensätzen muss nicht innerhalb des Bereichs erfolgen.
Mein größtes Anliegen ist also, dass es einige Objekte gibt, bei denen ich nicht 100 % sicher bin, ob sie intern gemeinsam genutzte TensorFlow-Variablen erstellen (z. B. benutzerdefinierte Rückrufe oder Dienstprogramme). Klassen, die Metriken oder Modelle akzeptieren) tf.distribute.MirroredStrategy und strategy.scope()?

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post