Wird tf.keras.Sequential, das mehrere benutzerdefinierte Ebenen enthält, in meinem Fall korrekt vollständig serialisierbPython

Python-Programme
Anonymous
 Wird tf.keras.Sequential, das mehrere benutzerdefinierte Ebenen enthält, in meinem Fall korrekt vollständig serialisierb

Post by Anonymous »

Ich implementieren eine U-Net-Variante in TensorFlow/Keras mit benutzerdefinierten Ebenen. In einer meiner benutzerdefinierten Ebenen UPDoubleConv habe ich einen Sequential self.blocks, der ein wiederholtes Muster von UpSampling2D enthält, gefolgt von einer benutzerdefinierten DoubleConv-Ebene. Ich möchte sicherstellen, dass dieser Teil mithilfe von model.save() und tf.keras.models.load_model() vollständig serialisierbar und deserialisierbar ist. Hier ist der relevante Code:

Code: Select all

@tf.keras.utils.register_keras_serializable(package="CustomLayers", name="UPDoubleConv")
class UPDoubleConv(layers.Layer):
def __init__(self, out_channels, num_layer=1, **kwargs):
super(UPDoubleConv, self).__init__(**kwargs)
self.out_channels = out_channels
self.num_layer = num_layer

self.deconv = tf.keras.Sequential([
layers.UpSampling2D(size=(2, 2), interpolation="bilinear"),
DoubleConv(out_channels=out_channels, spatial_dim=2)
])

blocks_layers = []
for _ in range(num_layer):
blocks_layers.append(layers.UpSampling2D(size=(2,2), interpolation="bilinear"))
blocks_layers.append(DoubleConv(out_channels=out_channels, spatial_dim=2))
self.blocks = tf.keras.Sequential(blocks_layers)

def call(self, x, training=None):
x = self.deconv(x, training=training)
x = self.blocks(x, training=training)
return x

def get_config(self):
config = super().get_config()
config.update({
"out_channels": self.out_channels,
"num_layer": self.num_layer,
})
return config

@classmethod
def from_config(cls, config):
return cls(**config)

@tf.keras.utils.register_keras_serializable(package="CustomLayers", name="DoubleConv")
class DoubleConv(tf.keras.layers.Layer):
def __init__(self, out_channels, spatial_dim=2, **kwargs):
super(DoubleConv, self).__init__(**kwargs)
assert spatial_dim in (2, 3), f"spatial_dim must be 2 or 3, got {spatial_dim}"
self.out_channels = out_channels
self.spatial_dim = spatial_dim

Conv = layers.Conv2D if spatial_dim == 2 else layers.Conv3D

self.double_conv = tf.keras.Sequential([
Conv(out_channels, kernel_size=3, strides=1, padding="same", use_bias=False),
layers.BatchNormalization(),
layers.ReLU(),
Conv(out_channels, kernel_size=3, strides=1, padding="same", use_bias=False),
layers.BatchNormalization(),
layers.ReLU()
])

self.identity = tf.keras.Sequential([
Conv(out_channels, kernel_size=1, strides=1, padding="valid", use_bias=False),
layers.BatchNormalization(),
layers.ReLU()
])

def call(self, x, training=None):
out = self.double_conv(x, training=training)
skip = self.identity(x, training=training)
return out + skip

def get_config(self):
config = super().get_config()
config.update({
"out_channels": self.out_channels,
"spatial_dim": self.spatial_dim,
})
return config

@classmethod
def from_config(cls, config):
return cls(**config)

Meine Fragen:
1- Wird das

Code: Select all

self.blocks = tf.keras.Sequential(blocks_layers)
Stellen Sie sicher, dass alle darin enthaltenen UpSampling2D-Ebenen und DoubleConv-Ebenen beim Speichern und Laden eines Modells vollständig serialisiert und korrekt deserialisiert werden? Gibt es irgendwelche Fallstricke, die ich beachten sollte?
2- Wird das Trainingsargument in

Code: Select all

x = self.blocks(x, training=training)
korrekt an die Ebenen innerhalb von self.blocks weitergegeben werden, die es benötigen, und für die Ebenen übersprungen werden, die es nicht benötigen?

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post