【UNet3+】遥感影像分割

1. 项目准备

1.1. 问题导入

  • 图像分割
    在计算机视觉领域,图像分割指的是将数字图像细分为多个图像子区域的过程,其目的是简化或改变图像的表示形式,使得图像更容易理解和分析。图像分割通常用于定位图像中的物体和边界,更精确的说,它是对图像中的每个像素加标签的一个过程,这一过程使得具有相同标签的像素具有某种共同视觉特性。

  • 实验任务
    本例简要介绍如何使用UNet3+模型实现遥感影像分割,我们需要将遥感影像中存在的建筑物分割、标注出来。

1.2. 数据集简介

武汉大学2019年发布了Aerial Imagery Dataset,该数据集原始航拍数据来自新西兰土地信息服务网站,数据集共有8,189张具有0.3m分辨率、大小为512×512像素的遥感图像,数据集共包含18,7000座建筑物。数据集包含存放遥感图像的image文件夹和存放分割图像的label文件夹,例图如下图所示:

【UNet3+】遥感影像分割

这是数据集的下载链接:Aerial Imagery Dataset – AI Studio

2. UNet3+模型

2.1. 背景介绍

Hinton等人(2006)提出了一种Encoder-Decoder结构,当时这个Encoder-Decoder结构提出的主要作用并不是分割,而是压缩图像和去噪声。输入是一幅图,经过下采样的编码,得到一串比原先图像更小的特征,相当于压缩,然后再经过一个解码,理想状况就是能还原到原来的图像。

后来,Jonathan等人(2015)在论文中基于该拓扑结构提出了FCN(Fully Convolutional Networks)。自提出以后,FCN就成为了语义分割的基本框架,后续算法(如UNet)其实都是在这个框架中改进而来。其中的UNet由于其对称结构简单易懂,且模型效果优秀,于是就成为了许多网络改进的范本之一。

UNet(2015)是医学影像分割领域应用最广泛的的网络,它使用跳跃连接(skip connection)来结合来自解码器的高级语义特征图和来自编码器的相应尺度的低级语义特征图,其性能和网络中多尺度特征的融合密切相关。为了避免纯跳跃连接在语义上融合不相似的特征,此后的UNet++(2018)引入嵌套结构和密集的跳跃连接对网络进行了改进。而最新的UNet3+(2020)通过全尺度的跳跃连接和深度监督(deep supervisions)来融合深层和浅层特征的同时对各个尺度的特征进行监督,它还可以在减少网络参数的同时提高计算效率。

【UNet3+】遥感影像分割

2.2. 模型介绍

Huang等人(2020)在论文中提出了UNet3+模型,Huang等人使用该模型在肝脏和脾脏数据集上进行广泛的实验,发现它的表现得到了提高并且超过了很多baselines。下面介绍一下UNet3+模型的三个创新点:

(1) 全尺度跳跃连接

UNet3+充分利用多尺度特征,引入全尺度跳跃连接(Full-scale Skip Connections),该连接结合了来自全尺度特征图的低级语义和高级语义,并且参数更少。

在许多分割实验的研究中,不同尺度的特征图展示着不同的信息:低级语义特征图捕捉丰富的空间信息,能够突出物体的边界;而高级语义特征图则体现了物体所在的位置信息。为此,UNet3+的每个解码器层都融合了来自编码器中的小尺度和同尺度的低级语义特征图,以及来自解码器的大尺度的高级语义特征图,这些特征图捕获了全尺度下的细粒度语义和粗粒度语义。

【UNet3+】遥感影像分割

如上图所示,为了构造特征图【UNet3+】遥感影像分割,第3层解码器不仅需要接收同尺度编码器层的特征图【UNet3+】遥感影像分割,还需要接收小尺度编码器层的特征图【UNet3+】遥感影像分割【UNet3+】遥感影像分割(为了统一特征图的分辨率,在接收前需进行下采样操作),同时也需要接收大尺度解码器层的特征图【UNet3+】遥感影像分割【UNet3+】遥感影像分割(为了统一特征图的分辨率,在接收前需进行上采样操作)。在统一特征图的分辨率之后,我们还需用64个3×3的卷积核统一特征图的数量,以减少多余信息。在完成上述操作之后,我们就能用“通道维度拼接”的方法融合特征了,融合上述5个特征后便得到了320个特征图。接着,我们用320个3×3的卷积核对其进行卷积操作,最后通过批正则化(Batch Normalize)和ReLU(Rectified Linear Unit)便得到【UNet3+】遥感影像分割

