扩散模型 (Diffusion Model) 简要介绍与源码分析

扩散模型 (Diffusion Model) 简要介绍与源码分析

前言

近期同事分享了 Diffusion Model, 这才发现生成模型的发展已经到了如此惊人的地步, OpenAI 推出的 Dall-E 2 可以根据文本描述生成极为逼真的图像, 质量之高直让人惊呼哇塞. 今早公众号给我推送了一篇关于 Stability AI 公司的报道, 他们推出的 AI 文生图扩散模型 Stable Diffusion 已开源, 能够在消费级显卡上实现 Dall-E 2 级别的图像生成, 效率提升了 30 倍.

于是找到他们的开源产品体验了一把, 在线体验地址在 https://huggingface.co/spaces/stabilityai/stable-diffusion (开源代码在 Github 上: https://github.com/CompVis/stable-diffusion), 在搜索框中输入 “A dog flying in the sky” (一只狗在天空飞翔), 生成效果如下:

扩散模型 (Diffusion Model) 简要介绍与源码分析

Amazing! 当然, 不是每一张图片都符合预期, 但好在可以生成无数张图片, 其中总有效果好的. 在震惊之余, 不免对 Diffusion Model (扩散模型) 背后的原理感兴趣, 就想看看是怎么实现的.

当时同事分享时, PPT 上那一堆堆公式扑面而来, 把我给整懵圈了, 但还是得撑起下巴, 表现出似有所悟、深以为然的样子, 在讲到关键处不由暗暗点头以表示理解和赞许. 后面花了个周末专门学习了一下, 公式推导+代码分析, 感觉终于了解了基本概念, 于是记录下来形成此文, 不敢说自己完全懂了, 毕竟我不做这个方向, 但回过头去看 PPT 上的公式就不再发怵了.

广而告之

可以在微信中搜索 “珍妮的算法之路” 或者 “world4458” 关注我的微信公众号, 可以及时获取最新原创技术文章更新.

另外可以看看知乎专栏 PoorMemory-机器学习, 以后文章也会发在知乎专栏中.

总览

本文对 Diffusion Model 扩散模型的原理进行简要介绍, 然后对源码进行分析. 扩散模型的实现有多种形式, 本文关注的是 DDPM (denoising diffusion probabilistic models). 在介绍完基本原理后, 对作者释放的 Tensorflow 源码进行分析, 加深对各种公式的理解.

参考文章

在理解扩散模型的路上, 受到下面这些文章的启发, 强烈推荐阅读:

扩散模型介绍

基本原理

Diffusion Model (扩散模型) 是一类生成模型, 和 VAE (Variational Autoencoder, 变分自动编码器), GAN (Generative Adversarial Network, 生成对抗网络) 等生成网络不同的是, 扩散模型在前向阶段对图像逐步施加噪声, 直至图像被破坏变成完全的高斯噪声, 然后在逆向阶段学习从高斯噪声还原为原始图像的过程.

具体来说, 前向阶段在原始图像 扩散模型 (Diffusion Model) 简要介绍与源码分析 上逐步增加噪声, 每一步得到的图像 扩散模型 (Diffusion Model) 简要介绍与源码分析 只和上一步的结果 扩散模型 (Diffusion Model) 简要介绍与源码分析 相关, 直至第 扩散模型 (Diffusion Model) 简要介绍与源码分析 步的图像 扩散模型 (Diffusion Model) 简要介绍与源码分析 变为纯高斯噪声. 前向阶段图示如下:

而逆向阶段则是不断去除噪声的过程, 首先给定高斯噪声 扩散模型 (Diffusion Model) 简要介绍与源码分析, 通过逐步去噪, 直至最终将原图像 扩散模型 (Diffusion Model) 简要介绍与源码分析 给恢复出来, 逆向阶段图示如下:

模型训练完成后, 只要给定高斯随机噪声, 就可以生成一张从未见过的图像. 下面分别介绍前向阶段和逆向阶段, 只列出重要公式,

