机器学习笔记 – 使用 GAN 进行数据增强以进行缺陷检测

一、用于数据增强的 GAN

        在机器学习中,训练数据量不足往往会阻碍分类算法的性能。经验表明,训练数据不足是常态,而不是例外,这就是为什么人们提出了数据增强方法。

        我们可以使用数据增强,例如稍微旋转或翻转原始数据以生成新的训练数据。但这当然不会给我们带来真正的新形象。

        反过来,GAN 确实输出了全新的图像。您可能听说过 GAN 作为一种创建极其逼真的假图像和视频的手段(在“Deepfake”一词下广为流传)。正如最近的研究(例如Antoniou 等人 2017、Wang 等人 2018和Frid-Adar 等人 2018)所表明的,它们还可以通过生成额外的训练数据来提高机器学习分类器的性能。

(1)工业应用

        当我们处理稀缺的训练数据时,GAN 数据增强方法特别有前途。

        想象一下,我们想要训练一个机器学习模型来识别工业生产流程中的缺陷组件。希望缺陷很少发生;但这也意味着我们可能只有少量图像显示出典型的缺陷来训练网络。

        使用 GAN,我们可以为任何给定的缺陷类型生成额外的图像。

(2)数据

        我们使用NEU 表面缺陷数据库,其中包含 300 张生产过程中出现的金属划痕图像。

http://faculty.neu.edu.cn/yunhyan/NEU_surface_defect_database.html机器学习笔记 - 使用 GAN 进行数据增强以进行缺陷检测http://faculty.neu.edu.cn/yunhyan/NEU_surface_defect_database.html        数据集内一些图像示例

机器学习笔记 - 使用 GAN 进行数据增强以进行缺陷检测

        GAN 是一种无监督学习方法,因此我们不需要任何标签。我们没有想要区分的不同类型的标记图像,而是有一组我们试图模仿的未标记数据。

二、GAN网络

(1)工作原理概述

        GAN 不是一个单一的神经网络。相反,它结合了两个互相玩游戏的神经网络。我将简要解释一下游戏规则。

        首先,有一个鉴别器网络,它只是一个简单的卷积神经网络(CNN)。然后我们有生成器网络,它或多或少是一个反向的 CNN。它获得一个随机输入,并通过使用转置卷积对输入进行上采样来创建图像作为输出。

        游戏进行如下:生成器获取随机输入并生成图片。鉴别器交替获取生成的图像和原始图像(不知道哪个是哪个),并尝试预测给定图像是原始图像还是生成的,仅考虑图像的特征。

        随着时间的推移,这两个网络都试图变得更好。鉴别器试图将真实图像与生成的图像区分开来,而生成器旨在欺骗鉴别器,使其认为其图像是真实的。

机器学习笔记 - 使用 GAN 进行数据增强以进行缺陷检测

        鉴别器试图最大化它的成功,而生成器试图最小化它。

        下图是训练 GAN 时发生的情况的一个很好的可视化。图像来源

GAN Lab: Play with Generative Adversarial Networks in Your Browser!机器学习笔记 - 使用 GAN 进行数据增强以进行缺陷检测https://poloclub.github.io/ganlab/

机器学习笔记 - 使用 GAN 进行数据增强以进行缺陷检测

 (2)网络配置

batch_size = 12
generator_depth = 64
discriminator_depth = 128 loss_function=nn.BCELoss()
number_of_epochs = 128
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0004, betas(0.5,0.999))
generator_optimizer = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5,0.999))

        在下面的代码块中,我们定义了鉴别器,它将图像作为输入。我们定义了模型用来分类这个输入图像的一系列过滤器。当我们训练它时,我们会调整这些过滤器,以便它学会区分原始图像和生成的图像。

class Discriminator(nn.Module):
    '''
    The Discriminator that shall distinguish between dataset images and the ones generated by the generator.
    '''
    def __init__(self, number_of_gpus):
        super(Discriminator, self).__init__()
        self.ngpu = number_of_gpus
        
        self.layer1 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=3, out_channels=discriminator_depth, 
                                    kernel_size=(4,4), stride=2, padding=1, bias=False)),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.layer2 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=discriminator_depth, out_channels=discriminator_depth*2, 
                                    kernel_size=(4,4), stride=2, padding=1, bias=False)),
            nn.BatchNorm2d(discriminator_depth*2),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.layer3 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=discriminator_depth*2, out_channels=discriminator_depth*4, 
                                    kernel_size=(4,4), stride=2, padding=1, bias=False)),
            nn.BatchNorm2d(discriminator_depth*4),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.layer4 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=discriminator_depth*4, out_channels=discriminator_depth*8, 
                                    kernel_size=(4,4), stride=2, padding=1, bias=False)),
            nn.BatchNorm2d(discriminator_depth*8),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.layer5 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=discriminator_depth*8, out_channels=discriminator_depth*16, 
                                    kernel_size=(4,4), stride=2, padding=1, bias=False)),
            nn.BatchNorm2d(discriminator_depth*16),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.output_layer = nn.Sequential(
            nn.Conv2d(in_channels=discriminator_depth*16, out_channels=1, 
                                    kernel_size=(4,4), stride=1, padding=0, bias=False),
            nn.Sigmoid()
        )


    def forward(self, input_image):
    
        layer1 = self.layer1(input_image)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)
        layer5 = self.layer5(layer4)
        return self.output_layer(layer5)

        生成器具有与鉴别器相似的过滤器,只是相反。它不是查看图片来检测模式,而是根据我们教它绘制的模式返回图像。输入是一堆随机数,它们激活这些过滤器以绘制图像。

