基于U-Net 的图像分割(使用Crack 500数据集)

文章目录

    • 0.环境要求
    • 1.加载包和数据集
      • 1.1加载包
      • 1.2加载数据
    • 2.加载和增强图像Generator
    • 3.对训练集做增强用于下面的训练模型
    • 3.构建ResUNet模型
    • 4.Loss & Compile
    • 5.Training
    • 6.Testing

0.环境要求

Crack 500数据集下载:
https://download.csdn.net/download/QH2107/87423329

创建一个环境,python版本为3.6.13

建一个requirements.txt文件

#新建requirements.txt
absl-py==0.15.0
aiohttp==3.7.4.post0
albumentations==1.3.0
argon2-cffi==20.1.0
astor==0.8.1
async-generator==1.10
async-timeout==3.0.1
attrs==21.4.0
backcall==0.2.0
bleach==4.1.0
blinker==1.4
brotlipy==0.7.0
cachetools==4.2.2
certifi==2021.5.30
cffi==1.14.6
chardet==4.0.0
charset-normalizer==2.0.4
click==8.0.3
colorama==0.4.4
coverage==5.5
cryptography==3.4.7
cycler==0.11.0
Cython==0.29.24
dataclasses==0.8
decorator==5.1.1
defusedxml==0.7.1
entrypoints==0.3
gast==0.2.2
google-auth==2.6.0
google-auth-oauthlib==0.4.4
google-pasta==0.2.0
graphviz==0.19.1
grpcio==1.36.1
h5py==2.10.0
idna==3.3
idna-ssl==1.1.0
imageio==2.15.0
importlib-metadata==4.8.1
ipykernel==5.3.4
ipython==7.16.1
ipython-genutils==0.2.0
ipywidgets==7.6.5
jedi==0.17.0
Jinja2==3.0.3
joblib==1.1.1
jsonschema==3.0.2
jupyter==1.0.0
jupyter-client==7.1.2
jupyter-console==6.4.3
jupyter-contrib-core==0.4.0
jupyter-contrib-nbextensions==0.7.0
jupyter-core==4.8.1
jupyter-highlight-selected-word==0.2.0
jupyter-latex-envs==1.4.6
jupyter-nbextensions-configurator==0.6.1
jupyterlab-pygments==0.1.2
jupyterlab-widgets==1.0.0
Keras==2.3.1
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.2
kiwisolver==1.3.1
lxml==3.8.0
Markdown==3.3.4
MarkupSafe==2.0.1
matplotlib==3.3.4
mistune==0.8.4
mkl-fft==1.3.0
mkl-random==1.1.1
mkl-service==2.3.0
multidict==5.1.0
nb-conda==2.2.1
nb-conda-kernels==2.3.1
nbclient==0.5.3
nbconvert==6.0.7
nbformat==5.1.3
nest-asyncio==1.5.1
networkx==2.5.1
notebook==6.4.3
numpy==1.19.2
oauthlib==3.2.0
olefile==0.46
opencv-python==4.5.5.62
opt-einsum==3.3.0
packaging==21.3
pandas==1.1.5
pandocfilters==1.5.0
parso==0.8.3
pickleshare==0.7.5
Pillow==8.4.0
pip==21.3.1
prometheus-client==0.13.1
prompt-toolkit==3.0.20
protobuf==3.17.2
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.21
pydot==1.4.2
pydot-ng==2.0.0
pydotplus==2.0.2
Pygments==2.11.2
PyJWT==2.1.0
pyOpenSSL==21.0.0
pyparsing==3.0.4
pyreadline==2.1
pyrsistent==0.17.3
PySocks==1.7.1
python-dateutil==2.8.2
pytz==2022.7.1
PyWavelets==1.1.1
pywin32==228
pywinpty==0.5.7
PyYAML==6.0
pyzmq==22.2.1
qtconsole==5.2.2
QtPy==2.0.1
qudida==0.0.4
requests==2.27.1
requests-oauthlib==1.3.0
rsa==4.7.2
scikit-image==0.17.2
scikit-learn==0.24.2
scipy==1.5.2
seaborn==0.11.2
Send2Trash==1.8.0
setuptools==58.0.4
six==1.16.0
sklean==0.0.3
tensorboard==2.4.0
tensorboard-plugin-wit==1.6.0
tensorflow==2.1.0
tensorflow-estimator==2.1.0
termcolor==1.1.0
terminado==0.9.4
testpath==0.5.0
tf-unet==0.1.2
threadpoolctl==3.1.0
tifffile==2020.9.3
tornado==6.1
traitlets==4.3.3
typing_extensions==4.1.1
urllib3==1.26.8
wcwidth==0.2.5
webencodings==0.5.1
Werkzeug==0.16.1
wheel==0.37.1
widgetsnbextension==3.5.1
win-inet-pton==1.1.0
wincertstore==0.2
wrapt==1.12.1
yarl==1.6.3
zipp==3.6.0

通过下面命令安全所需要的包

conda install --yes --file requirements.txt
或
pip install -r requirements.txt

1.加载包和数据集