前向阶段

由于前向过程中图像 扩散模型 (Diffusion Model) 简要介绍与源码分析 只和上一时刻的 扩散模型 (Diffusion Model) 简要介绍与源码分析 有关, 该过程可以视为马尔科夫过程, 满足:

扩散模型 (Diffusion Model) 简要介绍与源码分析

其中 扩散模型 (Diffusion Model) 简要介绍与源码分析 为高斯分布的方差超参, 并满足 扩散模型 (Diffusion Model) 简要介绍与源码分析. 另外公式 (2) 中为何均值 扩散模型 (Diffusion Model) 简要介绍与源码分析 前乘上系数 扩散模型 (Diffusion Model) 简要介绍与源码分析 的原因将在后面的推导介绍. 上述过程的一个美妙性质是我们可以在任意 time step 下通过 重参数技巧 采样得到 扩散模型 (Diffusion Model) 简要介绍与源码分析.

重参数技巧 (reparameterization trick) 是为了解决随机采样样本这一过程无法求导的问题. 比如要从高斯分布 扩散模型 (Diffusion Model) 简要介绍与源码分析 中采样样本 扩散模型 (Diffusion Model) 简要介绍与源码分析, 可以通过引入随机变量 扩散模型 (Diffusion Model) 简要介绍与源码分析, 使得 扩散模型 (Diffusion Model) 简要介绍与源码分析, 此时 扩散模型 (Diffusion Model) 简要介绍与源码分析 依旧具有随机性, 且服从高斯分布 扩散模型 (Diffusion Model) 简要介绍与源码分析, 同时 扩散模型 (Diffusion Model) 简要介绍与源码分析扩散模型 (Diffusion Model) 简要介绍与源码分析 (通常由网络生成) 可导.

简要了解了重参数技巧后, 再回到上面通过公式 (2) 采样 扩散模型 (Diffusion Model) 简要介绍与源码分析 的方法, 即生成随机变量 扩散模型 (Diffusion Model) 简要介绍与源码分析,
然后令 扩散模型 (Diffusion Model) 简要介绍与源码分析, 以及 扩散模型 (Diffusion Model) 简要介绍与源码分析, 从而可以得到:

扩散模型 (Diffusion Model) 简要介绍与源码分析

其中公式 (3-1) 到公式 (3-2) 的推导是由于独立高斯分布的可见性, 有 扩散模型 (Diffusion Model) 简要介绍与源码分析, 因此:

扩散模型 (Diffusion Model) 简要介绍与源码分析

注意公式 (3-2) 中 扩散模型 (Diffusion Model) 简要介绍与源码分析, 因此还需乘上 扩散模型 (Diffusion Model) 简要介绍与源码分析. 从公式 (3) 可以看出

扩散模型 (Diffusion Model) 简要介绍与源码分析

注意由于 扩散模型 (Diffusion Model) 简要介绍与源码分析扩散模型 (Diffusion Model) 简要介绍与源码分析, 而 扩散模型 (Diffusion Model) 简要介绍与源码分析, 因此 扩散模型 (Diffusion Model) 简要介绍与源码分析 并且有 扩散模型 (Diffusion Model) 简要介绍与源码分析, 另外由于 扩散模型 (Diffusion Model) 简要介绍与源码分析, 因此当 扩散模型 (Diffusion Model) 简要介绍与源码分析 时, 扩散模型 (Diffusion Model) 简要介绍与源码分析 以及 扩散模型 (Diffusion Model) 简要介绍与源码分析, 此时 扩散模型 (Diffusion Model) 简要介绍与源码分析. 从这里的推导来看, 在公式 (2) 中的均值 扩散模型 (Diffusion Model) 简要介绍与源码分析 前乘上系数 扩散模型 (Diffusion Model) 简要介绍与源码分析 会使得 扩散模型 (Diffusion Model) 简要介绍与源码分析 最后收敛到标准高斯分布.

逆向阶段