class Generator(nn.Module):
    '''
    The Generator Network. It is mostly a reversed discriminator with a random input noise which outputs an image.
    '''
    def __init__(self, number_of_gpus):
        super(Generator, self).__init__()
        self.ngpu = number_of_gpus
        
        self.layer1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=100, out_channels=generator_depth*16, 
                               kernel_size=(4,4), stride=1, padding=0, bias=False),
            nn.BatchNorm2d(num_features=generator_depth*16),
            nn.ReLU(inplace=True)
        )

        self.layer2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=generator_depth*16, out_channels=generator_depth*8, 
                               kernel_size=(4,4), stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=generator_depth*8),
            nn.ReLU(inplace=True)
        )
        
        self.layer3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=generator_depth*8, out_channels=generator_depth*4, 
                               kernel_size=(4,4), stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=generator_depth*4),
            nn.ReLU(inplace=True)
        )
            
        self.layer4 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=generator_depth*4, out_channels=generator_depth*2, 
                               kernel_size=(4,4), stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=generator_depth*2),
            nn.ReLU(inplace=True)
        )
        
        self.layer5 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=generator_depth*2, out_channels=generator_depth, 
                               kernel_size=(4,4), stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=generator_depth),
            nn.ReLU(inplace=True)
        )
        
        self.output_layer = nn.Sequential(
            nn.ConvTranspose2d(in_channels=generator_depth, out_channels=3, 
                               kernel_size=(4,4), stride=2, padding=1, bias=False),
            nn.Tanh()
        )
        

    def forward(self, input_noise):
        
        layer1 = self.layer1(input_noise)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)
        layer5 = self.layer5(layer4)
        return self.output_layer(layer5)

(3)训练

        我们可以将训练分为三个部分。

        用真实图像训练判别器:

discriminator.zero_grad()

prediction = discriminator(batch)

labels_for_dataset_images = torch.ones((batch_size,), device=device).view(-1)

loss_discriminator = loss_function(prediction.view(-1), labels_for_dataset_images)
loss_discriminator.backward()

        用生成器生成的图像训练鉴别器:

random_noise = torch.randn(batch_size,100,1,1, device=device)        
generated_image = generator(random_noise)

labels_for_generated_images = torch.zeros(np.prod(prediction.size()), device=device)

prediction = discriminator(generated_image.detach())

loss_generator = loss_function(prediction.view(-1), labels_for_generated_images)
loss_generator.backward()

discriminator_optimizer.step()

        训练生成器:

generator.zero_grad()

prediction = discriminator(generated_image).view(-1)

loss_generator = loss_function(prediction, labels_for_dataset_images)
loss_generator.backward()

generator_optimizer.step()
机器学习笔记 - 使用 GAN 进行数据增强以进行缺陷检测
生成图像的示例

        如果生成器过拟合,我们可以从数据集中获得与图像非常相似甚至几乎相同的图像。这当然不是我们想要的结果。所以我们测试我们生成的图像与数据集中的图像有多相似。

        使用k最近邻方法。这是一种分类算法,可以从要分类的图像中搜索“最近”的图像到数据集中的所有图像,以观察是否有过拟合产生。

def euclidean_distance(a, b):
    '''
    Calculates the euklidean Distance of two torch tensors of the same size.
    '''
    return torch.sqrt(((a - b) ** 2).sum())


def get_k_nearest_samples(image, k):
    '''
    Searches for the k-nearest samples in the dataset of a given image based on the euclidean distance.
    '''
    return np.argsort([euclidean_distance(image[0][0], sample[0][0]) for sample in dataset])[:k]

        这些图像与数据集图像相似,但它们不太相似——因此生成器没有过度拟合。

三、结论

        生成对抗网络确实学会了如何从给定的数据分布中生成新图像:它们是真正的新图像,因为它们不仅仅是原始图像的副本,而且仍然无法与原始图像区分开来。因此,我们可以使用这些新创建的图像来训练缺陷检测或缺陷分类模型。

        当然,在实际情况下,您应该始终仔细检查 GAN 创建的图像是否真的对模型性能产生了积极影响。情况可能并非总是如此。

        话虽如此,GAN(不仅限于)在工业生产中有很多潜在的用例。由于目前对 GAN 的研究兴趣,我们很快就会对何时以及如何使用它们有很多新的见解。

        值得注意的是,调整GAN微小的变化可能会导致输出失真。

四、代码参考

        代码参考

https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.htmlicon-default.png?t=M7J4https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html         NEU表面缺陷数据集分享

链接:https://pan.baidu.com/s/1bkSKlWFGkvEy2WdkXYEPVQ 
提取码:zl4y

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
心中带点小风骚的头像心中带点小风骚普通用户
上一篇 2023年2月25日 下午9:08
下一篇 2023年2月25日 下午9:14

相关推荐