于是,特征图【UNet3+】遥感影像分割的计算公式可总结为:
【UNet3+】遥感影像分割

(2) 全尺度深度监督

UNet3+采用全尺度深度监督(Full-scale Deep Supervision),从全面的聚合特征图中学习层次表示,优化了混合损失函数以增强器官边界。

不同于UNet++对全分辨率特征图进行深度监督,UNet3+中每个解码器都有一个侧输出,它是由真实标准(ground truth)来进行监督的。为实现深度监督,每个解码器的侧输出都会被送入1个3×3卷积层、1个双线性上采样层以及1个sigmoid函数层中。

为了进一步增强器官边界,UNet3+提出了一种多尺度结构相似指数(Multi-Scale Structural Similarity index,MS-SSIM)损失函数来赋予模糊边界更大的权重。由于区域分布差异越大,MS-SSIM值越高,故UNet3+将更加关注模糊边界。假设我们从分割结果P和真实标准G中分别裁剪了两个N×N的块【UNet3+】遥感影像分割【UNet3+】遥感影像分割,并且有【UNet3+】遥感影像分割【UNet3+】遥感影像分割,那么我们可定义【UNet3+】遥感影像分割【UNet3+】遥感影像分割的MS-SSIM损失函数为:
【UNet3+】遥感影像分割

UNet3+融合了focal损失函数、MS-SSIM损失函数和IoU损失函数,提出了一种用于三个不同层次级别(像素级、块级、图像级)分割的混合损失函数,它能捕获边界清晰的大尺度结构和精细结构。该混合损失函数的定义为:
【UNet3+】遥感影像分割

(3) 分类指导模块

UNet3+提出分类指导模块(Classification-guided Module,CGM),通过图像级分类联合训练,减少非器官图像的过度分割。

在大多数医学图像分割实验中,由于来自背景的噪声信息停留在较浅层次中,这导致非器官图像出现过度分割的现象。为解决这一问题,UNet3+增加了一个预测输入图像是否有器官的额外分类任务。

【UNet3+】遥感影像分割

如上图所示,最深层的特征图【UNet3+】遥感影像分割依次通过Dropout层、1×1卷积层、最大池化层和Sigmoid函数层,以得到代表【UNet3+】遥感影像分割中有/无器官概率的二维张量。然后,我们可以用argmax函数处理二维张量,以得到仅包含0和1的二分类结果。接着,我们用这些分类结果与每个侧边分割输出相乘,以得到修正后的侧边分割输出。我们可以通过优化二分类的交叉损失函数,来获得更准确的分类结果,以此指导模型避免对非器官图像过度分割。

3. 代码实现

3.0. 前期准备

  • 导入模块

注意:本案例仅适用于Paddle 2.0+版本,建议根据显存大小合理调整超参数batch_sizeimg_size的大小!

import cv2
import os
import random
import zipfile
import numpy as np
from copy import deepcopy
from PIL import Image, ImageEnhance
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap as LSC

import paddle
from paddle import nn
from paddle.framework import ParamAttr
from paddle.io import DataLoader, Dataset
from paddle.nn import initializer as I, functional as F
from paddle.optimizer import Adam
from paddle.optimizer.lr import CosineAnnealingDecay
  • 设置超参数
BATCH_SIZE = 2           # 每批次的样本数
EPOCHS = 10              # 模型训练的总轮数
LOG_GAP = 360            # 输出训练信息的间隔

N_CLASSES = 2            # 图像分类种类数量
IMG_SIZE = (256, 256)    # 图像缩放尺寸

INIT_LR = 2e-4           # 初始学习率
T_MAX = EPOCHS           # 余弦周期的一半

