超分辨网络SRCNN的Pytorch实现

整体框架

SR,即super resolution,即超分辨率。CNN相对来说比较著名,就是卷积神经网络了。从名字可以看出,SRCNN是首个应用于超分辨领域的卷积神经网络,事实上也的确如此。

所谓超分辨率,就是把低分辨率(LR, Low Resolution)图片放大为高分辨率(HR, High Resolution)的过程。由于是开山之作,SRCNN相对比较简单,总共分三步

  1. 输入LR图像超分辨网络SRCNN的Pytorch实现,经双三次(bicubic)插值,被放大成目标尺寸,得到超分辨网络SRCNN的Pytorch实现
  2. 通过三层卷积网络拟合非线性映射
  3. 输出HR图像结果超分辨网络SRCNN的Pytorch实现

训练的目标损失是最小化SR图像超分辨网络SRCNN的Pytorch实现和原高分辨率图像超分辨网络SRCNN的Pytorch实现像素差的均方误差

超分辨网络SRCNN的Pytorch实现

其中,超分辨网络SRCNN的Pytorch实现为训练样本数,参数更新公式为

超分辨网络SRCNN的Pytorch实现

网络模型

其网络结构如下

如前所述,网络分为三个卷积层

  1. 维度是超分辨网络SRCNN的Pytorch实现,表示输入图像通道数为1,进行卷积运算的核尺寸为超分辨网络SRCNN的Pytorch实现,输出深度为64。
  2. 维度是超分辨网络SRCNN的Pytorch实现,64即上一层输出,32为下一层输出。
  3. 尺寸为超分辨网络SRCNN的Pytorch实现。它的输出是单通道图像,与输入相同。

所以这个模型实现起来并不难