! nvidia-smi 
    Mon Feb  6 12:48:07 2023       
    +-----------------------------------------------------------------------------+
    | NVIDIA-SMI 512.78       Driver Version: 512.78       CUDA Version: 11.6     |
    |-------------------------------+----------------------+----------------------+
    | GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
    | Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
    |                               |                      |               MIG M. |
    |===============================+======================+======================|
    |   0  NVIDIA GeForce ... WDDM  | 00000000:01:00.0 Off |                  N/A |
    | N/A   43C    P0    26W /  N/A |      0MiB /  6144MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
                                                                                   
    +-----------------------------------------------------------------------------+
    | Processes:                                                                  |
    |  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
    |        ID   ID                                                   Usage      |
    |=============================================================================|
    |  No running processes found                                                 |
    +-----------------------------------------------------------------------------+

1.1加载包

import os
import cv2
import shutil
import math
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
sns.set()
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K
from tensorflow.keras.utils import Sequence
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, MaxPooling2D, Conv2DTranspose, Add, concatenate, average, Dropout
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from sklearn.metrics import classification_report, roc_auc_score, accuracy_score
from albumentations import Compose, OneOf, Flip, Rotate, RandomContrast, RandomGamma, RandomBrightness, ElasticTransform, GridDistortion, OpticalDistortion, RGBShift, CLAHE
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from skimage.transform import resize
from sklearn.metrics import classification_report

1.2加载数据

#数据集的文件夹路径
train_image_dir = r'E:\dataset\CRACK500\train\image'
train_mask_dir = r'E:\dataset\CRACK500\train\mask'

valid_image_dir = r'E:\dataset\CRACK500\validation\image'
valid_mask_dir =r'E:\dataset\CRACK500\validation\mask'

test_image_dir = r'E:\dataset\CRACK500\test\image'
test_mask_dir = r'E:\dataset\CRACK500\test\mask'
#数据集的文件路径(image对应mask)
# 测试集
test_image_paths = sorted([os.path.join(test_image_dir, fname) for fname in os.listdir(test_image_dir) if fname.endswith(".png") and not fname.startswith(".")])
test_mask_paths = sorted([os.path.join(test_mask_dir, fname) for fname in os.listdir(test_mask_dir) if fname.endswith(".png") and not fname.startswith(".")])
print("Number of testing images : ", len(test_image_paths))
print("Number of testing masks : ", len(test_mask_paths))

# 训练集
train_image_files = sorted([os.path.join(train_image_dir, fname) for fname in os.listdir(train_image_dir) if fname.endswith(".png") and not fname.startswith(".")])
train_mask_files = sorted([os.path.join(train_mask_dir, fname) for fname in os.listdir(train_mask_dir) if fname.endswith(".png") and not fname.startswith(".")])
print("Number of training images : ", len(train_image_files))
print("Number of training masks : ", len(train_mask_files))

#验证集
valid_image_files = sorted([os.path.join(valid_image_dir, fname) for fname in os.listdir(valid_image_dir) if fname.endswith(".png") and not fname.startswith(".")])
valid_mask_files = sorted([os.path.join(valid_mask_dir, fname) for fname in os.listdir(valid_mask_dir) if fname.endswith(".png") and not fname.startswith(".")])
print("Number of validing images : ", len(valid_image_files))
print("Number of validing masks : ", len(valid_mask_files))
#结果
    Number of testing images :  1124
    Number of testing masks :  1124
    Number of training images :  1896
    Number of training masks :  1896
    Number of validing images :  348
    Number of validing masks :  348
batch_size = 4 #批大小,显存不够可以再小一点
img_dim=(320, 640) #图像大小

2.加载和增强图像Generator

class Generator(Sequence):

  def __init__(self, x_set, y_set, batch_size=5, img_dim=(128, 128), augment=False):
      self.x = x_set
      self.y = y_set
      self.batch_size = batch_size
      self.img_dim = img_dim
      self.augment = augment

  def __len__(self):
      return math.ceil(len(self.x) / self.batch_size)

  augmentations = Compose(
    [                   
      Flip(p=0.7),
      Rotate(p=0.7),
      OneOf([
              RandomContrast(),
              RandomGamma(),
              RandomBrightness()
            ], p=0.3),
      OneOf([
              ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
              GridDistortion(),
              OpticalDistortion(distort_limit=2, shift_limit=0.5)
            ], p=0.3),
    ])

  def __getitem__(self, idx):
      batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
      batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

      batch_x = np.array([cv2.resize(cv2.cvtColor(cv2.imread(file_name, -1), cv2.COLOR_BGR2RGB), (self.img_dim[1], self.img_dim[0])) for file_name in batch_x])
      batch_y = np.array([(cv2.resize(cv2.imread(file_name, -1), (self.img_dim[1], self.img_dim[0]))>0).astype(np.uint8) for file_name in batch_y])

      if self.augment is True:
        aug = [self.augmentations(image=i, mask=j) for i, j in zip(batch_x, batch_y)]
        batch_x = np.array([i['image'] for i in aug])
        batch_y = np.array([j['mask'] for j in aug])

      batch_y = np.expand_dims(batch_y, -1)

      return batch_x/255, batch_y/1  
# 对测试集处理
test1_generator=Generator(test_image_paths,test_mask_paths,batch_size,img_dim,False) 
# 生成的样本 (未增强)
for i, j in test1_generator:
    break

