在保存的 keras Inception v3 模型上开始迁移学习
tensorflow 355
原文标题 :Start transfer learning on a saved keras Inception v3 model
我有一个经过训练的 Keras Inception v3 模型,名为“nsfw.299×299.h5”,我从 Github.Github 链接下载:https://github.com/GantMan/nsfw_model 该模型将图像分为以下 5 类:
- 中性的
- A片
- 无尽的
- 性感的
- 绘画
该模型在一些咖啡杯图像上给出了误报并将它们分类为“色情”当它应该被归类为“中性”时。所以为了消除偏见,我下载了大约 400 张咖啡杯图像并下载了相同数量的图像对于其他课程也是如此,并且想再次训练这个模型。我该如何进行训练?
下面是模型摘要的最后几行。
__________________________________________________________________________________________________
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
__________________________________________________________________________________________________
activation_94 (Activation) (None, 8, 8, 192) 0 batch_normalization_94[0][0]
__________________________________________________________________________________________________
mixed10 (Concatenate) (None, 8, 8, 2048) 0 activation_86[0][0]
mixed9_1[0][0]
concatenate_2[0][0]
activation_94[0][0]
__________________________________________________________________________________________________
average_pooling2d_10 (AveragePo (None, 1, 1, 2048) 0 mixed10[0][0]
__________________________________________________________________________________________________
flatten_1 (Flatten) (None, 2048) 0 average_pooling2d_10[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 256) 524544 flatten_1[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout) (None, 256) 0 dense_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 128) 32896 dropout_2[0][0]
__________________________________________________________________________________________________
dropout_3 (Dropout) (None, 128) 0 dense_2[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 5) 645 dropout_3[0][0]
==================================================================================================
Total params: 22,360,869
Trainable params: 17,076,261
Non-trainable params: 5,284,608
下面给出了用于训练该模型的代码,我从同一个 Github 存储库中获取了该代码,我从以下位置下载了模型:https://github.com/GantMan/nsfw_model/blob/master/tf1/training/inceptionv3_transfer /train_initialization.py
conv_base = InceptionV3(
weights='imagenet',
include_top=False,
input_shape=(height, width, constants.NUM_CHANNELS)
)
# First time run, no unlocking
conv_base.trainable = False
# Let's see it
print('Summary')
print(conv_base.summary())
# Let's construct that top layer replacement
x = conv_base.output
x = AveragePooling2D(pool_size=(8, 8))(x)
x - Dropout(0.4)(x)
x = Flatten()(x)
x = Dense(256, activation='relu', kernel_initializer=initializers.he_normal(seed=None), kernel_regularizer=regularizers.l2(.0005))(x)
x = Dropout(0.5)(x)
# Essential to have another layer for better accuracy
x = Dense(128,activation='relu', kernel_initializer=initializers.he_normal(seed=None))(x)
x = Dropout(0.25)(x)
predictions = Dense(constants.NUM_CLASSES, kernel_initializer="glorot_uniform", activation='softmax')(x)
print('Stacking New Layers')
model = Model(inputs = conv_base.input, outputs=predictions)
# Load checkpoint if one is found
if os.path.exists(weights_file):
print ("loading ", weights_file)
model.load_weights(weights_file)
# Get all model callbacks
callbacks_list = callbacks.make_callbacks(weights_file)
print('Compile model')
# originally adam, but research says SGD with scheduler
# opt = Adam(lr=0.001, amsgrad=True)
opt = SGD(momentum=.9)
model.compile(
loss='categorical_crossentropy',
optimizer=opt,
metrics=['accuracy']
)
# Get training/validation data via generators
train_generator, validation_generator = generators.create_generators(height, width)
print('Start training!')
history = model.fit_generator(
train_generator,
callbacks=callbacks_list,
epochs=constants.TOTAL_EPOCHS,
steps_per_epoch=constants.STEPS_PER_EPOCH,
shuffle=True,
workers=4,
use_multiprocessing=False,
validation_data=validation_generator,
validation_steps=constants.VALIDATION_STEPS
)
# Save it for later
print('Saving Model')
model.save("nsfw." + str(width) + "x" + str(height) + ".h5")
回复
我来回复-
Jirayu Kaewprateep 评论
有很多例子,我花了一些时间在使用 image_genertor 和 fit_genertor 的类似方法上,这将在返回消息中很快过时。
关于您的问题,要使用 CIFAR_10 或 CIFAR_100 重新训练此模型,您可以使用 pip install cifar10 或 cifar100。或者,您可以使用数据集或数据生成器创建自己的数据集:以下是使用标准数据集 CIFAR-10 的示例。
""""""""""""""""""""""""""""""""""""""""""""""""""""""""" : Model Summary """"""""""""""""""""""""""""""""""""""""""""""""""""""""" model.compile(optimizer=optimizer, loss=lossfn, metrics=['accuracy']) train_generator = tf.keras.preprocessing.image.ImageDataGenerator( featurewise_center=True, featurewise_std_normalization=True, rotation_range=20, width_shift_range=0.2, height_shift_range=0.2, horizontal_flip=True, validation_split=0.2) (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() train_generator.fit(x_train) history = model.fit( train_generator.flow(x_train, y_train, batch_size=32, subset='training'), epochs=50, steps_per_epoch=1 ) plt.plot(history.history['accuracy'], label='accuracy') plt.xlabel('Epoch') plt.ylabel('Accuracy') plt.show() plt.close() input('...') ...
对于提取 CIFAR,您可以执行以下操作:
""""""""""""""""""""""""""""""""""""""""""""""""""""""""" : DataSets """"""""""""""""""""""""""""""""""""""""""""""""""""""""" dataset_cat = tf.data.Dataset.list_files("F:\\datasets\\downloads\\PetImages\\train\\Cat\\*.png") dataset_len = tf.data.experimental.cardinality(dataset_cat).numpy() list_label_cat = [0 for i in range(dataset_len)] list_image = [] for elem in dataset.take(10): element_as_string = str(elem.numpy()).split('\'') image = plt.imread(os.fspath(element_as_string[1])) list_image.append(image) dataset = tf.data.Dataset.from_tensor_slices((list_image, list_label)) dataset = dataset.batch(10)
2年前