SRC_PATH = "./data/data69911/BuildData.zip"  # 压缩包路径
DST_PATH = "./data"                          # 解压路径
DATA_PATH = {                                # 实验数据集路径
    "img": DST_PATH + "/image",    # 正常图像
    "lab": DST_PATH + "/label",    # 分割图像
}
INFER_PATH = {                               # 预测数据集路径
    "img": ["./work/1.jpg", "./work/2.jpg"],   # 正常图像
    "lab": ["./work/1.png", "./work/2.png"],   # 分割图像
}
MODEL_PATH = "UNet3+.pdparams"               # 模型参数保存路径

3.1. 数据准备

  • 解压数据集
    由于数据集中的数据是以压缩包的形式存放的,因此我们需要先解压数据压缩包。
if not os.path.isdir(DATA_PATH["img"]) or not os.path.isdir(DATA_PATH["lab"]):
    z = zipfile.ZipFile(SRC_PATH, "r")   # 以只读模式打开zip文件
    z.extractall(path=DST_PATH)          # 解压zip文件至目标路径
    z.close()
print("The dataset has been unpacked successfully!")
  • 划分数据集
    我们需要按9:1比例划分训练集和测试集,分别生成两个包含数据路径和标签路径映射关系的列表。
train_list, test_list = [], []         # 存放图像路径与标签路径的映射
images = os.listdir(DATA_PATH["img"])  # 统计数据集下的图像文件

for idx, img in enumerate(images):
    lab = os.path.join(DATA_PATH["lab"], img.replace(".jpg", ".png"))
    img = os.path.join(DATA_PATH["img"], img)
    if idx % 10 != 0:                  # 按照1:9的比例划分数据集
        train_list.append((img, lab))
    else:
        test_list.append((img, lab))
  • 数据增强
    数据増广(Data Augmentation),即数据增强,数据增强的目的主要是减少网络的过拟合现象,通过对训练图片进行变换可以得到泛化能力更强的网络,更好地适应应用场景。
    由于实验模型较为复杂,直接训练容易发生过拟合,故在处理实验数据集时采用数据增强的方法扩充数据集的多样性。本实验中用到的数据增强方法有:随机改变亮度,随机改变对比度,随机改变饱和度,随机改变清晰度,随机旋转图像,随机翻转图像,随机加高斯噪声等。
def random_brightness(img, lab, low=0.5, high=1.5):
    ''' 随机改变亮度(0.5~1.5) '''
    x = random.uniform(low, high)
    img = ImageEnhance.Brightness(img).enhance(x)
    return img, lab

def random_contrast(img, lab, low=0.5, high=1.5):
    ''' 随机改变对比度(0.5~1.5) '''
    x = random.uniform(low, high)
    img = ImageEnhance.Contrast(img).enhance(x)
    return img, lab

def random_color(img, lab, low=0.5, high=1.5):
    ''' 随机改变饱和度(0.5~1.5) '''
    x = random.uniform(low, high)
    img = ImageEnhance.Color(img).enhance(x)
    return img, lab

def random_sharpness(img, lab, low=0.5, high=1.5):
    ''' 随机改变清晰度(0.5~1.5) '''
    x = random.uniform(low, high)
    img = ImageEnhance.Sharpness(img).enhance(x)
    return img, lab

def random_rotate(img, lab, low=0, high=360):
    ''' 随机旋转图像(0~360度) '''
    angle = random.choice(range(low, high))
    img, lab = img.rotate(angle), lab.rotate(angle)
    return img, lab

def random_flip(img, lab, prob=0.5):
    ''' 随机翻转图像(p=0.5) '''
    if random.random() < prob:   # 上下翻转
        img = img.transpose(Image.FLIP_TOP_BOTTOM)
        lab = lab.transpose(Image.FLIP_TOP_BOTTOM)
    if random.random() < prob:   # 左右翻转
        img = img.transpose(Image.FLIP_LEFT_RIGHT)
        lab = lab.transpose(Image.FLIP_LEFT_RIGHT)
    return img, lab

def random_noise(img, lab, low=0, high=10):
    ''' 随机加高斯噪声(0~10) '''
    img = np.asarray(img)
    sigma = np.random.uniform(low, high)
    noise = np.random.randn(img.shape[0], img.shape[1], 3) * sigma
    img = img + np.round(noise).astype('uint8')
    # 将矩阵中的所有元素值限制在0~255之间:
    img[img > 255], img[img < 0] = 255, 0
    img = Image.fromarray(img)
    return img, lab