fig, axes = plt.subplots(1, 4, figsize=(13,2.5))
fig.suptitle('Original Images (test)', fontsize=15)
axes = axes.flatten()
for img, ax in zip(i[:4], axes[:4]):
    ax.imshow(img)
    ax.axis('off')
plt.tight_layout()
plt.show()

fig, axes = plt.subplots(1, 4, figsize=(13,3))
fig.suptitle('Original Masks (test)', fontsize=15)
axes = axes.flatten()
for img, ax in zip(j[:4], axes[:4]):
    ax.imshow(np.squeeze(img, -1), cmap='gray')
    ax.axis('off')
plt.tight_layout()
plt.show()


# 对训练集和验证集做同样的处理
train_generator = Generator(train_image_files, train_mask_files,batch_size,img_dim,False)
validation_generator = Generator(valid_image_files, valid_mask_files,batch_size,img_dim,False)
for i, j in train_generator:
    break

print(i.shape)
print(j.shape)
(4, 320, 640, 3)
(4, 320, 640, 1)
for i, j in validation_generator:
    break

print(i.shape)
print(j.shape)
(4, 320, 640, 3)
(4, 320, 640, 1)
# 生成的样本(未增强)
for i, j in train_generator:
    break

fig, axes = plt.subplots(1, 4, figsize=(13,2.5))
fig.suptitle('Original Images (train)', fontsize=15)
axes = axes.flatten()
for img, ax in zip(i[:4], axes[:4]):
    ax.imshow(img)
    ax.axis('off')
plt.tight_layout()
plt.show()

fig, axes = plt.subplots(1, 4, figsize=(13,2.5))
fig.suptitle('Original Masks (train)', fontsize=15)
axes = axes.flatten()
for img, ax in zip(j[:4], axes[:4]):
    ax.imshow(np.squeeze(img, -1), cmap='gray')
    ax.axis('off')
plt.tight_layout()
plt.show()


# 生成的样本 (未增强)
for i, j in validation_generator:
    break

fig, axes = plt.subplots(1, 4, figsize=(13,2.5))
fig.suptitle('Original Images (validation)', fontsize=15)
axes = axes.flatten()
for img, ax in zip(i[:4], axes[:4]):
    ax.imshow(img)
    ax.axis('off')
plt.tight_layout()
plt.show()

fig, axes = plt.subplots(1, 4, figsize=(13,2.5))
fig.suptitle('Original Masks (validation)', fontsize=15)
axes = axes.flatten()
for img, ax in zip(j[:4], axes[:4]):
    ax.imshow(np.squeeze(img, -1), cmap='gray')
    ax.axis('off')
plt.tight_layout()
plt.show()

3.对训练集做增强用于下面的训练模型

tg = Generator(train_image_files, train_mask_files, batch_size, img_dim, augment = True) #训练集
vg = Generator(valid_image_files, valid_mask_files, batch_size, img_dim, augment = False)#验证集
for i, j in tg:
  break

print(i.shape)
print(j.shape)
#结果
    (4, 320, 640, 3)
    (4, 320, 640, 1)
for i, j in vg:
  break

print(i.shape)
print(j.shape)
#结果
    (4, 320, 640, 3)
    (4, 320, 640, 1)
# Augmented train
for i, j in tg:
    break

fig, axes = plt.subplots(1, 4, figsize=(13,2.5))
fig.suptitle('Augmented Images', fontsize=15)
axes = axes.flatten()
for img, ax in zip(i[:4], axes[:4]):
    ax.imshow(img)
    ax.axis('off')
plt.tight_layout()
plt.show()

fig, axes = plt.subplots(1, 4, figsize=(13,2.5))
fig.suptitle('Augmented Masks', fontsize=15)
axes = axes.flatten()
for img, ax in zip(j[:4], axes[:4]):
    ax.imshow(np.squeeze(img, -1), cmap='gray')
    ax.axis('off')
plt.tight_layout()
plt.show()

3.构建ResUNet模型

import numpy as np
from tensorflow.keras.backend import int_shape
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, UpSampling2D, Add, BatchNormalization, Input, Activation, Concatenate
from keras.regularizers import l2
# BatchNormalization and Activation
def BN_Act(x, act = True):
    x = BatchNormalization()(x)
    if act == True:
        x = Activation("relu")(x)
    return x
#conv2d block
def conv2d_block(x, filters, kernel_size = (3, 3), padding = "same", strides = 1):
    conv = BN_Act(x)
    conv = Conv2D(filters, kernel_size, padding = padding, strides = strides)(conv)
    return conv
#Fixed layer.
def stem(x, filters, kernel_size=(3, 3), padding="same", strides=1):
    conv = Conv2D(filters, kernel_size, padding = padding, strides = strides)(x)
    conv = conv2d_block(conv, filters, kernel_size = kernel_size, padding = padding, strides = strides)
    
    #skip
    shortcut = Conv2D(filters, kernel_size = (1, 1), padding = padding, strides = strides)(x)
    shortcut = BN_Act(shortcut, act = False) # No activation in skip connection
    
    output = Add()([conv, shortcut])
    return output
