为什么 Keras 的 ModelCheckPoint 没有在训练期间以最高验证准确度保存我的最佳模型?

青葱年少 tensorflow 245

原文标题Why doesn’t Keras’ ModelCheckPoint save my best model with the highest validation accuracy during training?

我正在用 Keras 训练 ResNet18。如下图,我使用 ModelCheckPoint 来保存基于验证准确度的最佳模型。

model = ResNet18(2)
model.build(input_shape = (None,128,128,3))

model.summary()
model.save_weights('./Adam_resnet18_original.hdf5')
opt = tf.keras.optimizers.Adam(learning_rate=0.001)

model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])
mcp_save = ModelCheckpoint('Adam_resnet18_weights.hdf5', save_best_only=True, monitor='val_accuracy', mode='max')

batch_size = 128
model.fit(generator(batch_size, x_train, y_train), steps_per_epoch = len(x_train) // batch_size, validation_data = generator(batch_size, x_valid, y_valid), validation_steps = len(x_valid) // batch_size, callbacks=[mcp_save], epochs = 300)

如下图所示,在训练过程中验证准确率最高可达 0.8281。 训练历史

但是,当我使用最终模型通过下面的代码获得最终验证准确度时,我得到的准确度仅为 0.78109。谁能告诉我这里可能存在什么问题?非常感谢!

model.load_weights('Adam_resnet18_weights.hdf5')

predictions_validation = model.predict(generator(batch_size, x_valid, y_valid), steps = len(x_valid) // batch_size + 1)
predictions_validation_label = np.argmax(predictions_validation, axis=1)
Y_valid_label = np.argmax(Y_valid, axis=1)
accuracy_validation_conventional = accuracy_score(Y_valid_label, predictions_validation_label[:len(Y_valid_label)])
print(f'Accuracy on the validation set: {accuracy_validation_conventional}')

原文链接:https://stackoverflow.com//questions/71993936/why-doesnt-keras-modelcheckpoint-save-my-best-model-with-the-highest-validatio

回复

我来回复
  • Greystormy的头像
    Greystormy 评论

    这里最大的线索是,最近几个 epoch 的准确率一直保持在 1.000。由此看来,这个模型是过拟合的。对过拟合的直观理解就像一个学生一遍又一遍地进行完全相同的测试,以他们只是记住每个问题的答案,无法适应措辞的微小变化。网络已经“记住”了训练数据,但无法适应测试数据。

    弄清楚最好的方法是什么有点棘手,因为我不知道您正在使用的数据集的大小或模型的细节。我假设数据集大小合适(如果不是,请尝试数据扩充)并且您已经定义了一个多层网络(如果您从 Keras 导入此模型,您的选择可能会更有限)。不过这里有一些建议:

    1. 早点停下来。将您的 epochs 设置为较小的数字以防止过度训练。这是最简单和最简单的解决方案,在您的情况下这很有意义,因为过去几个时期的准确度已经达到 1.00。如果您能够随时间绘制准确度和损失,这将有所帮助,因为您将能够直观地确定过度拟合开始的时期数,如您在此示例中所见。有更好的方法来实现提前停止,但简单地运行更少的 epoch 可能就足以满足您的目的。
    2. 添加辍学层。简而言之,这将“关闭”网络中的随机权重,从而防止网络过度依赖一小部分节点。这也是防止过拟合的常用技术。

    可以在此处找到更完整的解释以及其他建议。希望这有帮助!

    2年前 0条评论