Um dies zu handhaben, erstelle ich die Unterebenen in einer Schleife, weise jede einzelne mit setattr als Attribut zu und halte sie zur einfachen Iteration auch in einer Python-Liste:
Code: Select all
@keras.saving.register_keras_serializable()
class MyBlock(tf.keras.layers.Layer):
def __init__(self, n_layers):
super().__init__()
self.n_layers = n_layers
self.blocks = []
for i in range(n_layers):
layer = tf.keras.layers.Dense(8)
setattr(self, f"block_{i}", layer)
self.blocks.append(getattr(self, f"block_{i}"))
def call(self, x):
for layer in self.blocks:
x = layer(x)
return x
def get_config(self):
return {"n_layers": self.n_layers}
My understanding is that:
- Keras verfolgt nur Layer, die als Attribute zugewiesen sind
- Listen allein werden nicht verfolgt
- Mit setattr wird sichergestellt, dass jeder Sublayer korrekt registriert und serialisiert wird
Mobile version