前向阶段是加噪声的过程, 而逆向阶段则是将噪声去除, 如果能得到逆向过程的分布 扩散模型 (Diffusion Model) 简要介绍与源码分析, 那么通过输入高斯噪声 扩散模型 (Diffusion Model) 简要介绍与源码分析, 我们将生成一个真实的样本. 注意到当 扩散模型 (Diffusion Model) 简要介绍与源码分析 足够小时, 扩散模型 (Diffusion Model) 简要介绍与源码分析 也是高斯分布, 具体的证明在 ewrfcas 的知乎文章: 由浅入深了解Diffusion Model 推荐的论文中: On the theory of stochastic processes, with particular reference to applications. 我大致看了一下, 哈哈, 没太看明白, 不过想到这个不是我关注的重点, 因此 pass. 由于我们无法直接推断 扩散模型 (Diffusion Model) 简要介绍与源码分析, 因此我们将使用深度学习模型 扩散模型 (Diffusion Model) 简要介绍与源码分析 去拟合分布 扩散模型 (Diffusion Model) 简要介绍与源码分析, 模型参数为 扩散模型 (Diffusion Model) 简要介绍与源码分析:

扩散模型 (Diffusion Model) 简要介绍与源码分析

注意到, 虽然我们无法直接求得 扩散模型 (Diffusion Model) 简要介绍与源码分析 (注意这里是 扩散模型 (Diffusion Model) 简要介绍与源码分析 而不是模型 扩散模型 (Diffusion Model) 简要介绍与源码分析), 但在知道 扩散模型 (Diffusion Model) 简要介绍与源码分析 的情况下, 可以通过贝叶斯公式得到 扩散模型 (Diffusion Model) 简要介绍与源码分析 为:

扩散模型 (Diffusion Model) 简要介绍与源码分析

推导过程如下:

扩散模型 (Diffusion Model) 简要介绍与源码分析

上面推导过程中, 通过贝叶斯公式巧妙的将逆向过程转换为前向过程, 且最终得到的概率密度函数和高斯概率密度函数的指数部分 扩散模型 (Diffusion Model) 简要介绍与源码分析 能对应, 即有:

扩散模型 (Diffusion Model) 简要介绍与源码分析

通过公式 (8) 和公式 (9), 我们能得到 扩散模型 (Diffusion Model) 简要介绍与源码分析 (见公式 (7)) 的分布. 此外由于公式 (3) 揭示的 扩散模型 (Diffusion Model) 简要介绍与源码分析扩散模型 (Diffusion Model) 简要介绍与源码分析 之间的关系: 扩散模型 (Diffusion Model) 简要介绍与源码分析, 可以得到

扩散模型 (Diffusion Model) 简要介绍与源码分析

代入公式 (9) 中得到:

扩散模型 (Diffusion Model) 简要介绍与源码分析

补充一下公式 (11) 的详细推导过程:

前面说到, 我们将使用深度学习模型 扩散模型 (Diffusion Model) 简要介绍与源码分析 去拟合逆向过程的分布 扩散模型 (Diffusion Model) 简要介绍与源码分析, 由公式 (6) 知 扩散模型 (Diffusion Model) 简要介绍与源码分析, 我们希望训练模型 扩散模型 (Diffusion Model) 简要介绍与源码分析 以预估 扩散模型 (Diffusion Model) 简要介绍与源码分析. 由于 扩散模型 (Diffusion Model) 简要介绍与源码分析 在训练阶段会作为输入, 因此它是已知的, 我们可以转而让模型去预估噪声 扩散模型 (Diffusion Model) 简要介绍与源码分析, 即令:

扩散模型 (Diffusion Model) 简要介绍与源码分析

模型训练

前面谈到, 逆向阶段让模型去预估噪声 扩散模型 (Diffusion Model) 简要介绍与源码分析, 那么应该如何设计 Loss 函数 ? 我们的目标是在真实数据分布下, 最大化模型预测分布的对数似然, 即优化在 扩散模型 (Diffusion Model) 简要介绍与源码分析 下的 扩散模型 (Diffusion Model) 简要介绍与源码分析 交叉熵:

扩散模型 (Diffusion Model) 简要介绍与源码分析

变分自动编码器 VAE 类似, 使用 Variational Lower Bound 来优化: 扩散模型 (Diffusion Model) 简要介绍与源码分析 :

扩散模型 (Diffusion Model) 简要介绍与源码分析

对公式 (15) 左右两边取期望 扩散模型 (Diffusion Model) 简要介绍与源码分析, 利用到重积分中的 Fubini 定理 可得:

扩散模型 (Diffusion Model) 简要介绍与源码分析

因此最小化 扩散模型 (Diffusion Model) 简要介绍与源码分析 就可以优化公式 (14) 中的目标函数. 之后对 扩散模型 (Diffusion Model) 简要介绍与源码分析 做进一步的推导, 这部分的详细推导见上面的参考文章, 最终的结论是:

扩散模型 (Diffusion Model) 简要介绍与源码分析

最终是优化两个高斯分布 扩散模型 (Diffusion Model) 简要介绍与源码分析 (详见公式 (7)) 与 扩散模型 (Diffusion Model) 简要介绍与源码分析 (详见公式(6), 此为模型预估的分布)之间的 KL 散度. 由于多元高斯分布的 KL 散度存在闭式解, 详见: Multivariate_normal_distributions, 从而可以得到:

扩散模型 (Diffusion Model) 简要介绍与源码分析

DDPM 将 Loss 简化为如下形式:

扩散模型 (Diffusion Model) 简要介绍与源码分析

因此 Diffusion 模型的目标函数即是学习高斯噪声 扩散模型 (Diffusion Model) 简要介绍与源码分析扩散模型 (Diffusion Model) 简要介绍与源码分析 (来自模型输出) 之间的 MSE loss.

最终算法

最终 DDPM 的算法流程如下:

训练阶段重复如下步骤:

逆向阶段采用如下步骤进行采样:

源码分析

DDPM 文章以及代码的相关信息如下:

本文以分析 Tensorflow 源码为主, Pytorch 版本的代码和 Tensorflow 版本的实现逻辑大体不差的, 变量名字啥的都类似, 阅读起来不会有啥门槛. Tensorlow 源码对 Diffusion 模型的实现位于 diffusion_utils_2.py, 模型本身的分析以该文件为主.

训练阶段

以 CIFAR 数据集为例.

run_cifar.py 中进行前向传播计算 Loss:

  • 第 6 行随机选出 扩散模型 (Diffusion Model) 简要介绍与源码分析
  • 第 7 行 training_losses 定义在 GaussianDiffusion2 中, 计算噪声间的 MSE Loss.

进入 GaussianDiffusion2 中, 看到初始化函数中定义了诸多变量, 我在注释中使用公式的方式进行了说明:

下面进入到 training_losses 函数中:

  • 第 19 行: self.model_mean_type 默认是 eps, 模型学习的是噪声, 因此 target 是第 6 行定义的 noise, 即 扩散模型 (Diffusion Model) 简要介绍与源码分析
  • 第 9 行: 调用 self.q_sample 计算 扩散模型 (Diffusion Model) 简要介绍与源码分析, 即公式 (3) 扩散模型 (Diffusion Model) 简要介绍与源码分析
  • 第 21 行: denoise_fn 是定义在 unet.py 中的 UNet 模型, 只需知道它的输入和输出大小相同; 结合第 9 行得到的 扩散模型 (Diffusion Model) 简要介绍与源码分析, 得到模型预估的噪声: 扩散模型 (Diffusion Model) 简要介绍与源码分析
  • 第 23 行: 计算两个噪声之间的 MSE: 扩散模型 (Diffusion Model) 简要介绍与源码分析, 并利用反向传播算法训练模型

