扩散模型 (Diffusion Model) 简要介绍与源码分析
前言
近期同事分享了 Diffusion Model, 这才发现生成模型的发展已经到了如此惊人的地步, OpenAI 推出的 Dall-E 2 可以根据文本描述生成极为逼真的图像, 质量之高直让人惊呼哇塞. 今早公众号给我推送了一篇关于 Stability AI 公司的报道, 他们推出的 AI 文生图扩散模型 Stable Diffusion 已开源, 能够在消费级显卡上实现 Dall-E 2 级别的图像生成, 效率提升了 30 倍.
于是找到他们的开源产品体验了一把, 在线体验地址在 https://huggingface.co/spaces/stabilityai/stable-diffusion (开源代码在 Github 上: https://github.com/CompVis/stable-diffusion), 在搜索框中输入 “A dog flying in the sky” (一只狗在天空飞翔), 生成效果如下:
Amazing! 当然, 不是每一张图片都符合预期, 但好在可以生成无数张图片, 其中总有效果好的. 在震惊之余, 不免对 Diffusion Model (扩散模型) 背后的原理感兴趣, 就想看看是怎么实现的.
当时同事分享时, PPT 上那一堆堆公式扑面而来, 把我给整懵圈了, 但还是得撑起下巴, 表现出似有所悟、深以为然的样子, 在讲到关键处不由暗暗点头以表示理解和赞许. 后面花了个周末专门学习了一下, 公式推导+代码分析, 感觉终于了解了基本概念, 于是记录下来形成此文, 不敢说自己完全懂了, 毕竟我不做这个方向, 但回过头去看 PPT 上的公式就不再发怵了.
广而告之
可以在微信中搜索 “珍妮的算法之路” 或者 “world4458” 关注我的微信公众号, 可以及时获取最新原创技术文章更新.
另外可以看看知乎专栏 PoorMemory-机器学习, 以后文章也会发在知乎专栏中.
总览
本文对 Diffusion Model 扩散模型的原理进行简要介绍, 然后对源码进行分析. 扩散模型的实现有多种形式, 本文关注的是 DDPM (denoising diffusion probabilistic models). 在介绍完基本原理后, 对作者释放的 Tensorflow 源码进行分析, 加深对各种公式的理解.
参考文章
在理解扩散模型的路上, 受到下面这些文章的启发, 强烈推荐阅读:
- Lilian 的博客, 内容非常非常详实, 干货十足, 而且每篇文章都极其用心, 向大佬学习: What are Diffusion Models?
- ewrfcas 的知乎, 公式推导补充了更多的细节: 由浅入深了解Diffusion Model
- Lilian 的博客, 介绍变分自动编码器 VAE: From Autoencoder to Beta-VAE, Diffusion Model 需要从分布中随机采样样本, 该过程无法求导, 需要使用到 VAE 中介绍的重参数技巧.
- Denoising Diffusion Probabilistic Models 论文,
- 其 TF 源码位于: https://github.com/hojonathanho/diffusion, 源码介绍以该版本为主
- PyTorch 的开源实现: https://github.com/lucidrains/denoising-diffusion-pytorch, 核心逻辑和上面 Tensorflow 版本是一致的, Stable Diffusion 参考的是 pytorch 版本的代码.
扩散模型介绍
基本原理
Diffusion Model (扩散模型) 是一类生成模型, 和 VAE (Variational Autoencoder, 变分自动编码器), GAN (Generative Adversarial Network, 生成对抗网络) 等生成网络不同的是, 扩散模型在前向阶段对图像逐步施加噪声, 直至图像被破坏变成完全的高斯噪声, 然后在逆向阶段学习从高斯噪声还原为原始图像的过程.
具体来说, 前向阶段在原始图像 上逐步增加噪声, 每一步得到的图像 只和上一步的结果 相关, 直至第 步的图像 变为纯高斯噪声. 前向阶段图示如下:
而逆向阶段则是不断去除噪声的过程, 首先给定高斯噪声 , 通过逐步去噪, 直至最终将原图像 给恢复出来, 逆向阶段图示如下:
模型训练完成后, 只要给定高斯随机噪声, 就可以生成一张从未见过的图像. 下面分别介绍前向阶段和逆向阶段, 只列出重要公式,
前向阶段
由于前向过程中图像 只和上一时刻的 有关, 该过程可以视为马尔科夫过程, 满足:
其中 为高斯分布的方差超参, 并满足 . 另外公式 (2) 中为何均值 前乘上系数 的原因将在后面的推导介绍. 上述过程的一个美妙性质是我们可以在任意 time step 下通过 重参数技巧 采样得到 .
重参数技巧 (reparameterization trick) 是为了解决随机采样样本这一过程无法求导的问题. 比如要从高斯分布 中采样样本 , 可以通过引入随机变量 , 使得 , 此时 依旧具有随机性, 且服从高斯分布 , 同时 与 (通常由网络生成) 可导.
简要了解了重参数技巧后, 再回到上面通过公式 (2) 采样 的方法, 即生成随机变量 ,
然后令 , 以及 , 从而可以得到:
其中公式 (3-1) 到公式 (3-2) 的推导是由于独立高斯分布的可见性, 有 , 因此:
注意公式 (3-2) 中 , 因此还需乘上 . 从公式 (3) 可以看出
注意由于 且 , 而 , 因此 并且有 , 另外由于 , 因此当 时, 以及 , 此时 . 从这里的推导来看, 在公式 (2) 中的均值 前乘上系数 会使得 最后收敛到标准高斯分布.
逆向阶段
前向阶段是加噪声的过程, 而逆向阶段则是将噪声去除, 如果能得到逆向过程的分布 , 那么通过输入高斯噪声 , 我们将生成一个真实的样本. 注意到当 足够小时, 也是高斯分布, 具体的证明在 ewrfcas 的知乎文章: 由浅入深了解Diffusion Model 推荐的论文中: On the theory of stochastic processes, with particular reference to applications
. 我大致看了一下, 哈哈, 没太看明白, 不过想到这个不是我关注的重点, 因此 pass. 由于我们无法直接推断 , 因此我们将使用深度学习模型 去拟合分布 , 模型参数为 :
注意到, 虽然我们无法直接求得 (注意这里是 而不是模型 ), 但在知道 的情况下, 可以通过贝叶斯公式得到 为:
推导过程如下:
上面推导过程中, 通过贝叶斯公式巧妙的将逆向过程转换为前向过程, 且最终得到的概率密度函数和高斯概率密度函数的指数部分 能对应, 即有:
通过公式 (8) 和公式 (9), 我们能得到 (见公式 (7)) 的分布. 此外由于公式 (3) 揭示的 和 之间的关系: , 可以得到
代入公式 (9) 中得到:
补充一下公式 (11) 的详细推导过程:
前面说到, 我们将使用深度学习模型 去拟合逆向过程的分布 , 由公式 (6) 知 , 我们希望训练模型 以预估 . 由于 在训练阶段会作为输入, 因此它是已知的, 我们可以转而让模型去预估噪声 , 即令:
模型训练
前面谈到, 逆向阶段让模型去预估噪声 , 那么应该如何设计 Loss 函数 ? 我们的目标是在真实数据分布下, 最大化模型预测分布的对数似然, 即优化在 下的 交叉熵:
和 变分自动编码器 VAE 类似, 使用 Variational Lower Bound 来优化: :
对公式 (15) 左右两边取期望 , 利用到重积分中的 Fubini 定理 可得:
因此最小化 就可以优化公式 (14) 中的目标函数. 之后对 做进一步的推导, 这部分的详细推导见上面的参考文章, 最终的结论是:
最终是优化两个高斯分布 (详见公式 (7)) 与 (详见公式(6), 此为模型预估的分布)之间的 KL 散度. 由于多元高斯分布的 KL 散度存在闭式解, 详见: Multivariate_normal_distributions, 从而可以得到:
DDPM 将 Loss 简化为如下形式:
因此 Diffusion 模型的目标函数即是学习高斯噪声 和 (来自模型输出) 之间的 MSE loss.
最终算法
最终 DDPM 的算法流程如下:
训练阶段重复如下步骤:
- 从数据集中采样
- 随机选取 time step
- 生成高斯噪声
- 调用模型预估
- 计算噪声之间的 MSE Loss: , 并利用反向传播算法训练模型.
逆向阶段采用如下步骤进行采样:
- 从高斯分布采样
- 按照 的顺序进行迭代:
- 如果 , 令 ; 如果 , 从高斯分布中采样
- 利用公式 (12) 学习出均值 , 并利用公式 (8) 计算均方差
- 通过重参数技巧采样
- 经过以上过程的迭代, 最终恢复 .
源码分析
DDPM 文章以及代码的相关信息如下:
- Denoising Diffusion Probabilistic Models 论文,
- 其 TF 源码位于: https://github.com/hojonathanho/diffusion, 源码介绍以该版本为主
- PyTorch 的开源实现: https://github.com/lucidrains/denoising-diffusion-pytorch, 核心逻辑和上面 Tensorflow 版本是一致的, Stable Diffusion 参考的是 pytorch 版本的代码.
本文以分析 Tensorflow 源码为主, Pytorch 版本的代码和 Tensorflow 版本的实现逻辑大体不差的, 变量名字啥的都类似, 阅读起来不会有啥门槛. Tensorlow 源码对 Diffusion 模型的实现位于 diffusion_utils_2.py, 模型本身的分析以该文件为主.
训练阶段
以 CIFAR 数据集为例.
在 run_cifar.py 中进行前向传播计算 Loss:
- 第 6 行随机选出
- 第 7 行
training_losses
定义在 GaussianDiffusion2 中, 计算噪声间的 MSE Loss.
进入 GaussianDiffusion2 中, 看到初始化函数中定义了诸多变量, 我在注释中使用公式的方式进行了说明:
下面进入到 training_losses
函数中:
- 第 19 行:
self.model_mean_type
默认是eps
, 模型学习的是噪声, 因此target
是第 6 行定义的noise
, 即 - 第 9 行: 调用
self.q_sample
计算 , 即公式 (3) - 第 21 行:
denoise_fn
是定义在 unet.py 中的UNet
模型, 只需知道它的输入和输出大小相同; 结合第 9 行得到的 , 得到模型预估的噪声: - 第 23 行: 计算两个噪声之间的 MSE: , 并利用反向传播算法训练模型
上面第 9 行定义的 self.q_sample
详情如下:
- 第 13 行的
q_sample
已经介绍过, 不多说. - 第 2 行的
_extract
在代码中经常被使用到, 看到它只需知道它是用来提取系数的即可. 引入输入是一个 Batch, 里面的每个样本都会随机采样一个 time step , 因此需要使用tf.gather
来将 之类选出来, 然后将系数 reshape 为[B, 1, 1, ....]
的形式, 目的是为了利用 broadcasting 机制和 这个 Tensor 相乘.
前向的训练阶段代码实现非常简单, 下面看逆向阶段
逆向阶段
逆向阶段代码定义在 GaussianDiffusion2 中:
- 第 5 行生成高斯噪声 , 然后对其不断去噪直至恢复原始图像
- 第 11 行的
self.p_sample
就是公式 (6) 的过程, 使用模型来预估 以及 - 第 12 行的
denoise_fn
在前面说过, 是定义在 unet.py 中的UNet
模型;img_
表示 . - 第 13 行的
noise_fn
则默认是tf.random_normal
, 用于生成高斯噪声.
进入 p_sample
函数:
- 第 7 行调用
self.p_mean_variance
生成 以及 , 其中 通过计算 得到. - 第 11 行从高斯分布中采样
- 第 18 行通过重参数技巧采样 , 其中
进入 self.p_mean_variance
函数:
- 第 6 行调用模型
denoise_fn
, 通过输入 , 输出得到噪声 - 第 19 行
self.model_var_type
默认为fixedlarge
, 但我当时看fixedsmall
比较爽, 因此model_variance
和model_log_variance
分别为 (见公式 8), 以及 - 第 29 行调用
self._predict_xstart_from_eps
函数, 利用公式 (10) 得到 - 第 30 行调用
self.q_posterior_mean_variance
通过公式 (9) 得到
self._predict_xstart_from_eps
函数相亲如下:
- 该函数计算
self.q_posterior_mean_variance
函数详情如下:
- 相关说明见注释, 另外发现对于 的计算使用的是公式 (9) 而不是进一步推导后的公式 (11) .
总结
写文章真的挺累的, 好处是, 我发现写之前我以为理解了, 但写的过程中又发现有些地方理解的不对. 写完后才终于把逻辑理顺.
文章出处登录后可见!