# Residual Block
def residual_block(x, filters, kernel_size = (3, 3), padding = "same", strides = 1):
    res = conv2d_block(x, filters, kernel_size = kernel_size, padding = padding, strides = strides)
    res = conv2d_block(res, filters, kernel_size = kernel_size, padding = padding, strides = 1)
    
    shortcut = Conv2D(filters, kernel_size = (1, 1), padding = padding, strides = strides)(x)
    shortcut = BN_Act(shortcut, act = False) # No activation in skip connection
    
    output = Add()([shortcut, res])
    return output
# Upsampling Concatenation block
def upsample_concat_block(x, xskip):
    u = UpSampling2D((2, 2))(x)
    c = Concatenate()([u, xskip])
    return c
# MODEL
def ResUNet():
    f = [16, 32, 64, 128, 256]
    inputs = Input((img_dim[0], img_dim[1], 3))
    
    ## Encoder/downsampling/contracting path
    e0 = inputs
    e1 = stem(e0, f[0])
    e2 = residual_block(e1, f[1], strides = 2)
    e3 = residual_block(e2, f[2], strides = 2)
    e4 = residual_block(e3, f[3], strides = 2)
    e5 = residual_block(e4, f[4], strides = 2)
    
    ## Bridge/Bottleneck
    b0 = conv2d_block(e5, f[4], strides = 1)
    b1 = conv2d_block(b0, f[4], strides = 1)
    
    ## Decoder/upsampling/expansive path
    u1 = upsample_concat_block(b1, e4)
    d1 = residual_block(u1, f[4])
    
    u2 = upsample_concat_block(d1, e3)
    d2 = residual_block(u2, f[3])
    
    u3 = upsample_concat_block(d2, e2)
    d3 = residual_block(u3, f[2])
    
    u4 = upsample_concat_block(d3, e1)
    d4 = residual_block(u4, f[1])
    
    outputs = Conv2D(1, (1, 1), padding = "same", activation = "sigmoid")(d4)
    model = Model(inputs, outputs)
    return model
