目录
GAN
生成式对抗网络(GAN, Generative Adversarial Networks)是一种深度学习模型。主要包括两部分:生成模型和判别模型。也就是对应神经网络的生成器与判别器:
- 生成器G(Generator):通过生成器G生成数据。
- 判别器D(Discriminator):判断这张图像是真实的还是机器生成的,目的是判别数据是否是生成器做的“假数据”
生成器与判别器互相对抗,不断调整参数。最终的目的是使判别网络无法判断生成网络的输出结果是否真实。
生成器G是一个生成图片的网络,它接收一个随机的噪声,通过这个噪声生成图片,生成的图片记做。
判别器D判别一张图片是不是“真实的”。它的输入是 ,代表一张图片(其中,包含生成图片和真实图片,对于生成图片有 ),输出代表为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表图片0%是真的(或者说100%是假的)
其中网络结构如下所示:
其中,真实数据分布中的数据与生成数据可以认为是相同形状的。
生成网络G(Generative)
生成网络从隐空间(latent space)中随机采样作为输入,其输出结果需要尽量模仿训练集中的真实样本。
对抗网络D(Discriminative)
对抗网络也可称判别网络,判别网络的输入则为真实样本或生成网络的输出,其目的是将生成网络的输出从真实样本中尽可能分辨出来。
两分布之间差异性评价
真实数据分布中的数据 服从分布 ,生成数据分布中的数据 服从分布
那么衡量两个分布之间的差异性指标有:KL散度,JS散度,交叉熵和Wasserstein距离。
KL散度
离散概率分布的KL散度计算公式:
连续概率分布的KL散度计算公式:
JS散度
损失函数
以分布的角度来看GAN网络结构,然后考虑其损失函数。
对于生成网络G,其输入的 (,表示 服从正态分布的数据),通过训练出来的参数 的生成网络生成的图片为 。
对于判别网络,可以认为是二分类问题,一类是生成网络的输出,即 ;另一类是真实数据,(其中, ,表示 服从一种真实的分布distribution)。将 (其中,) 数据输入到判别网络中,输出结果分别为:
分别从生成网络和判别网络的角度来看:
-
对于生成网络的标准就是:我希望我生成的图片越接近真实越好,那么也就是使 越接近1越好。也就是训练生成网络中的参数 满足:
-
对于判别网络的标准就是:我能够很好的区分哪些是真的,哪些是假的。也就是说能够很好的将真的和假的区分开来。也就是希望真实数据的输出 越趋近于1,而生成数据的输出 越趋近于0。
将其看成二分类问题,二分类问题的损失函数可以使用交叉熵损失函数来表示,对于二分类,只有正样本(label=1)与负样本(label=0)。并且两者概率之和为1。对于一个输入,经过模型输出为。y是真实的标签。于是单个样本的损失函数就是:
如果是计算 N 个样本的平均损失函数,只要将 N 个 Loss 叠加起来再除以N就行:
对于 的输出是真实的,也就是标签为1,那么其单个样本损失函数就是:
平均损失函数就是(其实就是求平均):
那么对于 的输出是假的,也就是标签为0,那么其单个样本损失函数就是:
平均损失函数就是:
由于上面损失函数是负数,并且需要最小化损失函数,那么反过来的最大化损失函数就是:
在训练过程中,当G固定的时候,有:
全局优化首先固定G,然后优化D(这种情况也就是生成数据的分布 与真实分布已知),D的最佳情况为:(推导可以看GAN入门理解及公式推导 – 知乎 (zhihu.com))
将最佳的D代入目标loss函数,有:
也就是说,原始GAN的loss实际上等价于JS散度。
一次代码实验
很多的GAN网络结构代码可以参考:PyTorch-GAN
其中一个网络结构如下:
在实验中,对应的形状如下所示:
- 高斯随机变量:
torch.Size([batch_size, 100])
- 生成的fake_image, 真实image:
torch.Size([batch_size, 3,64,64])
- 判别真假:
torch.Size([batch_size, 1])
代码如下:
生成网络代码
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(opt.latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.shape[0], *img_shape)
return img
对抗网络代码
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
)
然后在Anime_Faces数据集中训练,获得的1000个epochs后生成的数据有:
感觉效果不太行,估计是网络结构的局限性。
WGAN
参考论文:[1701.07875] Wasserstein GAN (arxiv.org)
在生成对抗网络中,当判断网络为最优时,生成网络的优化目标是最小化真实分布 和模型分布 之间的JS散度。当两个分布相同时,JS散度为0,最优生成网络对应的损失为−2log2。但是使用JS散度来训练生成对抗网络的一个问题是当两个分布没有重叠时,它们之间的JS散度恒等于常数log2。对生成网络来说,目标函数关于参数的梯度为0。
在GAN的基础上加入了Wasserstein距离,Wasserstein距离用于衡量两个分布之间的距离。相比KL散度和JS散度的优势在于即使两个分布没有重叠或者重叠非常少,Wasserstein距离仍然能反映两个分布的远近。其数学公式如下:
其中 , 是边际分布为 的所有可能的联合分布集合, 为 和 的 距离, 比如 距离等, 表示期望,表示下确界。
下确界
,如果是一个集合的下确界, 即表示小于或等于集合E
的所有其他元素的最大元素
, 这个数不一定
在集合E中。举例来说:
- ; 也就是说集合的下确界为1
- ;
- ;
当然,换一种角度解读:将两个分布看作是两个土堆,联合分布 看作是从土堆 的位置 到土堆 的位置 的搬运土的数量。Wasserstein距离可以理解为搬运土堆的最小工作量,也称为推土机距离(Earth-Mover’s Distance,EMD)
WGAN-GP
参考论文:[1704.00028] Improved Training of Wasserstein GANs (arxiv.org)
WGAN还是有问题:
- 权重裁剪会导致参数基本都在限制的边界值,极大浪费了模型的参数。
- 还是很容易梯度消失或者梯度爆炸,需要仔细的调参
WGAN-GP,核心只有一个:Gradient Penalty
Gradient Penalty:判别器相对于输入的梯度的二范数要约束在1附近,这样就能够保证Lipschitz连续。
# 计算Gradient Penalty
def compute_gradient_penalty(D, real_samples, fake_samples):
"""Calculates the gradient penalty loss for WGAN GP"""
# Random weight term for interpolation between real and fake samples
alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
# Get random interpolation between real and fake samples
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
d_interpolates = D(interpolates)
fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
# Get gradient w.r.t. interpolates
gradients = autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=fake,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
Conditional GAN
条件GAN,顾名思义就是根据条件针对性的生成数据。具体有AC-GAN。
文章出处登录后可见!