参考文章:pytorch复现U-Net 及常见问题汇总(2021.11.14亲测可行)_奶盖芒果的博客-CSDN博客_pytorch 无法复现
1.安装过程
1.1 github代码
代码网址:https://github.com/milesial/Pytorch-UNet
1.2 环境配置
requirements:
matplotlib
numpy
Pillow
torch
torchvision
tqdm
wandb
requirements中未包含具体版本信息,笔者亲测torch11.3 + python3.6可用
1.3 网络数据集验证模型可用
1.3.1 数据集地址
kaggle:Carvana Image Masking Challenge | Kaggle
数据集内容,这里我们下载tran.zip及train_masks.zip文件即可
下载完毕后,数据样式如下:
注意:蒙版图片格式为.gif文件,如果为jpg、png蒙版训练会出错,笔者会在错误总结中具体介绍
1.3.2 网络数据集训练
打开train.py文件,修改数据集路径
如果采用conda环境,在命令行执行命令如下:
conda activate u-net # u-net执行环境名
cd E:\u-net\Pytorch-UNet-master # u-net网络文件夹
python train.py
一切顺利的话,可以看到如下界面:
1.3.3 网络数据集预测
将程序目录中checkpoint文件夹下需要模型文件复制到predict.py文件所在目录
(1)修改predict.py中模型文件名称
或者执行命令时添加-m,后接权重文件名称即可
例如:python predict.py -m checkpoint_epoch5.pth
(2)如果采用conda环境,在命令行执行命令如下:
conda activate u-net # u-net执行环境名
cd E:\u-net\Pytorch-UNet-master # u-net网络文件夹
python predict.py -i 541_36.png --vi -v # -i后为预测图像名称
运行结果如下:
恭喜,代表环境配置成功,可以继续下一阶段了!
2.训练自己的数据集
2.1 数据集准备
(1)图片采用3通道RGB.png图像
通道数通过如下图查看
笔者输入位深度为24位图像正确,输入位深度为32图像报错
(2)图像蒙版为.gif格式二值图
mask文件名称为对应图像名称+_mask,如图所示:
注意:如果采用.png格式或者.jpg格式二值图会报错,错误原因详见错误汇总
2.2 模型训练
2.2.1 predict文件修改
如1.3.2所述,修改相应地址名,epoches,batch_size,lr即可
因为笔者是二分类问题-背景+波,所以不需要修改classes
2.3 模型预测
如1.3.3所述,修改相应图片名即可
3.问题汇总
3.1 模型预测结果全黑
3.1.1 原因1:输入数据集负样本含量过多
如图,笔者第一次训练时未剔除无效数据,导致预测结果全黑
解决方法:增加数据集中正样本数量即可
3.2 模型训练错误
3.2.1 原因1:
RuntimeError:CUDA error: device-side assert triggered
报错如上图所示
解决方法:
(1)类别不匹配,修改classes数量即可
(2)mask格式采用.png,.jpg格式,需要修改为.gif格式
因为.png格式二值图以数组形式读入结果为[0, 255]
而.gif格式以数组形式读入结果为[0,1]
笔者猜测是这个原因导致错误的发生
文章出处登录后可见!