使用神图像分类经网络经行

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from time import time
from tensorflow.keras.datasets import mnist

(x_train_image, y_train_label), (x_test_image, y_test_label) = mnist.load_data()

print('x_train:', x_train_image.shape)
print('y_train:', y_train_label.shape)

import matplotlib.pyplot as plt
def plot_image(image):
    fig = plt.gcf()
    fig.set_size_inches(2,2) #设置图形的宽为2英寸,高为2英寸
    plt.imshow(image, cmap='binary') #以黑白灰度显示图形
    plt.show()
plot_image(x_train_image[0])

# 将图像转化为一维向量
x_train = x_train_image.reshape(60000, 784).astype('float32')
x_test = x_test_image.reshape(10000, 784).astype('float32')

# 图像标准化
x_train = x_train / 255
x_test = x_test / 255

print(y_train_label[:5]) #查看原本的前5个标签

from tensorflow.python.keras.utils import np_utils

y_train = np_utils.to_categorical(y_train_label)
y_test = np_utils.to_categorical(y_test_label)

print('one-hot:')
print(y_train[:5]) #查看经过one-hot编码的前5个标签

from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense

model = Sequential()
model.add(Dense(units=256,
                input_dim=784,
                kernel_initializer='normal',
                activation='relu'))
model.add(Dense(units=10,
                kernel_initializer='normal',
                activation='softmax'))
print(model.summary())

model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

train_history = model.fit(x=x_train,
                          y=y_train,
                          validation_split=0.2,
                          epochs=10,
                          batch_size=200,
                          verbose=2)

def show_train_history(train_history, train, validation):
    plt.plot(train_history.history[train])
    plt.plot(train_history.history[validation])
    plt.title('Train history')
    plt.ylabel(train)
    plt.xlabel('Epoch')
    plt.legend(['train', 'validation'], loc='upper left')
    plt.show()
show_train_history(train_history, 'accuracy', 'val_accuracy')
scores = model.evaluate(x_test, y_test)
print()
print('accuracy=', scores[1])

prediction = model.predict_classes(x_test)
print(prediction)

def plot_images_labels_prediction(images, labels, prediction, idx, num=10):
    fig = plt.gcf()
    fig.set_size_inches(12, 5)
    if num>25: num=25
    for i in range(num):
        ax = plt.subplot(5, 5, i+1)
        ax.imshow(images[idx], cmap='binary')
        title = 'label=' + str(labels[idx])
        if len(prediction) > 0:
            title += '.predict=' + str(prediction[idx])
        ax.set_title(title, fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
        idx += 1
    plt.show()
plot_images_labels_prediction(x_test_image, y_test_label, prediction, idx=340)


from PIL import Image
import numpy as np

img = Image.open('white_0.jpg').convert('L')
print('img size:', img.size)
img = np.array(img).reshape(1, 784).astype('float32')
img = img / 255
pred_label = model.predict_classes(img)
print('predict my img',pred_label)

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2022年5月18日
下一篇 2022年5月18日

相关推荐