# models.py
class SRCNN(nn.Module):
    def __init__(self, nChannel=1):
        super(SRCNN,self).__init__()
        self.conv1 = nn.Conv2d(nChannel, 64,
            kernel_size=9, padding=9//2)
        self.conv2 = nn.Conv2d(64, 32,
            kernel_size=5, padding=5//2)
        self.conv3 = nn.Conv2d(32, nChannel, 
            kernel_size=5, padding=5//2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self,x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x

数据集

训练数据集可手动生成,设放大倍数为scale,考虑到原始数据未必会被scale整除,所以要重新规划一下图像尺寸,所以训练数据集的生成分为三步:

  1. 将原始图像通过双三次插值重设尺寸,使之可被scale整除,作为高分辨图像数据HR
  2. 将HR通过双三次插值压缩scale倍,为低分辨图像的原始数据
  3. 将低分辨图像通过双三次插值放大scale倍,与HR图像维度相等,作为低分辨图像数据LR

最后,可通过h5py将训练数据分块并打包,其生成代码为

import h5py
import PIL.Image as pImg

def rgb2gray(img):
    return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.

# imgPath为图像路径;h5Path为存储路径;scale为放大倍数
# pSize为patch尺寸; pStride为步长
def setTrianData(imgPath, h5Path, scale=3, pSize=33, pStride=14):
    h5_file = h5py.File(h5Path, 'w')
    lrPatches, hrPatches = [], []       #用于存储低分辨率和高分辨率的patch
    for p in sorted(glob.glob(f'{imgPath}/*')):
        hr = pImg.open(p).convert('RGB')
        lrWidth, lrHeight = hr.width // scale, hr.height // scale
        # width, height为可被scale整除的训练数据尺寸
        width, height = lrWidth*scale, lrHeight*scale
        hr = hr.resize((width, height), resample=pImg.BICUBIC)
        lr = hr.resize((lrWidth, lrHeight), resample=pImg.BICUBIC)
        lr = lr.resize((width, height), resample=pImg.BICUBIC)
        hr = np.array(hr).astype(np.float32)
        lr = np.array(lr).astype(np.float32)
        hr = rgb2gray(hr)
        lr = rgb2gray(lr)
        # 将数据分割
        for i in range(0, height - pSize + 1, pStride):
            for j in range(0, width - pSize + 1, pStride):
                lrPatches.append(lr[i:i + pSize, j:j + pSize])
                hrPatches.append(hr[i:i + pSize, j:j + pSize])
    h5_file.create_dataset('lr', data=np.array(lrPatches))
    h5_file.create_dataset('hr', data=np.array(hrPatches))
    h5_file.close()

以比较常见的T91数据集为例,通过上面的方法,可以得到一个181M的h5文件。

对预测数据执行相同的操作。

在做好训练数据之后,需要为这些数据创建一个读取类,以便torch中的DataLoader调用,而DataLoader中的内容则是Dataset,所以新建的读取类需要继承Dataset,并实现其__getitem__和__len__这两个成员方法。

这两个方法只是看上去吓人,但对Python稍有一点深入了解,就会知道__getitem__是字典索引的方法,而__len__则设定了len函数的返回值。

import h5py
import numpy as np
from torch.utils.data import Dataset

class DataSet(Dataset):
    def __init__(self, h5_file):
        super(Dataset, self).__init__()
        self.h5_file = h5_file

    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])

训练

首先,训练需要一点准备工作,比如数据集准备好,相关的文件夹需要建好,建好模型之后,需要采用什么样的优化方式。训练设备是用cpu还是cuda,然后将数据集和模型装载到设备上。

数据准备

import os
import copy
import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from models import SRCNN

trainFile = "91-image.h5"
evalFile = "Set5.h5"

cudnn.benchmark = True
# 设置训练设备 是CPU还是cuda
device = torch.device(
  'cuda:0' if torch.cuda.is_available() else 'cpu')

# 装载训练数据
trainData = Dataset(trainFile)
trainLoader = DataLoader(dataset=trainData,
  bSize=bSize,
  shuffle=True,               # 表示打乱样本
  num_workers=nWorker,        # 线程数
  pin_memory=True,            # 方便载入CUDA
  drop_last=True)

# 装载预测数据
evalDatas = Dataset(evalFile)
evalLoader = DataLoader(dataset=evalDatas, bSize=1)

模型准备

# 模型和设备
lr = 1e-4       #学习率
torch.manual_seed(seed)     #设置随机数种子
model = SRCNN().to(device)  #将模型载入设备
criterion = nn.MSELoss()    #设置损失函数
optimizer = optim.Adam([
  {'params': model.conv1.parameters()},
  {'params': model.conv2.parameters()},
  {'params': model.conv3.parameters(), 'lr': lr * 0.1}
], lr=lr)

train

outPath = "outputs"
scale = 3
bSize = 16
nEpoch = 400
nWorker = 8     #线程数
seed = 42       #随机数种子

def initPSNR():
    return {'avg':0, 'sum':0, 'count':0}

def updatePSNR(psnr, val, n=1):
    s = psnr['sum'] + val*n
    c = psnr['count'] + n
    return {'avg':s/c, 'sum':s, 'count':c}

bestWeights = copy.deepcopy(model.state_dict()) #最佳模型
bestEpoch = 0   #最佳训练结果
bestPSNR = 0.0  #最佳psnr

# 训练主循环
for epoch in range(nEpoch):
  model.train()
  epochLosses = initPSNR()

  for data in trainLoader:
      inputs, labels = data
      inputs = inputs.to(device)
      labels = labels.to(device)
      preds = model(inputs)
      loss = criterion(preds, labels)
      epochLosses = updatePSNR(epochLosses,loss.item(), len(inputs))
      optimizer.zero_grad()   #清空梯度
      loss.backward()         #反向传播
      optimizer.step()        #根据梯度更新网络参数
      print(f'{epochLosses['avg']:.6f}')

  torch.save(model.state_dict(), 
      os.path.join(outPath, f'epoch_{epoch}.pth'))

  model.eval()    #取消dropout
  psnr = AverageMeter()

  for data in evalLoader:
      inputs, labels = data
      inputs = inputs.to(device)
      labels = labels.to(device)
      # 令reqires_grad自动设为False,关闭自动求导
      # clamp将inputs归一化为0到1区间
      with torch.no_grad():
          preds = model(inputs).clamp(0.0, 1.0)

      tmp_psnr = 10. * torch.log10(
          1. / torch.mean((preds - labels) ** 2))
      psnr = updatePSNR(psnr, tmp_psnr, len(inputs))

  print(f'eval psnr: {psnr.avg:.2f}')

  if psnr['avg'] > bestPSNR:
      bestEpoch = epoch
      bestPSNR = psnr['avg']
      bestWeights = copy.deepcopy(model.state_dict())

print(f'best epoch: {bestEpoch}, psnr: {bestPSNR:.2f}')
torch.save(bestWeights, os.path.join(outPath, 'best.pth'))

最终结果是
原图来自网络

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
乘风的头像乘风管理团队
上一篇 2022年5月6日
下一篇 2022年5月6日

相关推荐