UNet – 预测数据predict(多个图像的分割)

目录


1. 介绍

之前已经将unet的网络模块、dataset数据加载和train训练数据已经解决了,这次要将unet网络去分割图像,下面是之前的链接

unet 网络:UNet – unet网络

dataset 数据处理:UNet – 数据加载 Dataset

train 网络训练:UNet – 训练数据train

待分割的图像如下:

UNet - 预测数据predict(多个图像的分割)

 存放的路径在U-net项目的predict里面

UNet - 预测数据predict(多个图像的分割)

我们的目标是将predict里面所有的图片分割出来,按照名称顺序保存在result文件夹里面:

UNet - 预测数据predict(多个图像的分割)

2. predict 预测分割图片

首先定义图片的预处理,按照dataset里面相同的方式进行预处理

UNet - 预测数据predict(多个图像的分割)

然后是加载网络的模型和网络参数

UNet - 预测数据predict(多个图像的分割)

 然后加载predict里面所有待处理图片的路径

需要注意的是,os.listdir 加载的只是里面每个图片,并不是图片的具体路径。tests_path 里面的内容如下面的注释所示:

UNet - 预测数据predict(多个图像的分割)

接下来就可以分割图片了

因为tests_path 里面每个文件是 x.png 即文件名+后缀的方式。通过split的 '.' 分割成x和后缀名png的形式,[-2]代表取倒数第二个值,就可以将每个文件名x取出来,然后将路径拼接就可以存放到result里面

open图像的时候,也要注意,test_path 只是遍历tests_path 里面的文件,需要加上之前的predict路径才能正确的读取到每个待分割的图片

因为这里处理图像会改变size成480*480的形式,想要将输出的结果保持不变的话,在网络预测前将图像的大小保存下来就可以了。(注:这里的size和opencv里面的shape返回值是反过来的

这里不清楚的可以通过调试,打印每个变量的内容看一下就可以了

UNet - 预测数据predict(多个图像的分割)

接下来就是网络预测的部分,这里输出的size是(batch,channel,height,width),因为这里的batch是1,channel 灰度图片因此也是1,这里通过squeeze将1的维度删去,只需要图像的大小

下面是squeeze的用法UNet - 预测数据predict(多个图像的分割)

然后图像保存的话,要转到cpu上面 ,这一步不知道为啥,但是不加这一步会报错

UNet - 预测数据predict(多个图像的分割)

 最后就是保存图像了,将网络的结果二值化后,还原图像再保存就可以了

UNet - 预测数据predict(多个图像的分割)

3. 结果展示

predict里面待预测的图片

UNet - 预测数据predict(多个图像的分割)

result 里面分割好的图片

UNet - 预测数据predict(多个图像的分割)

下面是 参考文章 博主的分割结果

 UNet - 预测数据predict(多个图像的分割)

对比发现,有些小的细节会丢失,但是大概的轮廓分割出来了

4. 完整代码

完整的项目可以在 这里 下载

import numpy as np
import torch
import cv2
from model import UNet
from torchvision import transforms
from PIL import Image
import os


# 预处理
transform = transforms.Compose([
    transforms.Resize((480,480)),        # 缩放图像
    transforms.ToTensor(),
])

# 加载模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(in_channels=1, num_classes=1)
net.load_state_dict(torch.load('Unet.pth', map_location=device))
net.to(device)

# 测试模式
net.eval()
# 读取所有图片路径
tests_path = os.listdir('./predict/')   # 获取 './predict/' 路径下所有文件,这里的路径只是里面文件的路径
''''
print(tests_path)
['0.png', '1.png', '10.png', '11.png', '12.png', '13.png', '14.png', 
'15.png', '16.png', '17.png', '18.png', '19.png', '2.png', '20.png', 
'21.png', '22.png', '23.png', '24.png', '25.png', '26.png', '27.png',
 '28.png', '29.png', '3.png', '4.png', '5.png', '6.png', '7.png', '8.png', '9.png']
'''


with torch.no_grad():                   # 预测的时候不需要计算梯度
    for test_path in tests_path:        # 遍历每个predict的文件
        save_pre_path = './result/'+test_path.split('.')[-2] + '_res.png'    # 将保存的路径按照原图像的后缀,按照数字排序保存
        img = Image.open('./predict/' +test_path)           # 预测图片的路径
        width,height = img.size[0],img.size[1]              # 保存图像的大小
        img = transform(img)
        img = torch.unsqueeze(img,dim = 0)                  # 扩展图像的维度

        pred = net(img.to(device))                          # 网络预测
        pred = torch.squeeze(pred)                          # 将(batch、channel)维度去掉
        pred = np.array(pred.data.cpu())                    # 保存图片需要转为cpu处理

        pred[pred >= 0] = 255                               # 处理结果二值化
        pred[pred < 0] = 0

        pred = np.uint8(pred)                               # 转为图片的形式
        pred = cv2.resize(pred,(width,height),cv2.INTER_CUBIC)          # 还原图像的size
        cv2.imwrite(save_pre_path, pred)                    # 保存图片

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
乘风的头像乘风管理团队
上一篇 2023年3月4日 上午10:52
下一篇 2023年3月4日 上午10:55

相关推荐