基于vgg16的猫狗识别(二分类)

基于vgg16的猫狗识别(二分类)

python代码如下:

from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import RMSprop
import matplotlib.pyplot as plt

train_path = r"D:\Desktop\Dog_Cat\train" #训练集目录
valid_path = r"D:\Desktop\Dog_Cat\valid" #验证集目录
i1 = ImageDataGenerator(rescale=1/255, rotation_range=40, width_shift_range=0.2)
i2 = ImageDataGenerator(rescale=1/255)
f1 = i1.flow_from_directory(train_path, target_size=(150, 150), batch_size=20, class_mode="binary")
f2 = i1.flow_from_directory(valid_path, target_size=(150, 150), batch_size=20, class_mode="binary")
# 2.构建模型
model = Sequential()
vgg = VGG16(include_top=False, input_shape=(150, 150, 3))
vgg.summary()
for i, j in enumerate(vgg.layers):
    if i >= 17:
        j.trainable = True
    else:
        j.trainable = False
vgg.summary()
model.add(vgg)
model.add(Flatten())
model.add(Dense(units=1, activation="sigmoid"))
model.compile(optimizer=RMSprop(learning_rate=1E-4), loss="binary_crossentropy", metrics=["acc"])
model.summary()
history = model.fit_generator(generator=f1, epochs=15, validation_data=f2)
plt.rcParams['font.family'] = ['sans-serif']
plt.rcParams['font.sans-serif'] = ['SimHei']
accuracy = history.history['acc']
val_accuracy = history.history['val_acc']
epochs = range(1, len(accuracy)+1)
plt.plot(epochs, accuracy, label='训练精度', c = 'r')
plt.plot(epochs, val_accuracy, label='验证精度', c = 'b')
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.title('author:peiyuanman')
plt.legend()
plt.show()

model.save('DogCatModelVGG16.h5')


文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
心中带点小风骚的头像心中带点小风骚普通用户
上一篇 2022年6月13日 上午11:31
下一篇 2022年6月13日 上午11:33

相关推荐