Hier ist eine vereinfachte Version meines Codes:
Code: Select all
@keras.saving.register_keras_serializable(package="U-Net", name="Encoder")
class Encoder(layers.Layer):
def __init__(self, in_channels, out_channels, basic_module=DoubleConv, **kwargs):
super().__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.basic_module_class = basic_module
# Instantiate the basic module
self.basic_module = basic_module(
in_channels=in_channels,
out_channels=out_channels,
encoder=True
)
def call(self, x, training=False):
x = self.basic_module(x, training=training)
return x
def get_config(self):
config = super().get_config()
config.update({
"in_channels": self.in_channels,
"out_channels": self.out_channels,
"basic_module": keras.saving.serialize_keras_object(self.basic_module_class),
})
return config
@classmethod
def from_config(cls, config):
basic_module_class = keras.saving.deserialize_keras_object(config.pop("basic_module"))
return cls(basic_module=basic_module_class, **config)
Beispielverwendung:
Code: Select all
encoder = Encoder(in_channels=1, out_channels=32, basic_module=DoubleConv)
encoder_config = encoder.get_config()
new_encoder = Encoder.from_config(encoder_config)
Mobile version