python 参加某图像去噪比赛有感

使用之前的去噪图像来抑制文本

1.经验

本菜鸡本科毕设在FPGA上搞过图像滤波等算法,研究生期间虽然搞的是基于深度学习的图形学,但是主干网络用的还是卷积… 感觉自己代码能力还可以,基础还行,参赛之前还是比较自信的:
觉着看几篇顶会去噪的文章,复现借鉴一下应该能取得一个不错的结果,但是——-大概1000+人参赛,一多半没有提交的或者只提交个baseline,本菜最终100+ 额还没结束 明天结束了估计排名快接近200了 实在卷不动了
主要存在三个问题:

  1. 和男生的能力差距还是不小的,可能也和研究方向有关。毕竟不是专业的。
  2. Money is all my need?? 经历了学校服务器排队人数爆满、维修,小组里的机器也排不上队(我一个参加比赛的也不好意思和别人抢- -),就想调个参,很难,即便排上队 – 跑的时候batch_size都调的很小才能跑
  3. 我——工具人……论文直接被拒了审稿人到底提了什么——

虽然我知道我自己的食物,但我还是想尝试一下。吐槽结束,进入话题:::::

2. 收获

谈收获,虽难,但收获也不少

  1. 看了几篇cv顶会的去噪文章,了解并尝试了cv算法中low-level的方向
  2. 尝试复现了两篇顶会,效果并没有baseline好,差不太多 – – (可能复现的不太对,毕竟只是借用思想不是完全拷贝) 最终魔改了一篇别的论文.
  3. 从dataloder、网络框架、网络初始化、训练策略到最后的损失函数等等,第一次完整的写了一个深度学习的项目(以前都是拿别人代码框架改改),遇到很多坑,也学到了许多新的知识点

3.经验分享(部分源码展示和注释)

3.1 输入

图片是要切片的,一整张图太大了,网络稍大点,32G的显卡也会爆显存
将一个图划分为多个图,伪代码如下:

# 外层是一个循环 根据图像大小进行切片 
tmp['imgs'] = data['imgs'][:, :, a:b, c:d]     # batch 通道数 图片的长和宽
tmp['gts'] = data['gts'][:, :, a:b, c:d]       # 标签
model.set_input(tmp)                           # 网络输入

3.2 网络

我的网络借鉴的主要思想:

1.不直接学习端到端的像素值,而是学习噪声(网络更容易拟合?)
2.使用通道可分离的卷积,适当增加通道数(显存太小,跑起来速度很慢)
3.尝试增加卷积核大小(显存太小,跑起来速度很慢)

(比赛有模型大小限制)——增加通道和卷积核会增加显存的使用,设备不好,所以只有增加通道数。具体实施细则如下:

纯纯的Unet baseline修改而来

class Unet2(nn.Module):
    def __init__(self, dim=4):
        super(Unet2, self).__init__()
        self.dims = [32, 64, 128, 256, 512]
        self.ks = [3, 3, 3, 3, 3]
        self.dims_up = self.dims[::-1]
        self.ks_up = self.ks[-2::-1]

        self.first_block = Block2(dim, self.dims[0], self.ks[0])
        self.first_pool = nn.MaxPool2d(kernel_size=2)  # AvgPool2d pnsr: 37.683, ssim: 0.902, score: 30.679, time: 52.650

        for i, dim_in in enumerate(self.dims[:-2]):
            dim_out = self.dims[i+1]
            setattr(self, 'Block{}'.format(i), Block2(dim_in, dim_out, k=self.ks[i+1]))
            setattr(self, 'pool{}'.format(i), nn.MaxPool2d(kernel_size=2))

        self.conv_mid = Block2(self.dims[-2], self.dims[-1], self.ks[-1])

        for i, dim_in in enumerate(self.dims_up[:-1]):
            dim_out = self.dims_up[i+1]
            setattr(self, 'ConvTrans{}'.format(i), nn.ConvTranspose2d(dim_in, dim_out, 2, stride=2, bias=True))
            setattr(self, 'up_Block{}'.format(i), Block2(dim_in, dim_out, k=self.ks_up[i]))

        self.last_conv = nn.Conv2d(self.dims[0], dim, 1, bias=True)

    def forward(self, x):
        n, c, h, w = x.shape
        h_pad = 32 - h % 32 if not h % 32 == 0 else 0
        w_pad = 32 - w % 32 if not w % 32 == 0 else 0
        padded_image = F.pad(x, (0, w_pad, 0, h_pad), 'replicate')
        list_pools = []

        x_bk = x
        # 1.first Block
        x = self.first_block(padded_image)
        list_pools.append(x)
        x = self.first_pool(x)

        # 2.Blocks
        for i, dim_in in enumerate(self.dims[:-2]):
            x = getattr(self, 'Block{}'.format(i))(x)
            list_pools.append(x)
            x = getattr(self, 'pool{}'.format(i))(x)

        x = self.conv_mid(x)

        for i, dim_in in enumerate(self.dims_up[:-1]):
            x = getattr(self, 'ConvTrans{}'.format(i))(x)
            # tmp = list_pools.pop()
            x = torch.cat([x, list_pools.pop()], 1)
            x = getattr(self, 'up_Block{}'.format(i))(x)

        # 3.last
        x = self.last_conv(x)
        out = x[:, :, :h, :w] + x_bk

        return out


