UNet-3D个人理解及代码实现(PyTorch)

以下内容均为个人理解,如有错误,欢迎指正。

UNet-3D

论文链接:地址

网络结构

在这里插入图片描述
UNet-3D和UNet-2D的基本结构是差不多的,分成小模块来看,也是有连续两次卷积,下采样,上采样,特征融合以及最后一次卷积。
UNet-2D可参考:VGG16+UNet个人理解及代码实现(Pytorch)

不同的是,UNet-3D的卷积是三维的卷积。
关于2D卷积和3D卷积的区别可参见这篇文章:链接

需要注意的是,UNet-3D的连续两次卷积操作中,第一次卷积和第二次卷积的输出通道数是不同的(UNet-2D的连续两次卷积操作的输出通道数是相同的)。

单从图示的网络结构来看,UNet-3D的网络深度为4,2D的网络深度为5,这个深度可以改变。

理清楚这些之后就可以着手写代码实现网络结构了。

代码实现

使用PyTorch实现

import torch
import torch.nn as nn
from torch.nn import functional as F

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, bath_normal=False):
        super(DoubleConv, self).__init__()
        channels = out_channels / 2
        if in_channels > out_channels:
            channels = in_channels / 2

        layers = [
            # in_channels:输入通道数
            # channels:输出通道数
            # kernel_size:卷积核大小
            # stride:步长
            # padding:边缘填充
            nn.Conv3d(in_channels, channels, kernel_size=3, stride=1, padding=0),
            nn.ReLU(True),

            nn.Conv3d(channels, out_channels, kernel_size=3, stride=1, padding=0),
            nn.ReLU(True)
        ]
        if bath_normal: # 如果要添加BN层
            layers.insert(1, nn.BatchNorm3d(channels))
            layers.insert(len(layers) - 1, nn.BatchNorm3d(out_channels))

        # 构造序列器
        self.double_conv = nn.Sequential(*layers)

    def forward(self, x):
        return self.double_conv(x)

class DownSampling(nn.Module):
    def __init__(self, in_channels, out_channels, batch_normal=False):
        super(DownSampling, self).__init__()
        self.maxpool_to_conv = nn.Sequential(
            nn.MaxPool3d(kernel_size=2, stride=2),
            DoubleConv(in_channels, out_channels, batch_normal)
        )

    def forward(self, x):
        return self.maxpool_to_conv(x)

class UpSampling(nn.Module):
    def __init__(self, in_channels, out_channels, batch_normal=False, bilinear=True):
        super(UpSampling, self).__init__()
        if bilinear:
            # 采用双线性插值的方法进行上采样
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            # 采用反卷积进行上采样
            self.up = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels + in_channels / 2, out_channels, batch_normal)

    # inputs1:上采样的数据(对应图中黄色箭头传来的数据)
    # inputs2:特征融合的数据(对应图中绿色箭头传来的数据)
    def forward(self, inputs1, inputs2):
        # 进行一次up操作
        inputs1 = self.up(inputs1)

        # 进行特征融合
        outputs = torch.cat([inputs1, inputs2], dim=1)
        outputs = self.conv(outputs)
        return outputs

class LastConv(nn.Module):
    def __init__(self, in_channels, out_channels ):
        super(LastConv, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1 )

    def forward(self, x):
        return self.conv(x)

class UNet3D(nn.Module):
    def __init__(self, in_channels, num_classes=2, batch_normal=False, bilinear=True):
        super(UNet3D, self).__init__()
        self.in_channels = in_channels
        self.batch_normal = batch_normal
        self.bilinear = bilinear

        self.inputs = DoubleConv(in_channels, 64, self.batch_normal)
        self.down_1 = DownSampling(64, 128, self.batch_normal)
        self.down_2 = DownSampling(128, 256, self.batch_normal)
        self.down_3 = DownSampling(256, 512, self.batch_normal)

        self.up_1 = UpSampling(512, 256, self.batch_normal, self.bilinear)
        self.up_2 = UpSampling(256, 128, self.batch_normal, self.bilinear)
        self.up_3 = UpSampling(128, 64, self.batch_normal, self.bilinear)
        self.outputs = LastConv(64, num_classes)

    def forward(self, x):
        # down 部分
        x1 = self.inputs(x)
        x2 = self.down_1(x1)
        x3 = self.down_2(x2)
        x4 = self.down_3(x3)

        # up部分
        x5 = self.up_1(x4, x3)
        x6 = self.up_2(x5, x2)
        x7 = self.up_3(x6, x1)
        x = self.outputs(x7)

        return x

总结

学习参考

最近学习UNet-3D的内容时,我觉得网络上现有的学习资料不是很多,没有像UNet-2D的学习资料那么多,在这里略总结一下可以参考的学习资料:

1.视频:【中文字幕】3D-图像分割超详细教程U-Net
这是B站上一位up主上传的油管上 的一条视频,有中文字幕,主要讲解的是3D格式的数据,以及如何实现3D格式的图像的分割,也介绍了依赖的库等,视频中的博主使用notebook做的项目,他也演示了整个项目的实现和运行的效果。

2.视频:医学分割项目:自己写的2D、3D医学分割项目
也是B站上的视频,是一位up主介绍的自己做的2D和3D的分割项目,视频主要介绍的是如何使用代码(视频下方有代码的github链接),我觉得可以参考up主的代码是怎么实现的,参考他的数据处理方式。

3.博文:3D-UNet
分析了论文中的细节,介绍了3维数据以及3D卷积等

4.博文+代码实现
3D-UNet的Pytorch实现
pytorch实战-Unet3d(LiTS)
3D U-Net脑胶质瘤分割BraTs + Pytorch实现
3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation 论文解读与程序复现
【MICCAI2018论文翻译】使用集成的3D U-Net分割脑肿瘤和放射特征预测总生存率

个人感受

对于我个人来说,想要快速学习的话就是先看论文理解原理,然后直接看代码,参考一些条理清晰的项目代码,理解该项目的流程,学习代码是如何实现的。
我觉得对于我个人来说,数据处理方面有点儿难,数据处理也是花时间最多的一步,怎么处理数据,怎么把处理好的数据输入到网络中训练,等等。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
xiaoxingxing的头像xiaoxingxing管理团队
上一篇 2023年3月7日 下午10:36
下一篇 2023年3月7日 下午10:37

相关推荐