K.clear_session()
model = ResUNet()
model.summary()
#结果
    Model: "model"
    __________________________________________________________________________________________________
    Layer (type)                    Output Shape         Param #     Connected to                     
    ==================================================================================================
    input_1 (InputLayer)            [(None, 320, 640, 3) 0                                            
    __________________________________________________________________________________________________
    conv2d (Conv2D)                 (None, 320, 640, 16) 448         input_1[0][0]                    
    __________________________________________________________________________________________________
    batch_normalization (BatchNorma (None, 320, 640, 16) 64          conv2d[0][0]                     
    __________________________________________________________________________________________________
    activation (Activation)         (None, 320, 640, 16) 0           batch_normalization[0][0]        
    __________________________________________________________________________________________________
    conv2d_2 (Conv2D)               (None, 320, 640, 16) 64          input_1[0][0]                    
    __________________________________________________________________________________________________
    conv2d_1 (Conv2D)               (None, 320, 640, 16) 2320        activation[0][0]                 
    __________________________________________________________________________________________________
    batch_normalization_1 (BatchNor (None, 320, 640, 16) 64          conv2d_2[0][0]                   
    __________________________________________________________________________________________________
    add (Add)                       (None, 320, 640, 16) 0           conv2d_1[0][0]                   
                                                                     batch_normalization_1[0][0]      
    __________________________________________________________________________________________________
    batch_normalization_2 (BatchNor (None, 320, 640, 16) 64          add[0][0]                        
    __________________________________________________________________________________________________
    activation_1 (Activation)       (None, 320, 640, 16) 0           batch_normalization_2[0][0]      
    __________________________________________________________________________________________________
    conv2d_3 (Conv2D)               (None, 160, 320, 32) 4640        activation_1[0][0]               
    __________________________________________________________________________________________________
    batch_normalization_3 (BatchNor (None, 160, 320, 32) 128         conv2d_3[0][0]                   
    __________________________________________________________________________________________________
    conv2d_5 (Conv2D)               (None, 160, 320, 32) 544         add[0][0]                        
    __________________________________________________________________________________________________
    activation_2 (Activation)       (None, 160, 320, 32) 0           batch_normalization_3[0][0]      
    __________________________________________________________________________________________________
    batch_normalization_4 (BatchNor (None, 160, 320, 32) 128         conv2d_5[0][0]                   
    __________________________________________________________________________________________________
    conv2d_4 (Conv2D)               (None, 160, 320, 32) 9248        activation_2[0][0]               
    __________________________________________________________________________________________________
    add_1 (Add)                     (None, 160, 320, 32) 0           batch_normalization_4[0][0]      
                                                                     conv2d_4[0][0]                   
    __________________________________________________________________________________________________
    batch_normalization_5 (BatchNor (None, 160, 320, 32) 128         add_1[0][0]                      
    __________________________________________________________________________________________________
    activation_3 (Activation)       (None, 160, 320, 32) 0           batch_normalization_5[0][0]      
    __________________________________________________________________________________________________
    conv2d_6 (Conv2D)               (None, 80, 160, 64)  18496       activation_3[0][0]               
    __________________________________________________________________________________________________
    batch_normalization_6 (BatchNor (None, 80, 160, 64)  256         conv2d_6[0][0]                   
    __________________________________________________________________________________________________
    conv2d_8 (Conv2D)               (None, 80, 160, 64)  2112        add_1[0][0]                      
    __________________________________________________________________________________________________
    activation_4 (Activation)       (None, 80, 160, 64)  0           batch_normalization_6[0][0]      
    __________________________________________________________________________________________________
    batch_normalization_7 (BatchNor (None, 80, 160, 64)  256         conv2d_8[0][0]                   
    __________________________________________________________________________________________________
    conv2d_7 (Conv2D)               (None, 80, 160, 64)  36928       activation_4[0][0]               
    __________________________________________________________________________________________________
    add_2 (Add)                     (None, 80, 160, 64)  0           batch_normalization_7[0][0]      
                                                                     conv2d_7[0][0]                   
    __________________________________________________________________________________________________
    batch_normalization_8 (BatchNor (None, 80, 160, 64)  256         add_2[0][0]                      
    __________________________________________________________________________________________________
    activation_5 (Activation)       (None, 80, 160, 64)  0           batch_normalization_8[0][0]      
    __________________________________________________________________________________________________
    conv2d_9 (Conv2D)               (None, 40, 80, 128)  73856       activation_5[0][0]               
    __________________________________________________________________________________________________
    batch_normalization_9 (BatchNor (None, 40, 80, 128)  512         conv2d_9[0][0]                   
    __________________________________________________________________________________________________
    conv2d_11 (Conv2D)              (None, 40, 80, 128)  8320        add_2[0][0]                      
    __________________________________________________________________________________________________
    activation_6 (Activation)       (None, 40, 80, 128)  0           batch_normalization_9[0][0]      
    __________________________________________________________________________________________________
    batch_normalization_10 (BatchNo (None, 40, 80, 128)  512         conv2d_11[0][0]                  
    __________________________________________________________________________________________________
    conv2d_10 (Conv2D)              (None, 40, 80, 128)  147584      activation_6[0][0]               
    __________________________________________________________________________________________________
    add_3 (Add)                     (None, 40, 80, 128)  0           batch_normalization_10[0][0]     
                                                                     conv2d_10[0][0]                  
    __________________________________________________________________________________________________
    batch_normalization_11 (BatchNo (None, 40, 80, 128)  512         add_3[0][0]                      
    __________________________________________________________________________________________________
    activation_7 (Activation)       (None, 40, 80, 128)  0           batch_normalization_11[0][0]     
    __________________________________________________________________________________________________
    conv2d_12 (Conv2D)              (None, 20, 40, 256)  295168      activation_7[0][0]               
    __________________________________________________________________________________________________
    batch_normalization_12 (BatchNo (None, 20, 40, 256)  1024        conv2d_12[0][0]                  
    __________________________________________________________________________________________________
    conv2d_14 (Conv2D)              (None, 20, 40, 256)  33024       add_3[0][0]                      
    __________________________________________________________________________________________________
    activation_8 (Activation)       (None, 20, 40, 256)  0           batch_normalization_12[0][0]     
    __________________________________________________________________________________________________
    batch_normalization_13 (BatchNo (None, 20, 40, 256)  1024        conv2d_14[0][0]                  
    __________________________________________________________________________________________________
    conv2d_13 (Conv2D)              (None, 20, 40, 256)  590080      activation_8[0][0]               
    __________________________________________________________________________________________________
    add_4 (Add)                     (None, 20, 40, 256)  0           batch_normalization_13[0][0]     
                                                                     conv2d_13[0][0]                  
    __________________________________________________________________________________________________
    batch_normalization_14 (BatchNo (None, 20, 40, 256)  1024        add_4[0][0]                      
    __________________________________________________________________________________________________
    activation_9 (Activation)       (None, 20, 40, 256)  0           batch_normalization_14[0][0]     
    __________________________________________________________________________________________________
    conv2d_15 (Conv2D)              (None, 20, 40, 256)  590080      activation_9[0][0]               
    __________________________________________________________________________________________________
    batch_normalization_15 (BatchNo (None, 20, 40, 256)  1024        conv2d_15[0][0]                  
    __________________________________________________________________________________________________
    activation_10 (Activation)      (None, 20, 40, 256)  0           batch_normalization_15[0][0]     
    __________________________________________________________________________________________________
    conv2d_16 (Conv2D)              (None, 20, 40, 256)  590080      activation_10[0][0]              
    __________________________________________________________________________________________________
    up_sampling2d (UpSampling2D)    (None, 40, 80, 256)  0           conv2d_16[0][0]                  
    __________________________________________________________________________________________________
    concatenate (Concatenate)       (None, 40, 80, 384)  0           up_sampling2d[0][0]              
                                                                     add_3[0][0]                      
    __________________________________________________________________________________________________
    batch_normalization_16 (BatchNo (None, 40, 80, 384)  1536        concatenate[0][0]                
    __________________________________________________________________________________________________
    activation_11 (Activation)      (None, 40, 80, 384)  0           batch_normalization_16[0][0]     
    __________________________________________________________________________________________________
    conv2d_17 (Conv2D)              (None, 40, 80, 256)  884992      activation_11[0][0]              
    __________________________________________________________________________________________________
    batch_normalization_17 (BatchNo (None, 40, 80, 256)  1024        conv2d_17[0][0]                  
    __________________________________________________________________________________________________
    conv2d_19 (Conv2D)              (None, 40, 80, 256)  98560       concatenate[0][0]                
    __________________________________________________________________________________________________
    activation_12 (Activation)      (None, 40, 80, 256)  0           batch_normalization_17[0][0]     
    __________________________________________________________________________________________________
    batch_normalization_18 (BatchNo (None, 40, 80, 256)  1024        conv2d_19[0][0]                  
    __________________________________________________________________________________________________
    conv2d_18 (Conv2D)              (None, 40, 80, 256)  590080      activation_12[0][0]              
    __________________________________________________________________________________________________
    add_5 (Add)                     (None, 40, 80, 256)  0           batch_normalization_18[0][0]     
                                                                     conv2d_18[0][0]                  
    __________________________________________________________________________________________________
    up_sampling2d_1 (UpSampling2D)  (None, 80, 160, 256) 0           add_5[0][0]                      
    __________________________________________________________________________________________________
    concatenate_1 (Concatenate)     (None, 80, 160, 320) 0           up_sampling2d_1[0][0]            
                                                                     add_2[0][0]                      
    __________________________________________________________________________________________________
    batch_normalization_19 (BatchNo (None, 80, 160, 320) 1280        concatenate_1[0][0]              
    __________________________________________________________________________________________________
    activation_13 (Activation)      (None, 80, 160, 320) 0           batch_normalization_19[0][0]     
    __________________________________________________________________________________________________
    conv2d_20 (Conv2D)              (None, 80, 160, 128) 368768      activation_13[0][0]              
    __________________________________________________________________________________________________
    batch_normalization_20 (BatchNo (None, 80, 160, 128) 512         conv2d_20[0][0]                  
    __________________________________________________________________________________________________
    conv2d_22 (Conv2D)              (None, 80, 160, 128) 41088       concatenate_1[0][0]              
    __________________________________________________________________________________________________
    activation_14 (Activation)      (None, 80, 160, 128) 0           batch_normalization_20[0][0]     
    __________________________________________________________________________________________________
    batch_normalization_21 (BatchNo (None, 80, 160, 128) 512         conv2d_22[0][0]                  
    __________________________________________________________________________________________________
    conv2d_21 (Conv2D)              (None, 80, 160, 128) 147584      activation_14[0][0]              
    __________________________________________________________________________________________________
    add_6 (Add)                     (None, 80, 160, 128) 0           batch_normalization_21[0][0]     
                                                                     conv2d_21[0][0]                  
    __________________________________________________________________________________________________
    up_sampling2d_2 (UpSampling2D)  (None, 160, 320, 128 0           add_6[0][0]                      
    __________________________________________________________________________________________________
    concatenate_2 (Concatenate)     (None, 160, 320, 160 0           up_sampling2d_2[0][0]            
                                                                     add_1[0][0]                      
    __________________________________________________________________________________________________
    batch_normalization_22 (BatchNo (None, 160, 320, 160 640         concatenate_2[0][0]              
    __________________________________________________________________________________________________
    activation_15 (Activation)      (None, 160, 320, 160 0           batch_normalization_22[0][0]     
    __________________________________________________________________________________________________
    conv2d_23 (Conv2D)              (None, 160, 320, 64) 92224       activation_15[0][0]              
    __________________________________________________________________________________________________
    batch_normalization_23 (BatchNo (None, 160, 320, 64) 256         conv2d_23[0][0]                  
    __________________________________________________________________________________________________
    conv2d_25 (Conv2D)              (None, 160, 320, 64) 10304       concatenate_2[0][0]              
    __________________________________________________________________________________________________
    activation_16 (Activation)      (None, 160, 320, 64) 0           batch_normalization_23[0][0]     
    __________________________________________________________________________________________________
    batch_normalization_24 (BatchNo (None, 160, 320, 64) 256         conv2d_25[0][0]                  
    __________________________________________________________________________________________________
    conv2d_24 (Conv2D)              (None, 160, 320, 64) 36928       activation_16[0][0]              
    __________________________________________________________________________________________________
    add_7 (Add)                     (None, 160, 320, 64) 0           batch_normalization_24[0][0]     
                                                                     conv2d_24[0][0]                  
    __________________________________________________________________________________________________
    up_sampling2d_3 (UpSampling2D)  (None, 320, 640, 64) 0           add_7[0][0]                      
    __________________________________________________________________________________________________
    concatenate_3 (Concatenate)     (None, 320, 640, 80) 0           up_sampling2d_3[0][0]            
                                                                     add[0][0]                        
    __________________________________________________________________________________________________
    batch_normalization_25 (BatchNo (None, 320, 640, 80) 320         concatenate_3[0][0]              
    __________________________________________________________________________________________________
    activation_17 (Activation)      (None, 320, 640, 80) 0           batch_normalization_25[0][0]     
    __________________________________________________________________________________________________
    conv2d_26 (Conv2D)              (None, 320, 640, 32) 23072       activation_17[0][0]              
    __________________________________________________________________________________________________
    batch_normalization_26 (BatchNo (None, 320, 640, 32) 128         conv2d_26[0][0]                  
    __________________________________________________________________________________________________
    conv2d_28 (Conv2D)              (None, 320, 640, 32) 2592        concatenate_3[0][0]              
    __________________________________________________________________________________________________
    activation_18 (Activation)      (None, 320, 640, 32) 0           batch_normalization_26[0][0]     
    __________________________________________________________________________________________________
    batch_normalization_27 (BatchNo (None, 320, 640, 32) 128         conv2d_28[0][0]                  
    __________________________________________________________________________________________________
    conv2d_27 (Conv2D)              (None, 320, 640, 32) 9248        activation_18[0][0]              
    __________________________________________________________________________________________________
    add_8 (Add)                     (None, 320, 640, 32) 0           batch_normalization_27[0][0]     
                                                                     conv2d_27[0][0]                  
    __________________________________________________________________________________________________
    conv2d_29 (Conv2D)              (None, 320, 640, 1)  33          add_8[0][0]                      
    ==================================================================================================
    Total params: 4,723,057
    Trainable params: 4,715,761
    Non-trainable params: 7,296
    __________________________________________________________________________________________________

4.Loss & Compile

smooth = 1.

def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)


def dice_coef_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)

def IOU(y_true, y_pred):

    y_true = K.flatten(y_true)
    y_pred = K.flatten(y_pred)

    thresh = 0.5

    y_true = K.cast(K.greater_equal(y_true, thresh), 'float32')
    y_pred = K.cast(K.greater_equal(y_pred, thresh), 'float32')

    union = K.sum(K.maximum(y_true, y_pred)) + K.epsilon()
    intersection = K.sum(K.minimum(y_true, y_pred)) + K.epsilon()

    iou = intersection/union

    return iou


def lr_schedule(epoch):

    lr =0.0035
    if epoch >150:
        lr *=2**-1
    elif epoch >80:
        lr *=2**(-1)
    elif epoch >50:
        lr *=2**(-1)
    elif epoch >30:
        lr *=2**(-1)
    
    print('Learning rate: ', lr)
    return lr
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import LearningRateScheduler
from keras.optimizers import SGD
import time

start_time = time.time()

# Prepare callbacks for model saving and for learning rate adjustment.
lr_scheduler = LearningRateScheduler(lr_schedule)

lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1),
                               cooldown=0,
                               patience=5,
                               min_lr=0.5e-6)

callbacks = [lr_reducer, lr_scheduler]

import tensorflow as tf
optimiser=tf.keras.optimizers.Adam(
    learning_rate=lr_schedule(0),
    beta_1=0.9,
    beta_2=0.999,
    epsilon=1e-07,
    amsgrad=True,
    name="Adam"
)
model.compile(optimizer =optimiser , loss = dice_coef_loss, metrics = ['accuracy', IOU, dice_coef])
#结果
    Learning rate:  0.0035

5.Training

train_steps = len(train_image_files)//batch_size
valid_steps = len(valid_image_files)//batch_size

history = model.fit(
    tg,  #基于上面的训练集
    steps_per_epoch=train_steps,
    initial_epoch = 0,
    epochs=3,   #这里只设为3个
    validation_data = vg, #基于上面的验证集   
    validation_steps = valid_steps,callbacks=callbacks)
WARNING:tensorflow:sample_weight modes were coerced from
  ...
    to  
  ['...']
WARNING:tensorflow:sample_weight modes were coerced from
  ...
    to  
  ['...']
Train for 474 steps, validate for 87 steps
Learning rate:  0.0035
Epoch 1/3
474/474 [==============================] - 10471s 22s/step - loss: 0.4093 - accuracy: 0.9513 - IOU: 0.4435 - dice_coef: 0.5907 - val_loss: 0.4153 - val_accuracy: 0.9350 - val_IOU: 0.4442 - val_dice_coef: 0.5847
Learning rate:  0.0035
Epoch 2/3
474/474 [==============================] - 5823s 12s/step - loss: 0.3736 - accuracy: 0.9561 - IOU: 0.4787 - dice_coef: 0.6264 - val_loss: 0.4357 - val_accuracy: 0.8781 - val_IOU: 0.4241 - val_dice_coef: 0.5643
Learning rate:  0.0035
Epoch 3/3
474/474 [==============================] - 6416s 14s/step - loss: 0.3337 - accuracy: 0.9618 - IOU: 0.5186 - dice_coef: 0.6663 - val_loss: 0.4244 - val_accuracy: 0.9004 - val_IOU: 0.4297 - val_dice_coef: 0.5756
train_loss = history.history['loss']
valid_loss = history.history['val_loss']

train_acc = history.history['accuracy']
valid_acc = history.history['val_accuracy']
fig, axes = plt.subplots(1, 2, figsize=(13,4))
axes = axes.flatten()

axes[0].plot(train_acc, label='training')
axes[0].plot(valid_acc, label='validation')
axes[0].set_title('Accuracy Curve')
axes[0].set_xlabel('epochs')
axes[0].set_ylabel('accuracy')
axes[0].legend()


axes[1].plot(train_loss, label='training')
axes[1].plot(valid_loss, label='validation')
axes[1].set_title('Loss Curve')
axes[1].set_xlabel('epochs')
axes[1].set_ylabel('loss')
axes[1].legend()

plt.show()

train_dice = history.history['dice_coef']
valid_dice = history.history['val_dice_coef']

train_IOU = history.history['IOU']
valid_IOU = history.history['val_IOU']
fig, axes = plt.subplots(1, 2, figsize=(20,7))
axes = axes.flatten()

axes[0].plot(train_IOU, label='training')
axes[0].plot(valid_IOU, label='validation')
axes[0].set_title('IOU Curve [Adam lr : 0.0001]')
axes[0].set_xlabel('epochs')
axes[0].set_ylabel('IOU')
axes[0].legend()


axes[1].plot(train_dice, label='training')
axes[1].plot(valid_dice, label='validation')
axes[1].set_title('Dice coefficient Curve [Adam lr : 0.0001]')
axes[1].set_xlabel('epochs')
axes[1].set_ylabel('dice_coef')
axes[1].legend()

plt.show()

6.Testing

test_generator = Generator(valid_image_files, valid_mask_files, 396, img_dim)

for x_test, y_test in test_generator:
  break

y_pred = model.predict(x_test)

yy_true = (y_test>0.5).flatten()
yy_pred = (y_pred>0.5).flatten()
report = classification_report(yy_true, yy_pred, output_dict=True)

Accuracy = accuracy_score(yy_true, yy_pred)

Precision = report['True']['precision']
Recall = report['True']['recall']
F1_score = report['True']['f1-score']

Sensitivity = Recall
Specificity = report['False']['recall']

AUC = roc_auc_score(y_test.flatten(), y_pred.flatten())

IOU = (Precision*Recall)/(Precision+Recall-Precision*Recall)

print("Accuracy: {0:.4f}\n".format(Accuracy))
print("Precision: {0:.4f}\n".format(Precision))
print("Recall: {0:.4f}\n".format(Recall))
print("F1-Score: {0:.4f}\n".format(F1_score))
print("Sensitivity: {0:.4f}\n".format(Sensitivity))
print("Specificity: {0:.4f}\n".format(Specificity))
print("AUC: {0:.4f}\n".format(AUC))
print("IOU: {0:.4f}\n".format(IOU))
print('-'*50,'\n')
print(classification_report(yy_true, yy_pred))
#结果
    Accuracy: 0.9004
    
    Precision: 0.3555
    
    Recall: 0.8970
    
    F1-Score: 0.5092
    
    Sensitivity: 0.8970
    
    Specificity: 0.9006
    
    AUC: 0.9202
    
    IOU: 0.3415

    -------------------------------------------------- 
    
                  precision    recall  f1-score   support
    
           False       0.99      0.90      0.94  67167304
            True       0.36      0.90      0.51   4103096
    
        accuracy                           0.90  71270400
       macro avg       0.67      0.90      0.73  71270400
    weighted avg       0.96      0.90      0.92  71270400

for i, j in test1_generator:
  break

print(i.shape)
print(j.shape)
#结果
    (4, 320, 640, 3)
    (4, 320, 640, 1)
ttg = Generator(test_image_paths,test_mask_paths, batch_size, img_dim, augment = False)
for i, j in ttg:
  break

print(i.shape)
print(j.shape)
#结果
    (4, 320, 640, 3)
    (4, 320, 640, 1)
test_generator1 = Generator(test_image_paths,test_mask_paths,1124, img_dim)

for x_test, y_test in test_generator:
  break

y_pred = model.predict(x_test)

yy_true = (y_test>0.5).flatten()
yy_pred = (y_pred>0.5).flatten()
report = classification_report(yy_true, yy_pred, output_dict=True)

Accuracy = accuracy_score(yy_true, yy_pred)

Precision = report['True']['precision']
Recall = report['True']['recall']
F1_score = report['True']['f1-score']

Sensitivity = Recall
Specificity = report['False']['recall']

AUC = roc_auc_score(y_test.flatten(), y_pred.flatten())

IOU = (Precision*Recall)/(Precision+Recall-Precision*Recall)

print("Accuracy: {0:.4f}\n".format(Accuracy))
print("Precision: {0:.4f}\n".format(Precision))
print("Recall: {0:.4f}\n".format(Recall))
print("F1-Score: {0:.4f}\n".format(F1_score))
print("Sensitivity: {0:.4f}\n".format(Sensitivity))
print("Specificity: {0:.4f}\n".format(Specificity))
print("AUC: {0:.4f}\n".format(AUC))
print("IOU: {0:.4f}\n".format(IOU))
print('-'*50,'\n')
print(classification_report(yy_true, yy_pred))
#结果
    Accuracy: 0.9004
    
    Precision: 0.3555
    
    Recall: 0.8970
    
    F1-Score: 0.5092
    
    Sensitivity: 0.8970
    
    Specificity: 0.9006
    
    AUC: 0.9202
    
    IOU: 0.3415
    
    -------------------------------------------------- 
    
                  precision    recall  f1-score   support
    
           False       0.99      0.90      0.94  67167304
            True       0.36      0.90      0.51   4103096
    
        accuracy                           0.90  71270400
       macro avg       0.67      0.90      0.73  71270400
    weighted avg       0.96      0.90      0.92  71270400

参考资料:https://github.com/Subham2901/Concrete_Crack_Segmentation
欢迎关注公众号【智能建造小硕】(分享计算机编程、人工智能、智能建造、日常学习、科研和写作经验等,欢迎大家关注交流。)

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
扎眼的阳光的头像扎眼的阳光普通用户
上一篇 2023年11月30日
下一篇 2023年11月30日

相关推荐