class Block2(nn.Module):
    def __init__(self, dim_in, dim_out, k=3):
        super(Block2, self).__init__()
        self.conv1 = nn.Conv2d(dim_in, dim_in, kernel_size=k, padding=k // 2, padding_mode='zeros', bias=True)
        self.conv2 = nn.Conv2d(dim_in, dim_out, kernel_size=k, padding=k // 2, padding_mode='zeros', bias=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.leaky_relu(x)
        x = self.conv2(x)
        x = self.leaky_relu(x)
        return x

    def leaky_relu(self, x, a=0.2):
        out = torch.max(a * x, x)
        return out

我使用的网络 魔改ConvNet

class Our(nn.Module):
    def __init__(self, dim=4):
        super(Our, self).__init__()
        self.dims = [128, 256, 512, 1024]
        self.ks = [3, 3, 3, 3]
        # 内存不够啊

        # self.dims = [16, 32, 64, 128, 256]
        # self.ks = [23, 23, 23, 17, 3]
        ######################################
        self.dims_up = self.dims[::-1]
        self.ks_up = self.ks[-2::-1]

        self.first_block = Block(dim, self.dims[0], self.ks[0])
        self.first_pool = nn.MaxPool2d(kernel_size=2)

        for i, dim_in in enumerate(self.dims[:-2]):
            dim_out = self.dims[i+1]
            setattr(self, 'Block{}'.format(i), Block(dim_in, dim_out, k=self.ks[i+1]))
            setattr(self, 'pool{}'.format(i), nn.MaxPool2d(kernel_size=2))

        self.conv_mid = Block(self.dims[-2], self.dims[-1], self.ks[-1])

        for i, dim_in in enumerate(self.dims_up[:-1]):
            dim_out = self.dims_up[i+1]
            setattr(self, 'ConvTrans{}'.format(i), nn.ConvTranspose2d(dim_in, dim_out, 2, stride=2))
            setattr(self, 'up_Block{}'.format(i), Block(dim_in, dim_out, k=self.ks_up[i]))

        self.last_ln = nn.LayerNorm(self.dims[0], eps=1e-6)
        self.last_conv = nn.Linear(self.dims[0], dim)

    def forward(self, x):
        n, c, h, w = x.shape
        h_pad = 32 - h % 32 if not h % 32 == 0 else 0
        w_pad = 32 - w % 32 if not w % 32 == 0 else 0
        padded_image = F.pad(x, (0, w_pad, 0, h_pad), 'replicate')
        list_pools = []
        x_bk = x
        # 1.first Block
        x = self.first_block(padded_image)
        list_pools.append(x)
        x = self.first_pool(x)

        # 2.Blocks
        for i, dim_in in enumerate(self.dims[:-2]):
            x = getattr(self, 'Block{}'.format(i))(x)
            list_pools.append(x)
            x = getattr(self, 'pool{}'.format(i))(x)

        x = self.conv_mid(x)

        for i, dim_in in enumerate(self.dims_up[:-1]):
            x = getattr(self, 'ConvTrans{}'.format(i))(x)
            # tmp = list_pools.pop()
            x = torch.cat([x, list_pools.pop()], 1)
            x = getattr(self, 'up_Block{}'.format(i))(x)

        # 3.last
        x = x.permute(0, 2, 3, 1).contiguous()
        x = self.last_ln(x)
        x = self.last_conv(x)
        x = x.permute(0, 3, 1, 2).contiguous()
        out = x[:, :, :h, :w] + x_bk

        return out



class Block(nn.Module):
    def __init__(self, dim_in, dim_out, k=9):
        super(Block, self).__init__()
        self.conv = nn.Conv2d(dim_in, dim_in, groups=dim_in, kernel_size=k, padding=k // 2)
        self.ln = nn.LayerNorm(dim_in,eps=1e-6)
        self.conv1x1up = nn.Linear(dim_in, dim_in * 2) #nn.Conv2d(dim, dim * 2, 1)
        self.act = nn.GELU()
        self.conv1x1dn = nn.Linear(dim_in * 2, dim_out) #nn.Conv2d(dim * 2, dim, 1)
        self.w = nn.Parameter(torch.zeros(1))
        # res
        self.res_conv = nn.Conv2d(dim_in, dim_out, 1)

    def forward(self, x):
        identity = x
        x = self.conv(x)
        x = x.permute(0, 2, 3, 1).contiguous()
        x = self.ln(x)
        x = self.conv1x1up(x)
        x = self.act(x)
        x = self.conv1x1dn(x)
        x = x.permute(0, 3, 1, 2).contiguous()
        x = x * self.w
        x = x + self.res_conv(identity)
        return x

3.3 损失函数

loss = torch.nn.L1Loss()

实测了一下,还是L1效果好啊
其它L2、SSIM之类的花里胡哨的效果并不理想 (毕竟是炼丹,可能只是不适合我的网络)

3.4 传统滤波方法

哈、我还试了一下传统的去噪,顺便使用纯python写了一个双边滤波(参考我以前matlab的代码),不得不说,还是深度学习yyds!

def bilateral_filter(img):
    # 参考自己博客 matlab的实现 https://blog.csdn.net/qq_38204686/article/details/106929922
    r = 20                      # 窗口半径     核大小为 2*r + 1
    sigma_space = 15.0          # 空间标准差
    sigma_color = 10.0          # 相似标准差
    w_space = np.zeros((2*r + 1, 2*r + 1))
    for i in range(-r-1, r):
        for j in range(-r-1, r):
            tmp = i * i + j * j
            w_space[i + r+1, j + r+1] = np.exp(-float(tmp) / (2 * sigma_space * sigma_space))
    w_color = np.zeros((1, 256))
    for i in range(256):
        w_color[0, i] = np.exp(-float(i * i) / (2 * sigma_color * sigma_color))

    # 开始滤波
    height, width, channel = img.shape
    dst_img = img.copy()
    for h in range(r, height - r):
        # s = time.time()   0.3s
        for w in range(r, height - r):
            for c in range(channel):     # 通道遍历
                p_c = img[h, w, c]       # 像素值
                p_win = img[h-r:h+r+1, w-r:w+r+1, c]  # 窗口内所有像素
                c_w = np.abs(p_win - p_c).astype(int)
                c_w = w_color[0, c_w]
                w_tmp = w_space * c_w
                p_sum = p_win * w_tmp
                p_sum = np.sum(p_sum) / np.sum(w_tmp)
                dst_img[h, w, c] = p_sum

    return dst_img

4. 主要参考链接

  • https://zhuanlan.zhihu.com/p/455913104 (ConvNeXt: A ConvNet for the 2020s)
  • https://zhuanlan.zhihu.com/p/349644858 (如何白嫖GPU)
  • https://blog.csdn.net/u011447962/article/details/123510680 (CVPR 2022 | RepLKNet)
  • https://github.com/gbstack/CVPR-2022-papers#SG (CVPR2022 Papers (Papers/Codes/Demos))

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2022年5月9日
下一篇 2022年5月9日

相关推荐