Das Convolutional Neural Network erreicht eine unerwartet hohe Validierungsgenauigkeit. Was könnte die Ursache dafür sei
Posted: 28 Dec 2024, 19:11
Ich baue/optimiere ein CNN für die Klassifizierung von Autos aus diesem Datensatz.
Mein Basismodell erreicht durch eine sehr einfache Modellarchitektur eine überraschend hohe Genauigkeit. Ich befürchte, dass es zu Datenlecks kommt, wenn die Daten nicht geladen werden korrekt eingegeben, ich hätte also bitte ein paar Ratschläge.
Der Datensatz wird geladen
Erstellen des Modells
Passendes Modell

Mein Basismodell erreicht durch eine sehr einfache Modellarchitektur eine überraschend hohe Genauigkeit. Ich befürchte, dass es zu Datenlecks kommt, wenn die Daten nicht geladen werden korrekt eingegeben, ich hätte also bitte ein paar Ratschläge.
Der Datensatz wird geladen
Code: Select all
batch_size = 16
img_size = (64, 64)
train_dataset, val_dataset = tf.keras.utils.image_dataset_from_directory(
data_dir,
label_mode='categorical',
seed=1,
subset='both',
validation_split=0.2,
image_size=img_size,
batch_size=batch_size,
)
normalization_layer = Rescaling(1./255)
train_dataset = train_dataset.map(lambda x, y: (normalization_layer(x), y))
val_dataset = val_dataset.map(lambda x, y: (normalization_layer(x), y))
Code: Select all
def baseline_model(input_shape=[64, 64, 3]):
model = Sequential([
# 1st Conv Layer
Conv2D(filters=16, kernel_size=(3, 3), activation='relu', padding='valid', input_shape=input_shape),
# Pool Layer
MaxPooling2D((2, 2)),
# 2nd Conv Layer
Conv2D(filters=32, kernel_size=(3, 3), activation='relu', padding='valid'),
# Pool Layer
MaxPooling2D((2, 2)),
# Flatten
Flatten(),
# Fully connected layer
Dense(64, activation='relu'),
Dense(5, activation='softmax')
])
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy', Precision(name='precision'), Recall(name='recall')])
return model
baseline_model = baseline_model()
Code: Select all
history = baseline_model.fit(
train_dataset,
epochs=10,
validation_data=val_dataset,
#callbacks=[ConfusionMatrixCallback(val_dataset, class_names)]
)
plot_training_history(history)