上面第 9 行定义的 self.q_sample 详情如下:

  • 第 13 行的 q_sample 已经介绍过, 不多说.
  • 第 2 行的 _extract 在代码中经常被使用到, 看到它只需知道它是用来提取系数的即可. 引入输入是一个 Batch, 里面的每个样本都会随机采样一个 time step 扩散模型 (Diffusion Model) 简要介绍与源码分析, 因此需要使用 tf.gather 来将 扩散模型 (Diffusion Model) 简要介绍与源码分析 之类选出来, 然后将系数 reshape 为 [B, 1, 1, ....] 的形式, 目的是为了利用 broadcasting 机制和 扩散模型 (Diffusion Model) 简要介绍与源码分析 这个 Tensor 相乘.

前向的训练阶段代码实现非常简单, 下面看逆向阶段

逆向阶段

逆向阶段代码定义在 GaussianDiffusion2 中:

  • 第 5 行生成高斯噪声 扩散模型 (Diffusion Model) 简要介绍与源码分析, 然后对其不断去噪直至恢复原始图像
  • 第 11 行的 self.p_sample 就是公式 (6) 扩散模型 (Diffusion Model) 简要介绍与源码分析 的过程, 使用模型来预估 扩散模型 (Diffusion Model) 简要介绍与源码分析 以及 扩散模型 (Diffusion Model) 简要介绍与源码分析
  • 第 12 行的 denoise_fn 在前面说过, 是定义在 unet.py 中的 UNet 模型; img_ 表示 扩散模型 (Diffusion Model) 简要介绍与源码分析.
  • 第 13 行的 noise_fn 则默认是 tf.random_normal, 用于生成高斯噪声.

进入 p_sample 函数:

  • 第 7 行调用 self.p_mean_variance 生成 扩散模型 (Diffusion Model) 简要介绍与源码分析 以及 扩散模型 (Diffusion Model) 简要介绍与源码分析, 其中 扩散模型 (Diffusion Model) 简要介绍与源码分析 通过计算 扩散模型 (Diffusion Model) 简要介绍与源码分析 得到.
  • 第 11 行从高斯分布中采样 扩散模型 (Diffusion Model) 简要介绍与源码分析
  • 第 18 行通过重参数技巧采样 扩散模型 (Diffusion Model) 简要介绍与源码分析, 其中 扩散模型 (Diffusion Model) 简要介绍与源码分析

进入 self.p_mean_variance 函数:

  • 第 6 行调用模型 denoise_fn, 通过输入 扩散模型 (Diffusion Model) 简要介绍与源码分析, 输出得到噪声 扩散模型 (Diffusion Model) 简要介绍与源码分析
  • 第 19 行 self.model_var_type 默认为 fixedlarge, 但我当时看 fixedsmall 比较爽, 因此 model_variancemodel_log_variance 分别为 扩散模型 (Diffusion Model) 简要介绍与源码分析 (见公式 8), 以及 扩散模型 (Diffusion Model) 简要介绍与源码分析
  • 第 29 行调用 self._predict_xstart_from_eps 函数, 利用公式 (10) 得到 扩散模型 (Diffusion Model) 简要介绍与源码分析
  • 第 30 行调用 self.q_posterior_mean_variance 通过公式 (9) 得到 扩散模型 (Diffusion Model) 简要介绍与源码分析

self._predict_xstart_from_eps 函数相亲如下:

  • 该函数计算 扩散模型 (Diffusion Model) 简要介绍与源码分析

self.q_posterior_mean_variance 函数详情如下:

  • 相关说明见注释, 另外发现对于 扩散模型 (Diffusion Model) 简要介绍与源码分析 的计算使用的是公式 (9) 扩散模型 (Diffusion Model) 简要介绍与源码分析 而不是进一步推导后的公式 (11) 扩散模型 (Diffusion Model) 简要介绍与源码分析.

总结

写文章真的挺累的, 好处是, 我发现写之前我以为理解了, 但写的过程中又发现有些地方理解的不对. 写完后才终于把逻辑理顺.

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2023年3月22日
下一篇 2023年3月22日

相关推荐