def image_augment(img, lab, prob=0.5):
    ''' 叠加多种数据增强方法 '''
    opts = [random_brightness, random_contrast, random_color, random_flip,
            random_noise, random_rotate, random_sharpness,]  # 数据增强方法
    for func in opts:
        if random.random() < prob:
            img, lab = func(img, lab)   # 处理图像和标签
    return img, lab
  • 数据预处理
    我们需要对数据集图像进行缩放和归一化处理。
class MyDataset(Dataset):
    ''' 自定义的数据集类
    * `label_list`: 图像路径和标签路径的映射列表
    * `transform`: 图像处理函数
    * `augment`: 数据增强函数
    '''
    def __init__(self, label_list, transform, augment=None):
        super(MyDataset, self).__init__()
        random.shuffle(label_list)       # 打乱映射列表
        self.label_list = label_list
        self.transform = transform
        self.augment = augment
    
    def __getitem__(self, index):
        ''' 根据位序获取对应数据 '''
        img_path, lab_path = self.label_list[index]
        img, lab = self.transform(img_path, lab_path, self.augment)
        return img, lab
    
    def __len__(self):
        ''' 获取数据集的样本总数 '''
        return len(self.label_list)


def data_mapper(img_path, lab_path, augment=None):
    ''' 图像处理函数 '''
    img = Image.open(img_path).convert("RGB")
    lab = cv2.cvtColor(cv2.imread(lab_path), cv2.COLOR_RGB2GRAY)
    # 将标签文件进行灰度二值化:
    _, lab = cv2.threshold(src=lab,                     # 待处理图片
                           thresh=170,                  # 起始阈值
                           maxval=255,                  # 最大阈值
                           type=cv2.THRESH_BINARY_INV)  # 算法类型
    lab = Image.fromarray(lab).convert("L")       # 转换为PIL.Image
    # 将图像缩放为IMG_SIZE大小的高质量图像:
    img = img.resize(IMG_SIZE, Image.ANTIALIAS)
    lab = lab.resize(IMG_SIZE, Image.ANTIALIAS)
    if augment is not None:    # 数据增强
        img, lab = augment(img, lab)
    # 将图像转为numpy数组,并转换图像的格式:
    img = np.array(img).astype("float32").transpose((2, 0, 1))
    lab = np.array(lab).astype("int64")
    # 将图像数据归一化,并转换成Tensor格式:
    img = paddle.to_tensor(img / 255.0)
    lab = paddle.to_tensor(lab // 255)
    return img, lab
train_dataset = MyDataset(train_list, data_mapper, image_augment)  # 训练集
test_dataset = MyDataset(test_list, data_mapper, augment=None)     # 测试集
  • 定义数据提供器
    我们需要分别构建用于训练和测试的数据提供器,其中训练数据提供器是乱序、按批次提供数据的。
train_loader = DataLoader(train_dataset,          # 训练数据集
                          batch_size=BATCH_SIZE,  # 每批次的样本数
                          num_workers=2,          # 加载数据的子进程数
                          shuffle=True,           # 打乱数据集
                          drop_last=False)        # 不丢弃不完整的样本批次
test_loader = DataLoader(test_dataset,            # 测试数据集
                         batch_size=BATCH_SIZE,   # 每批次的样本数
                         num_workers=2,           # 加载数据的子进程数
                         shuffle=False,           # 不打乱数据集
                         drop_last=False)         # 不丢弃不完整的样本批次

3.2. 网络配置

本次实验使用的是UNet3+模型,UNet系列模型包含下采样(编码器,特征提取)和上采样(解码器,分辨率还原)两个阶段,因模型结构比较像U型而得名。

  • 定义网络初始化函数
def init_weights(net, init_type="normal"):
    ''' 初始化网络的权重与偏置
    * `net`: 需要初始化的神经网络层
    * `init_type`: 初始化机制(normal/xavier/kaiming/truncated)
    '''
    if init_type == "normal":
        attr = ParamAttr(initializer=I.Normal())
    elif init_type == "xavier":
        attr = ParamAttr(initializer=I.XavierNormal())
    elif init_type == "kaiming":
        attr = ParamAttr(initializer=I.KaimingNormal())
    elif init_type == "truncated":
        attr = ParamAttr(initializer=I.TruncatedNormal())
    else:
        error = "Initialization method [%s] is not implemented!"
        raise NotImplementedError(error % init_type)
    # 初始化网络层net的权重系数和偏置系数:
    net.param_attr, net.bias_attr = attr, deepcopy(attr)
  • 构建编码器
class Encoder(nn.Layer):
    ''' 用于构建编码器模块
    * `in_size`: 输入通道数
    * `out_size`: 输出通道数
    * `is_batchnorm`: 是否批正则化
    * `n`: 卷积层数量(默认为2)
    * `ks`: 卷积核大小(默认为3)
    * `s`: 卷积运算步长(默认为1)
    * `p`: 卷积填充大小(默认为1)
    '''
    def __init__(self, in_size, out_size, is_batchnorm, 
                 n=2, ks=3, s=1, p=1):
        super(Encoder, self).__init__()
        self.n = n

        for i in range(1, self.n+1):    # 定义多层卷积神经网络
            if is_batchnorm:
                block = nn.Sequential(nn.Conv2D(in_size, out_size, ks, s, p),
                                      nn.BatchNorm2D(out_size),
                                      nn.ReLU())
            else:
                block = nn.Sequential(nn.Conv2D(in_size, out_size, ks, s, p),
                                      nn.ReLU())
            setattr(self, "block%d" % i, block)
            in_size = out_size
        
        for m in self.children():       # 初始化各层网络的系数
            init_weights(m, init_type="kaiming")
    
    def forward(self, x):
        for i in range(1, self.n+1):
            block = getattr(self, "block%d" % i)
            x = block(x)                # 进行前向传播运算
        return x
  • 构建解码器
class Decoder(nn.Layer):
    ''' 用于构建解码器模块
    * `cur_stage`(int): 当前解码器所在层数
    * `cat_size`(int): 统一后的特征图通道数
    * `up_size`(int): 特征融合后的通道总数
    * `filters`(list): 各卷积网络的卷积核数
    * `ks`: 卷积核大小(默认为3)
    * `s`: 卷积运算步长(默认为1)
    * `p`: 卷积填充大小(默认为1)
    '''
    def __init__(self, cur_stage, cat_size, up_size,
                 filters, ks=3, s=1, p=1):
        super(Decoder, self).__init__()
        self.n = len(filters)      # 卷积网络模块的个数

        for idx, num in enumerate(filters):
            idx += 1               # 待处理输出所在层数
            if idx < cur_stage:
                # he[idx]_PT_hd[cur_stage], Pool [ps] times
                ps = 2 ** (cur_stage - idx)
                block = nn.Sequential(nn.MaxPool2D(ps, ps, ceil_mode=True),
                                      nn.Conv2D(num, cat_size, ks, s, p),
                                      nn.BatchNorm2D(cat_size),
                                      nn.ReLU())
            elif idx == cur_stage:
                # he[idx]_Cat_hd[cur_stage], Concatenate
                block = nn.Sequential(nn.Conv2D(num, cat_size, ks, s, p),
                                      nn.BatchNorm2D(cat_size),
                                      nn.ReLU())
            else:
                # hd[idx]_UT_hd[cur_stage], Upsample [us] times
                us = 2 ** (idx - cur_stage)
                num = num if idx == 5 else up_size
                block = nn.Sequential(nn.Upsample(scale_factor=us, mode="bilinear"),
                                      nn.Conv2D(num, cat_size, ks, s, p),
                                      nn.BatchNorm2D(cat_size),
                                      nn.ReLU())
            setattr(self, "block%d" % idx, block)

        # fusion(he[]_PT_hd[], ..., he[]_Cat_hd[], ..., hd[]_UT_hd[])
        self.fusion = nn.Sequential(nn.Conv2D(up_size, up_size, ks, s, p),
                                    nn.BatchNorm2D(up_size),
                                    nn.ReLU())
        
        for m in self.children():       # 初始化各层网络的系数
            init_weights(m, init_type="kaiming")

    def forward(self, inputs):
        outputs = []       # 记录各层的输出,以便于拼接起来
        for i in range(self.n):
            block = getattr(self, "block%d" % (i+1))
            outputs.append( block(inputs[i]) )
        hd = self.fusion(paddle.concat(outputs, 1))
        return hd
  • 定义网络结构
class UNet3Plus(nn.Layer):
    ''' UNet3+ with Deep Supervision and Class-guided Module
    * `in_channels`: 输入通道数(默认为3)
    * `n_classes`: 物体的分类种数(默认为2)
    * `is_batchnorm`: 是否批正则化(默认为True)
    * `deep_sup`: 是否开启深度监督机制(Deep Supervision)
    * `set_cgm`: 是否设置分类引导模块(Class-guided Module)
    '''
    def __init__(self, in_channels=3, n_classes=2, 
                 is_batchnorm=True, deep_sup=True, set_cgm=True):
        super(UNet3Plus, self).__init__()
        self.deep_sup = deep_sup
        self.set_cgm = set_cgm
        filters = [64, 128, 256, 512, 1024]      # 各模块的卷积核大小
        cat_channels = filters[0]                # 统一后的特征图通道数
        cat_blocks = 5                           # 编(解)码器的层数
        up_channels = cat_channels * cat_blocks  # 特征融合后的通道数

        # ====================== Encoders ======================
        self.conv_e1 = Encoder(in_channels, filters[0], is_batchnorm)
        self.pool_e1 = nn.MaxPool2D(kernel_size=2)
        self.conv_e2 = Encoder(filters[0], filters[1], is_batchnorm)
        self.pool_e2 = nn.MaxPool2D(kernel_size=2)
        self.conv_e3 = Encoder(filters[1], filters[2], is_batchnorm)
        self.pool_e3 = nn.MaxPool2D(kernel_size=2)
        self.conv_e4 = Encoder(filters[2], filters[3], is_batchnorm)
        self.pool_e4 = nn.MaxPool2D(kernel_size=2)
        self.conv_e5 = Encoder(filters[3], filters[4], is_batchnorm)
        
        # ====================== Decoders ======================
        self.conv_d4 = Decoder(4, cat_channels, up_channels, filters)
        self.conv_d3 = Decoder(3, cat_channels, up_channels, filters)
        self.conv_d2 = Decoder(2, cat_channels, up_channels, filters)
        self.conv_d1 = Decoder(1, cat_channels, up_channels, filters)

        # ======================= Output =======================
        if self.set_cgm:
            # -------------- Class-guided Module ---------------
            self.cls = nn.Sequential(nn.Dropout(p=0.5),
                                     nn.Conv2D(filters[4], 2, 1),
                                     nn.AdaptiveMaxPool2D(1),
                                     nn.Sigmoid())
        if self.deep_sup:
            # -------------- Bilinear Upsampling ---------------
            self.upscore5 = nn.Upsample(scale_factor=16, mode="bilinear")
            self.upscore4 = nn.Upsample(scale_factor=8, mode="bilinear")
            self.upscore3 = nn.Upsample(scale_factor=4, mode="bilinear")
            self.upscore2 = nn.Upsample(scale_factor=2, mode="bilinear")
            # ---------------- Deep Supervision ----------------
            self.outconv5 = nn.Conv2D(filters[4], n_classes, 3, 1, 1)
            self.outconv4 = nn.Conv2D(up_channels, n_classes, 3, 1, 1)
            self.outconv3 = nn.Conv2D(up_channels, n_classes, 3, 1, 1)
            self.outconv2 = nn.Conv2D(up_channels, n_classes, 3, 1, 1)
        self.outconv1 = nn.Conv2D(up_channels, n_classes, 3, 1, 1)
    
        # ================= Initialize Weights =================
        for m in self.sublayers():
            if isinstance(m, nn.Conv2D) or isinstance(m, nn.BatchNorm):
                init_weights(m, init_type='kaiming')

    def dot_product(self, seg, cls):
        B, N, H, W = seg.shape
        seg = seg.reshape((B, N, H * W))
        clssp = paddle.ones((1, N))
        ecls = (cls * clssp).reshape((B, N, 1))
        final = (seg * ecls).reshape((B, N, H, W))
        return final

    def forward(self, x):
        # ====================== Encoders ======================
        e1 = self.conv_e1(x)                  # e1: 320*320*64
        e2 = self.pool_e1(self.conv_e2(e1))   # e2: 160*160*128
        e3 = self.pool_e2(self.conv_e3(e2))   # e3: 80*80*256
        e4 = self.pool_e3(self.conv_e4(e3))   # e4: 40*40*512
        e5 = self.pool_e4(self.conv_e5(e4))   # e5: 20*20*1024

        # ================ Class-guided Module =================
        if self.set_cgm:
            cls_branch = self.cls(e5).squeeze(3).squeeze(2)
            cls_branch_max = cls_branch.argmax(axis=1)
            cls_branch_max = cls_branch_max[:, np.newaxis].astype("float32")

        # ====================== Decoders ======================
        d5 = e5
        d4 = self.conv_d4((e1, e2, e3, e4, d5))
        d3 = self.conv_d3((e1, e2, e3, d4, d5))
        d2 = self.conv_d2((e1, e2, d3, d4, d5))
        d1 = self.conv_d1((e1, d2, d3, d4, d5))
        
        # ======================= Output =======================
        if self.deep_sup:
            y5 = self.upscore5( self.outconv5(d5) )  # 16 => 256
            y4 = self.upscore4( self.outconv4(d4) )  # 32 => 256
            y3 = self.upscore3( self.outconv3(d3) )  # 64 => 256
            y2 = self.upscore2( self.outconv2(d2) )  # 128 => 256
            y1 = self.outconv1(d1)                   # 256
            if self.set_cgm:
                y5 = self.dot_product(y5, cls_branch_max)
                y4 = self.dot_product(y4, cls_branch_max)
                y3 = self.dot_product(y3, cls_branch_max)
                y2 = self.dot_product(y2, cls_branch_max)
                y1 = self.dot_product(y1, cls_branch_max)
            return F.sigmoid(y1), F.sigmoid(y2), F.sigmoid(y3),\
                   F.sigmoid(y4), F.sigmoid(y5)
        else:
            y1 = self.outconv1(d1)                   # 320*320*n_classes
            if self.set_cgm:
                y1 = self.dot_product(y1, cls_branch_max)
            return F.sigmoid(y1)
  • 实例化模型
model = UNet3Plus(n_classes=N_CLASSES, deep_sup=False, set_cgm=False)
# paddle.Model(model).summary((1, 3) + IMG_SIZE)  # 可视化模型结构

3. 模型训练

model.train()                # 开启训练模式
scheduler = CosineAnnealingDecay(
    learning_rate=INIT_LR,
    T_max=T_MAX,
)                            # 定义学习率衰减器
optimizer = Adam(
    learning_rate=scheduler,
    parameters=model.parameters()
)                            # 定义Adam优化器
loss_arr = []                # 记录每批训练的误差

for ep in range(EPOCHS):
    for batch_id, data in enumerate(train_loader()):
        image, label = data
        pred = model(image)                          # 预测结果
        loss = F.cross_entropy(pred, label, axis=1)  # 计算损失函数值
        if batch_id % LOG_GAP == 0:                  # 定期输出训练结果
            print("Epoch:%d,Batch:%3d,Loss:%.5f" % (ep, batch_id, loss))
        loss_arr.append(loss.item())
        optimizer.clear_grad()
        loss.backward()
        optimizer.step()
    scheduler.step()       # 衰减一次学习率
    paddle.save(model.state_dict(), MODEL_PATH)  # 保存训练好的模型

模型训练的结果如下:

Epoch:0,Batch:  0,Loss:1.39092
Epoch:0,Batch:360,Loss:0.17174
Epoch:0,Batch:720,Loss:0.16681
Epoch:1,Batch:  0,Loss:0.11368
Epoch:1,Batch:360,Loss:0.11665
Epoch:1,Batch:720,Loss:0.06234
Epoch:2,Batch:  0,Loss:0.12535
Epoch:2,Batch:360,Loss:0.12542
Epoch:2,Batch:720,Loss:0.11362
Epoch:3,Batch:  0,Loss:0.12906
Epoch:3,Batch:360,Loss:0.11927
Epoch:3,Batch:720,Loss:0.11524
Epoch:4,Batch:  0,Loss:0.07827
Epoch:4,Batch:360,Loss:0.15802
Epoch:4,Batch:720,Loss:0.09502
Epoch:5,Batch:  0,Loss:0.13487
Epoch:5,Batch:360,Loss:0.09628
Epoch:5,Batch:720,Loss:0.10007
Epoch:6,Batch:  0,Loss:0.07204
Epoch:6,Batch:360,Loss:0.11167
Epoch:6,Batch:720,Loss:0.13266
Epoch:7,Batch:  0,Loss:0.05692
Epoch:7,Batch:360,Loss:0.16079
Epoch:7,Batch:720,Loss:0.10594
Epoch:8,Batch:  0,Loss:0.05400
Epoch:8,Batch:360,Loss:0.06496
Epoch:8,Batch:720,Loss:0.09775
Epoch:9,Batch:  0,Loss:0.07335
Epoch:9,Batch:360,Loss:0.07723
Epoch:9,Batch:720,Loss:0.06590
  • 可视化训练过程
fig = plt.figure(figsize=[10, 5])

# 训练误差图像:
ax = fig.add_subplot(111, facecolor="#E8E8F8")
ax.set_ylabel("Loss", fontsize=18)
plt.tick_params(labelsize=14)
ax.plot(range(len(loss_arr)), loss_arr, color="orangered")
ax.grid(linewidth=1.5, color="white")  # 显示网格

fig.tight_layout()
plt.show()
plt.close()

【UNet3+】遥感影像分割

3.4. 模型评估

model.eval()                 # 开启评估模式
test_costs = []

for batch_id, data in enumerate(test_loader()):
    image, label = data
    pred = model(image)                          # 预测结果
    loss = F.cross_entropy(pred, label, axis=1)  # 计算损失函数值
    test_costs.append(loss.item())
print("Eval \t Avg_Loss:%.5f" % (np.mean(test_costs)))

模型评估的结果如下:

Eval 	 Avg_Loss:0.07250

3.5. 模型预测

def show_result(img_path, lab_path, pred):
    ''' 展示原图、标签以及预测结果 '''

    def add_subimg(img, loc, title, cmap=None):
        ''' 添加子图以展示图像 '''
        plt.subplot(loc)
        plt.title(title)
        plt.imshow(img, cmap)
        plt.xticks([])         # 去除X刻度
        plt.yticks([])         # 去除Y刻度

    def colormap(colors=['#A0C185', '#A6A6A6']):
        ''' 自定义ColorMap '''
        return LSC.from_list('cmap', colors, 256)

    img = Image.open(img_path).resize(IMG_SIZE)
    lab = Image.open(lab_path).resize(IMG_SIZE)
    pred = pred.argmax(axis=1).numpy().reshape(IMG_SIZE)
    plt.figure(figsize=(12, 4))
    add_subimg(img, 131, "Image")
    add_subimg(lab, 132, "Label")
    add_subimg(pred, 133, "Predict", colormap())
    plt.tight_layout()
    plt.show()
    plt.close()
model.eval()                 # 开启评估模式
model.set_state_dict(
    paddle.load(MODEL_PATH)
)   # 载入预训练模型参数

for i in range(len(INFER_PATH["img"])):
    img_path, lab_path = INFER_PATH["img"][i], INFER_PATH["lab"][i]
    img, lab = data_mapper(img_path, lab_path)  # 处理预测图像
    pred = model(img[np.newaxis, ...])          # 开始模型预测
    show_result(img_path, lab_path, pred)

第1组图像分割结果如下:
【UNet3+】遥感影像分割

第2组图像分割结果如下:
【UNet3+】遥感影像分割

写在最后

  • 如果您发现项目存在问题,或者如果您有更好的建议,欢迎在下方评论区中留言讨论~
  • 这是本项目的链接:实验项目 – AI Studio,点击fork可直接在AI Studio运行~
  • 这是我的个人主页:个人主页 – AI Studio,来AI Studio互粉吧,等你哦~
  • 【友链滴滴】欢迎大家随时访问我的个人博客~

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
扎眼的阳光的头像扎眼的阳光普通用户
上一篇 2023年3月4日 下午1:11
下一篇 2023年3月4日

相关推荐