cutout是2017年提出的一种数据增强方法,想法比较简单,即在训练时随机裁剪掉图像的一部分,也可以看作是一种类似dropout的正则化方法。
Improved Regularization of Convolutional Neural Networks with Cutout
paper: https://arxiv.org/pdf/1708.04552.pdf
code: https://github.com/uoguelph-mlrg/Cutout
cutout采用的操作是随机裁剪掉图像中的一块正方形区域,并在原图中补0。由于作者在cutout早期版本中使用了不规则大小区域的方式,但是对比发现,固定大小区域能达到同等的效果,因此就没必要这么麻烦去生成不规则区域了。
实现代码比较简单,cutout.py,如下:
import torch
import numpy as np
class Cutout(object):
"""Randomly mask out one or more patches from an image.
Args:
n_holes (int): Number of patches to cut out of each image.
length (int): The length (in pixels) of each square patch.
"""
def __init__(self, n_holes=1, length=16):
self.n_holes = n_holes
self.length = length
def __call__(self, img):
"""
Args:
img (Tensor): Tensor image of size (C, H, W).
Returns:
Tensor: Image with n_holes of dimension length x length cut out of it.
"""
h = img.size(1)
w = img.size(2)
mask = np.ones((h, w), np.float32)
for n in range(self.n_holes):
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img = img * mask
return img
上面代码中有两个参数,具体如下:
n_holes:表示裁剪掉的图像块的数目,默认都是设置为1;
length:每个正方形块的边长,作者经过多轮尝试后,不同数据集最优设置不同,CIFAR10为16,CIFAR100为8,SVHN为20;# 这里觉得挺麻烦的,cutout调参很重要
看看在图像上cutout是什么效果,代码如下:
import cv2
from torchvision import transforms
from cutout import Cutout
# 执行cutout
img = cv2.imread('cat.png')
img = transforms.ToTensor()(img)
cut = Cutout(length=100)
img = cut(img)
# cutout图像写入本地
img = img.mul(255).byte()
img = img.numpy().transpose((1, 2, 0))
cv2.imwrite('cutout.png', img)
由于原图比较大,这里把正方形边长调到了100,效果如下:
实际训练看看效果到底怎么样,为了保证公平,训练时参数统一,且每种模型训练了8次以减少随机性,结果见下表。
Method | CIFAR-10 | CIFAR-100 |
ResNet-50 | 96.76/96.82/96.81/96.79 96.72/96.69/96.60/96.82 (96.75) |
83.80/83.66/84.19/83.26 83.89/83.90/83.57/83.69 (83.74) |
ResNet-50+cutout | 96.73/96.58/96.78/96.65 96.65/96.58/96.77/96.65 (96.67) |
83.63/83.78/83.80/83.49 83.92/83.57/83.71/83.60 (83.69) |
从实验结果来看,在CIFAR10和CIFAR100这两个数据集上使用cutout,训练出来的模型精度都会掉一点。看来cutout涨点并没有那么容易,和调参、模型深度、数据集都有很大的关系。
版权声明:本文为博主一个菜鸟的奋斗原创文章,版权归属原作者,如果侵权,请联系我们删除!
原文链接:https://blog.csdn.net/u013685264/article/details/122562509