Änderung der Batchnormalisierungsimpuls beim Training in Keras 3
Posted: 21 May 2025, 14:39
Ich kaufe ein benutzerdefiniertes Modell mit Keras (Keras 3.3.3 mit Python 3.9.19) und ich möchte den Impuls meiner Batchnormalisierung Schichten während des Trainings erhöhen. Verwenden einer benutzerdefinierten Trainingsschleife. self.bn_momentum = tf.variable (bn_momentum, trainable = false) , übergeben Sie es an meine Ebenen und aktualisieren Sie es mit einem Rückruf während des Trainings. Wenn Sie jedoch versuchen, eine keras zu verwenden.Variable Ich erhalte die folgende Fehlermeldung:
Ich habe also versucht, die Keras.Variable zu entfernen und stattdessen einen einfachen Float zu verwenden. Das Training scheint damit einverstanden zu sein, aber ich vermute, dass nichts unter der Motorhaube passiert. Reproduzierbares Beispiel: < /p>
Code: Select all
TypeError: Exception encountered when calling CustomClassifier.call().
float() argument must be a string or a number, not 'Variable'
Code: Select all
import numpy as np
import keras
from keras import layers
class CustomClassifier(keras.Model):
def __init__(self, bn_momentum=0.99, **kwargs):
super().__init__(**kwargs)
self.bn_momentum = bn_momentum # or keras.Variable(bn_momentum, trainable=False)
self.input_layer = layers.Dense(8, activation="softmax", name="input_layer")
self.hidden_layer = CustomLayer(16, bn_momentum=self.bn_momentum, name="hidden_layer")
self.output_layer = layers.Dense(4, activation="softmax", name="output_scores")
def call(self, input_points, training=None):
x = self.input_layer(input_points, training=training)
x = self.hidden_layer(x, training=training)
return self.output_layer(x)
class CustomLayer(layers.Layer):
def __init__(self, units, bn_momentum, **kwargs):
super().__init__(**kwargs)
self.units = units
self.bn_momentum = bn_momentum
def build(self, batch_input_shape):
self.dense = layers.Dense(self.units, input_shape=batch_input_shape)
self.bn = layers.BatchNormalization(momentum=self.bn_momentum)
self.activation = layers.ReLU()
def call(self, x, training=None):
x = self.dense(x)
x = self.bn(x, training=training)
return self.activation(x)
class BatchNormalizationMomentumScheduler(keras.callbacks.Callback):
"""The decay rate for batch normalization starts with 0.5 and is gradually
increased to 0.99."""
def __init__(self,):
super().__init__()
self.initial_momentum = 0.5
self.final_momentum = 0.99
self.rate = 0.05
def on_train_begin(self, logs=None):
self.model.bn_momentum = self.initial_momentum
print(f"Initial BatchNormalization momentum is {self.model.bn_momentum:.3f}.")
def on_epoch_begin(self, epoch, logs=None):
if epoch:
new_bn_momentum = self.initial_momentum + self.rate * epoch
new_bn_momentum = np.min([new_bn_momentum, self.final_momentum])
self.model.bn_momentum = new_bn_momentum
print(f"Epoch {epoch}: BatchNormalization momentum is {self.model.bn_momentum:.3f}.")
if __name__ =="__main__":
# Generate random data
X = np.random.random((1024, 8))
y = np.random.choice([0, 1, 2, 3], 1024)
# Instanciate and train model
model = CustomClassifier()
model.build((64, 8))
model.summary()
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam")
history = model.fit(X, y, epochs=10, callbacks=[BatchNormalizationMomentumScheduler()])
# Check final
print("Model momentum after training:", model.bn_momentum)