【代码精读】Diffusion Model 扩散模型

【代码精读】Diffusion Model 扩散模型

本篇博客不会很详细介绍diffusion model的原理,而是用“知其然”的方式直接上代码。

1. 代码来源:

DenoisingDiffusionProbabilityModel-ddpm-
笔者fork后添加详细代码注释的版本连接:
minipuding/DenoisingDiffusionProbabilityModel-ddpm-

2. 代码结构

主要代码库为Diffusion以及DiffusionFreeGuidence。前者为最基础的Diffusion Model实现,后者则加上了最常用也最有效的技巧“Free Guidence”

├─Diffusion
│  │  Diffusion.py
│  │  Model.py
│  │  Train.py
│  │  __init__.py
│
├─DiffusionFreeGuidence
│      DiffusionCondition.py
│      ModelCondition.py
│      TrainCondition.py
│      __init__.py

3. Diffusion Package

3.1. Diffusion.py

首先导入需要的包,并定义一个“提取“函数extract

import torch
import torch.nn as nn
import torch.nn.functional as F


# ``extract``函数的作用是从v这一序列中按照索引t取出需要的数,然后reshape到输入数据x的维度
def extract(v, t, x_shape):
    """
    Extract some coefficients at specified timesteps, then reshape to
    [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
    """
    device = t.device
    # ``torch.gather``的用法建议看https://zhuanlan.zhihu.com/p/352877584的第一条评论
    # 在此处的所有调用实例中,v都是一维,可以看作是索引取值,即等价v[t], t大小为[batch_size, 1]
    out = torch.gather(v, index=t, dim=0).float().to(device)
    # 再把索引到的值reshape到[batch_size, 1, 1, ...], 维度和x_shape相同
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))

3.1.1. 正向扩散过程

必要公式及简单推导如下:

【代码精读】Diffusion Model 扩散模型
【代码精读】Diffusion Model 扩散模型

翻译:条件概率分布 【代码精读】Diffusion Model 扩散模型 服从均值为【代码精读】Diffusion Model 扩散模型,方差为 【代码精读】Diffusion Model 扩散模型的正态分布(高斯分布),【代码精读】Diffusion Model 扩散模型是常数。写成递推等式就是公式(1),其中【代码精读】Diffusion Model 扩散模型为标准正态分布(均值为0标准差为1)

【代码精读】Diffusion Model 扩散模型
【代码精读】Diffusion Model 扩散模型
【代码精读】Diffusion Model 扩散模型

翻译:令【代码精读】Diffusion Model 扩散模型,条件概率分布 【代码精读】Diffusion Model 扩散模型 服从均值为【代码精读】Diffusion Model 扩散模型,方差为 【代码精读】Diffusion Model 扩散模型的正态分布。

【代码精读】Diffusion Model 扩散模型
其中【代码精读】Diffusion Model 扩散模型
【代码精读】Diffusion Model 扩散模型

翻译:将等式(2)中的递归形式转为直接从【代码精读】Diffusion Model 扩散模型计算【代码精读】Diffusion Model 扩散模型,其中【代码精读】Diffusion Model 扩散模型表示【代码精读】Diffusion Model 扩散模型的连乘。

# ``GaussianDiffusionTrainer``包含了Diffusion Model的前向过程(加噪) & 训练过程
class GaussianDiffusionTrainer(nn.Module):
    def __init__(self, model, beta_1, beta_T, T):
        """
        初始化前向模型
        Args:
            model: 骨干模型,主流为U-Net+Attention
            beta_1: beta的起始值,本实例中取1e-4
            beta_T: bata在t=T时的值,本实例中取0.2
            T: 时间步数, 本实例中取1000
        """
        super().__init__()
        # 参数赋值
        self.model = model
        self.T = T

        # 等间隔得到beta_1到beta_T之间共T个step对应的beta值,组成序列存为类成员(后边可以用``self.betas``访问)
        self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
        # 根据公式,令alphas = 1 - betas
        alphas = 1. - self.betas
        # 根据公式,计算alpha连乘结果,存为alphas_bar
        # ``torch.cumprod``用于计算一个序列每个数与其前面所有数连乘的结果,得到一个序列,长度等于原序列长度
        # 例如:
        # a = torch.tensor([2,3,1,4])
        # b = torch.cumprod(a, dim=0)其实就等于torch.tensor([2, 2*3, 2*3*1, 2*3*1*4]) = torch.tensor([2, 6, 6, 24])
        alphas_bar = torch.cumprod(alphas, dim=0)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        # 根据公式计算sqrt(alphas_bar)以及sqrt(1-alphas_bar)分别作为正向扩散的均值和标准差,存入类成员
        # 可用``self.sqrt_alphas_bar``和``sqrt_one_minus_alphas_bar``来访问
        self.register_buffer(
            'sqrt_alphas_bar', torch.sqrt(alphas_bar))
        self.register_buffer(
            'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))

    def forward(self, x_0):
        """
        Algorithm 1.
        """
        # 从0~T中随机选batch_size个时间点
        t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
        # 参数重整化技巧,先生成均值为0方差为1的高斯分布,再通过乘标准差加均值的方式用于间接采样
        noise = torch.randn_like(x_0)
        x_t = (
            extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
            extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
        # 做一步反向扩散,希望模型可以预测出加入的噪声,也就是公式中的z_t
        loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
        return loss

3.1.2. 反向扩散过程

反向扩散过程的必要公式及其推导,我引用这篇文章的内容:

反向扩散相当于将条件概率颠倒,求【代码精读】Diffusion Model 扩散模型
由条件概率的相关知识可得【代码精读】Diffusion Model 扩散模型
但是我们在前向扩散时,每个状态需要由【代码精读】Diffusion Model 扩散模型求解,所以等式可以转换为:
【代码精读】Diffusion Model 扩散模型
根据正向过程(2)和(3)的式子,将上式右边三项分别展开:
【代码精读】Diffusion Model 扩散模型
【代码精读】Diffusion Model 扩散模型
【代码精读】Diffusion Model 扩散模型
其中【代码精读】Diffusion Model 扩散模型均为标准正态分布
根据上面三个展开结果,将【代码精读】Diffusion Model 扩散模型写成右侧三个正态分布合并后的展开公式,乘法在exp指数中变为加法,除法变为减法,汇总得:
【代码精读】Diffusion Model 扩散模型
展开后集中【代码精读】Diffusion Model 扩散模型合并同类项得【代码精读】Diffusion Model 扩散模型
我们将正态分布标准公式展开:【代码精读】Diffusion Model 扩散模型
与上面公式一 一对应可以得到:
【代码精读】Diffusion Model 扩散模型
【代码精读】Diffusion Model 扩散模型
根据【代码精读】Diffusion Model 扩散模型以及【代码精读】Diffusion Model 扩散模型化简得:
【代码精读】Diffusion Model 扩散模型
【代码精读】Diffusion Model 扩散模型
又因为公式中的【代码精读】Diffusion Model 扩散模型就是我们要求解的,需要被已知量替换掉,根据前向过程可知【代码精读】Diffusion Model 扩散模型,
所以有【代码精读】Diffusion Model 扩散模型,将其带入到公式(5)中得到:
【代码精读】Diffusion Model 扩散模型

为了代码上计算方便,将均值【代码精读】Diffusion Model 扩散模型的计算公式分为两个系数,即【代码精读】Diffusion Model 扩散模型,其中【代码精读】Diffusion Model 扩散模型如下:
【代码精读】Diffusion Model 扩散模型
【代码精读】Diffusion Model 扩散模型

# ``GaussianDiffusionSampler``包含了Diffusion Model的后向过程 & 推理过程
class GaussianDiffusionSampler(nn.Module):
    def __init__(self, model, beta_1, beta_T, T):
        """
        所有参数含义和``GaussianDiffusionTrainer``(前向过程)一样
        """
        super().__init__()

        self.model = model
        self.T = T

        # 这里获取betas, alphas以及alphas_bar和前向过程一模一样
        self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)
        # 这一步是方便后面运算,相当于构建alphas_bar{t-1}
        alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]  # 把alpha_bar的第一个数字换成1,按序后移

        # 根据公式(7)(8),后向过程中的计算均值需要用到的系数用coeff1和coeff2表示
        self.register_buffer('coeff1', torch.sqrt(1. / alphas))
        self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))

        # 根据公式(4),计算后向过程的方差
        self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))

    def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
        """
        该函数用于反向过程中,条件概率分布q(x_{t-1}|x_t)的均值
        Args:
             x_t: 迭代至当前步骤的图像
             t: 当前步数
             eps: 模型预测的噪声,也就是z_t
        Returns:
            x_{t-1}的均值,mean = coeff1 * x_t + coeff2 * eps
        """
        assert x_t.shape == eps.shape
        return (
            extract(self.coeff1, t, x_t.shape) * x_t -
            extract(self.coeff2, t, x_t.shape) * eps
        )

    def p_mean_variance(self, x_t, t):
        """
        该函数用于反向过程中,计算条件概率分布q(x_{t-1}|x_t)的均值和方差
        Args:
            x_t: 迭代至当前步骤的图像
            t: 当前步数
        Returns:
            xt_prev_mean: 均值
            var: 方差
        """
        # below: only log_variance is used in the KL computations
        # 这一步我略有不解,为什么要把算好的反向过程的方差大部分替换成betas。
        # 我猜测,后向过程方差``posterior_var``的计算过程仅仅是betas乘上一个(1 - alpha_bar_{t-1}) / (1 - alpha_bar_{t}),
        # 由于1 - alpha_bar_{t}这个数值非常趋近于0,分母为0会导致nan,
        # 而整体(1 - alpha_bar_{t-1}) / (1 - alpha_bar_{t})非常趋近于1,所以直接用betas近似后向过程的方差,
        # 但是t = 1 的时候(1 - alpha_bar_{0}) / (1 - alpha_bar_{1})还不是非常趋近于1,所以这个数值要保留,
        # 因此就有拼接``torch.cat([self.posterior_var[1:2], self.betas[1:]])``这一步
        var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
        var = extract(var, t, x_t.shape)

        # 模型前向预测得到eps(也就是z_t)
        eps = self.model(x_t, t)
        # 计算均值
        xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)

        return xt_prev_mean, var

    def forward(self, x_T):
        """
        Algorithm 2.
        """
        # 反向扩散过程,从x_t迭代至x_0
        x_t = x_T
        for time_step in reversed(range(self.T)):
            print(time_step)
            # t = [1, 1, ....] * time_step, 长度为batch_size
            t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
            # 计算条件概率分布q(x_{t-1}|x_t)的均值和方差
            mean, var= self.p_mean_variance(x_t=x_t, t=t)
            # no noise when t == 0
            # 最后一步的高斯噪声设为0(我认为不设为0问题也不大,就本实例而言,t=0时的方差已经很小了)
            if time_step > 0:
                noise = torch.randn_like(x_t)
            else:
                noise = 0
            x_t = mean + torch.sqrt(var) * noise
            assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
        x_0 = x_t
        # ``torch.clip(x_0, -1, 1)``,把x_0的值限制在-1到1之间,超出部分截断
        return torch.clip(x_0, -1, 1)   

3.2. Model.py

请直接阅读代码注释:

import math
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F


class Swish(nn.Module):
    """
    定义swish激活函数,可参考https://blog.csdn.net/bblingbbling/article/details/107105648
    """
    def forward(self, x):
        return x * torch.sigmoid(x)


class TimeEmbedding(nn.Module):
    """
    定义``时间嵌入``模块
    """
    def __init__(self, T, d_model, dim):
        """
        初始的time-embedding是由一系列不同频率的正弦、余弦函数采样值表示,
        即:[[sin(w_0*x), cos(w_0*x)],
            [sin(w_1*x), cos(w_1*x)],
             ...,
            [sin(w_T)*x, cos(w_T*x)]], 维度为 T * d_model
        在本实例中,频率范围是[0:T], x在1e-4~1范围,共d_model // 2个离散点;将sin, cos并在一起组成d_model个离散点
        Args:
            T: int, 总迭代步数,本实例中T=1000
            d_model: 输入维度(通道数/初始embedding长度)
            dim: 输出维度(通道数)
        """
        assert d_model % 2 == 0
        super().__init__()
        # 前两行计算x向量,共64个点
        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
        emb = torch.exp(-emb)
        # T个时间位置组成频率部分
        pos = torch.arange(T).float()
        # 两两相乘构成T*(d_model//2)的矩阵,并assert形状
        emb = pos[:, None] * emb[None, :]
        assert list(emb.shape) == [T, d_model // 2]
        # 计算不同频率sin, cos值,判断形状,并reshape到T*d_model
        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
        assert list(emb.shape) == [T, d_model // 2, 2]
        emb = emb.view(T, d_model)

        # MLP层,通过初始编码计算提取特征后的embedding
        # 包含两个线性层,第一个用swish激活函数,第二个不使用激活函数
        self.timembedding = nn.Sequential(
            nn.Embedding.from_pretrained(emb),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)

    def forward(self, t):
        emb = self.timembedding(t)
        return emb


class DownSample(nn.Module):
    """
    通过stride=2的卷积层进行降采样
    """
    def __init__(self, in_ch):
        super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x, temb):
        x = self.main(x)
        return x


class UpSample(nn.Module):
    """
    通过conv+最近邻插值进行上采样
    """
    def __init__(self, in_ch):
        super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x, temb):
        _, _, H, W = x.shape
        x = F.interpolate(
            x, scale_factor=2, mode='nearest')
        x = self.main(x)
        return x


class AttnBlock(nn.Module):
    """
    自注意力模块,其中线性层均用kernel为1的卷积层表示
    """
    def __init__(self, in_ch):
        # ``self.proj_q``, ``self.proj_k``, ``self.proj_v``分别用于学习query, key, value
        # ``self.proj``作为自注意力后的线性投射层
        super().__init__()
        self.group_norm = nn.GroupNorm(32, in_ch)
        self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.initialize()

    def initialize(self):
        for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
            init.xavier_uniform_(module.weight)
            init.zeros_(module.bias)
        init.xavier_uniform_(self.proj.weight, gain=1e-5)

    def forward(self, x):
        B, C, H, W = x.shape
        # 输入经过组归一化以及全连接层后分别得到query, key, value
        h = self.group_norm(x)
        q = self.proj_q(h)
        k = self.proj_k(h)
        v = self.proj_v(h)

        # 用矩阵乘法计算query与key的相似性权重w
        # 其中的``torch.bmm``的效果是第1维不动,第2,3维的矩阵做矩阵乘法,
        # 如a.shape=(_n, _h, _m), b.shape=(_n, _m, _w) --> torch.bmm(a, b).shape=(_n, _h, _w)
        # 矩阵运算后得到的权重要除以根号C, 归一化(相当于去除通道数对权重w绝对值的影响)
        q = q.permute(0, 2, 3, 1).view(B, H * W, C)
        k = k.view(B, C, H * W)
        w = torch.bmm(q, k) * (int(C) ** (-0.5))
        assert list(w.shape) == [B, H * W, H * W]
        w = F.softmax(w, dim=-1)

        # 再用刚得到的权重w对value进行注意力加权,操作也是一次矩阵乘法运算
        v = v.permute(0, 2, 3, 1).view(B, H * W, C)
        h = torch.bmm(w, v)
        assert list(h.shape) == [B, H * W, C]
        h = h.view(B, H, W, C).permute(0, 3, 1, 2)

        # 最后经过线性投射层输出,返回值加上输入x构成跳跃连接(残差连接)
        h = self.proj(h)

        return x + h


class ResBlock(nn.Module):
    """
    残差网络模块
    """
    def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
        """
        Args:
            in_ch: int, 输入通道数
            out_ch: int, 输出通道数
            tdim: int, time-embedding的长度/维数
            dropout: float, dropout的比例
            attn: bool, 是否使用自注意力模块
        """
        super().__init__()
        # 模块1: gn -> swish -> conv
        self.block1 = nn.Sequential(
            nn.GroupNorm(32, in_ch),
            Swish(),
            nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
        )
        # time_embedding 映射层: swish -> fc
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(tdim, out_ch),
        )
        # 模块2: gn -> swish -> dropout -> conv
        self.block2 = nn.Sequential(
            nn.GroupNorm(32, out_ch),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
        )
        # 如果输入输出通道数不一样,则添加一个过渡层``shortcut``, 卷积核为1, 否则什么也不做
        if in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
        else:
            self.shortcut = nn.Identity()
        # 如果需要加attention, 则添加一个``AttnBlock``, 否则什么也不做
        if attn:
            self.attn = AttnBlock(out_ch)
        else:
            self.attn = nn.Identity()
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)
        init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)

    def forward(self, x, temb):
        h = self.block1(x)                           # 输入特征经过模块1编码
        h += self.temb_proj(temb)[:, :, None, None]  # 将time-embedding加入到网络
        h = self.block2(h)                           # 将混合后的特征输入到模块2进一步编码

        h = h + self.shortcut(x)                     # 残差连接
        h = self.attn(h)                             # 经过自注意力模块(如果attn=True的话)
        return h


class UNet(nn.Module):
    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
        """

        Args:
            T: int, 总迭代步数,本实例中T=1000
            ch: int, UNet第一层卷积的通道数,每下采样一次在这基础上翻倍, 本实例中ch=128
            ch_mult: list, UNet每次下采样通道数翻倍的乘数,本实例中ch_mult=[1,2,3,4]
            attn: list, 表示在第几次降采样中使用attention
            num_res_blocks: int, 降采样或者上采样中每一层次的残差模块数目
            dropout: float, dropout比率
        """
        super().__init__()
        # assert确保需要加attention的位置小于总降采样次数
        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
        # 将time-embedding从长度为ch初始化编码到tdim = ch * 4
        tdim = ch * 4
        # 实例化初始的time-embedding层
        self.time_embedding = TimeEmbedding(T, ch, tdim)
        # 实例化头部卷积层
        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)

        # 实例化U-Net的编码器部分,即降采样部分,每一层次由``num_res_blocks``个残差块组成
        # 其中chs用于记录降采样过程中的各阶段通道数,now_ch表示当前阶段的通道数
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record output channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):  # i表示列表ch_mult的索引, mult表示ch_mult[i]
            out_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(
                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)

        # 实例化U-Net编码器和解码器的过渡层,由两个残差块组成
        # 这里我不明白为什么第一个残差块加attention, 第二个不加……问就是``工程科学``
        self.middleblocks = nn.ModuleList([
            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])

        # 实例化U-Net的解码器部分, 与编码器几乎对称
        # 唯一不同的是,每一层次的残差块比编码器多一个,
        # 原因是第一个残差块要用来融合当前特征图与跳转连接过来的特征图,第二、三个才是和编码器对称用来抽特征
        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(
                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        # 尾部模块: gn -> swish -> conv, 目的是回到原图通道数
        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
        )
        # 注意这里只初始化头部和尾部模块,因为其他模块在实例化的时候已经初始化过了
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.head.weight)
        init.zeros_(self.head.bias)
        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
        init.zeros_(self.tail[-1].bias)

    def forward(self, x, t):
        # Timestep embedding
        temb = self.time_embedding(t)
        # Downsampling
        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h, temb)
            hs.append(h)
        # Middle
        for layer in self.middleblocks:
            h = layer(h, temb)
        # Upsampling
        for layer in self.upblocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb)
        h = self.tail(h)

        assert len(hs) == 0
        return h


if __name__ == '__main__':
    batch_size = 8
    model = UNet(
        T=1000, ch=128, ch_mult=[1, 2, 2, 2], attn=[1],
        num_res_blocks=2, dropout=0.1)
    x = torch.randn(batch_size, 3, 32, 32)
    t = torch.randint(1000, (batch_size, ))
    y = model(x, t)
    print(y.shape)

3.3. Train.py

这一部分和普通模型训练、验证没有太大区别,我在重要的地方写上注释,请直接阅读代码注释:


import os
from typing import Dict

import torch
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.utils import save_image

from Diffusion import GaussianDiffusionSampler, GaussianDiffusionTrainer
from Diffusion.Model import UNet
from Scheduler import GradualWarmupScheduler


def train(modelConfig: Dict):
    device = torch.device(modelConfig["device"])
    # dataset
    dataset = CIFAR10(
        root='./CIFAR10', train=True, download=True,
        transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]))
    dataloader = DataLoader(
        dataset, batch_size=modelConfig["batch_size"], shuffle=True, num_workers=4, drop_last=True, pin_memory=True)

    # model setup
    net_model = UNet(T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"],
                     num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
    if modelConfig["training_load_weight"] is not None:
        net_model.load_state_dict(torch.load(os.path.join(
            modelConfig["save_weight_dir"], modelConfig["training_load_weight"]), map_location=device))
    optimizer = torch.optim.AdamW(
        net_model.parameters(), lr=modelConfig["lr"], weight_decay=1e-4)
    # 设置学习率衰减,按余弦函数的1/2个周期衰减,从``lr``衰减至0
    cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer, T_max=modelConfig["epoch"], eta_min=0, last_epoch=-1)
    # 设置逐步预热调度器,学习率从0逐渐增加至multiplier * lr,共用1/10总epoch数,后续学习率按``cosineScheduler``设置进行变化
    warmUpScheduler = GradualWarmupScheduler(
        optimizer=optimizer, multiplier=modelConfig["multiplier"], warm_epoch=modelConfig["epoch"] // 10, after_scheduler=cosineScheduler)
    # 实例化训练模型
    trainer = GaussianDiffusionTrainer(
        net_model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)

    # start training
    for e in range(modelConfig["epoch"]):
        with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
            for images, labels in tqdmDataLoader:
                # train
                optimizer.zero_grad()                                    # 清空过往梯度
                x_0 = images.to(device)                                  # 将输入图像加载到计算设备上
                loss = trainer(x_0).sum() / 1000.                        # 前向传播并计算损失
                loss.backward()                                          # 反向计算梯度
                torch.nn.utils.clip_grad_norm_(
                    net_model.parameters(), modelConfig["grad_clip"])    # 裁剪梯度,防止梯度爆炸
                optimizer.step()                                         # 更新参数
                tqdmDataLoader.set_postfix(ordered_dict={
                    "epoch": e,
                    "loss: ": loss.item(),
                    "img shape: ": x_0.shape,
                    "LR": optimizer.state_dict()['param_groups'][0]["lr"]
                })                                                       # 设置进度条显示内容
        warmUpScheduler.step()                                           # 调度器更新学习率
        torch.save(net_model.state_dict(), os.path.join(
            modelConfig["save_weight_dir"], 'ckpt_' + str(e) + "_.pt"))  # 保存模型


def eval(modelConfig: Dict):
    # load model and evaluate
    with torch.no_grad():
        # 建立和加载模型
        device = torch.device(modelConfig["device"])
        model = UNet(T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"],
                     num_res_blocks=modelConfig["num_res_blocks"], dropout=0.)
        ckpt = torch.load(os.path.join(
            modelConfig["save_weight_dir"], modelConfig["test_load_weight"]), map_location=device)
        model.load_state_dict(ckpt)
        print("model load weight done.")
        # 实例化反向扩散采样器
        model.eval()
        sampler = GaussianDiffusionSampler(
            model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)
        # Sampled from standard normal distribution
        # 随机生成高斯噪声图像并保存
        noisyImage = torch.randn(
            size=[modelConfig["batch_size"], 3, 32, 32], device=device)
        saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1)
        save_image(saveNoisy, os.path.join(
            modelConfig["sampled_dir"], modelConfig["sampledNoisyImgName"]), nrow=modelConfig["nrow"])
        # 反向扩散并保存输出图像
        sampledImgs = sampler(noisyImage)
        sampledImgs = sampledImgs * 0.5 + 0.5  # [0 ~ 1]
        save_image(sampledImgs, os.path.join(
            modelConfig["sampled_dir"],  modelConfig["sampledImgName"]), nrow=modelConfig["nrow"])

4. DiffusionFreeGuidence Package

原文:《Classifier-Free Diffusion Guidance》
建议和上边代码对比着看。

4.1. DiffusioinCondition.py

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np


def extract(v, t, x_shape):
    """
    Extract some coefficients at specified timesteps, then reshape to
    [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
    """
    device = t.device
    out = torch.gather(v, index=t, dim=0).float().to(device)
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))


class GaussianDiffusionTrainer(nn.Module):
    """
    前向加噪过程和``Diffusion.Diffusion.py``中的``GaussianDiffusionTrainer``几乎完全一样
    不同点在于模型输入,除了需要输入``x_t``, ``t``, 还要输入条件``labels``
    """
    def __init__(self, model, beta_1, beta_T, T):
        super().__init__()

        self.model = model
        self.T = T

        self.register_buffer(
            'betas', torch.linspace(beta_1, beta_T, T).double())
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer(
            'sqrt_alphas_bar', torch.sqrt(alphas_bar))
        self.register_buffer(
            'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))

    def forward(self, x_0, labels):
        """
        Algorithm 1.
        """
        t = torch.randint(self.T, size=(x_0.shape[0],), device=x_0.device)
        noise = torch.randn_like(x_0)
        x_t = extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 + \
              extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise
        loss = F.mse_loss(self.model(x_t, t, labels), noise, reduction='none')  # 不同点在于模型的输入多了``labels``
        return loss


class GaussianDiffusionSampler(nn.Module):
    """
    反向扩散过程和``Diffusion.Diffusion.py``中的``GaussianDiffusionSampler``绝大部分一样,
    所以在此只说明不一样的点
    """
    def __init__(self, model, beta_1, beta_T, T, w=0.):
        super().__init__()

        self.model = model
        self.T = T
        # In the classifier free guidence paper, w is the key to control the gudience.
        # w = 0 and with label = 0 means no guidence.
        # w > 0 and label > 0 means guidence. Guidence would be stronger if w is bigger.
        # 不同点1: 在初始化时需要输入一个权重系数``w``, 用来控制条件的强弱程度
        self.w = w

        self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)
        alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]
        self.register_buffer('coeff1', torch.sqrt(1. / alphas))
        self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))
        self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))

    def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        return (
            extract(self.coeff1, t, x_t.shape) * x_t -
            extract(self.coeff2, t, x_t.shape) * eps
        )

    def p_mean_variance(self, x_t, t, labels):
        # below: only log_variance is used in the KL computations
        var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
        var = extract(var, t, x_t.shape)

        # 不同点2: 模型推理时需要计算有条件和无条件(随机噪声)情况下模型的输出,
        # 将两次输出的结果用权重``self.w``进行合并得到最终输出
        eps = self.model(x_t, t, labels)
        nonEps = self.model(x_t, t, torch.zeros_like(labels).to(labels.device))
        # 参考原文公式(6)
        eps = (1. + self.w) * eps - self.w * nonEps
        xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)
        return xt_prev_mean, var

    def forward(self, x_T, labels):
        """
        Algorithm 2.
        """
        x_t = x_T
        for time_step in reversed(range(self.T)):
            print(time_step)
            t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
            # 除了输入多一个``labels``其他都和普通Diffusion Model一样
            mean, var = self.p_mean_variance(x_t=x_t, t=t, labels=labels)
            if time_step > 0:
                noise = torch.randn_like(x_t)
            else:
                noise = 0
            x_t = mean + torch.sqrt(var) * noise
            assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
        x_0 = x_t
        return torch.clip(x_0, -1, 1)

4.2. ModelCondition.py

import math
from telnetlib import PRAGMA_HEARTBEAT
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F


def drop_connect(x, drop_ratio):
    """
    这个函数在整个Project中都没被用到, 暂时先不考虑它的功能
    """
    keep_ratio = 1.0 - drop_ratio
    mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
    mask.bernoulli_(p=keep_ratio)
    x.div_(keep_ratio)
    x.mul_(mask)
    return x


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class TimeEmbedding(nn.Module):
    """
    和``Diffusion.Model``中的``TimeEmbedding``一模一样
    """
    def __init__(self, T, d_model, dim):
        assert d_model % 2 == 0
        super().__init__()
        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
        emb = torch.exp(-emb)
        pos = torch.arange(T).float()
        emb = pos[:, None] * emb[None, :]
        assert list(emb.shape) == [T, d_model // 2]
        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
        assert list(emb.shape) == [T, d_model // 2, 2]
        emb = emb.view(T, d_model)

        self.timembedding = nn.Sequential(
            nn.Embedding.from_pretrained(emb, freeze=False),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )

    def forward(self, t):
        emb = self.timembedding(t)
        return emb


class ConditionalEmbedding(nn.Module):
    """
    这是一个条件编码模块,将condition编码为embedding
    除了初始化Embedding不同,其他部分与time-embedding无异。
    """
    def __init__(self, num_labels, d_model, dim):
        assert d_model % 2 == 0
        super().__init__()
        # 注意,这里在初始化embedding时有一个细节——``num_embeddings=num_labels+1``也就是10+1=11
        # 本实例中考虑的condition是CIFAR10的label,共10个类别,对应0~9,按理来说只需要10个embedding即可,
        # 但是我们需要给``无条件``情况一个embedding表示,在本实例中就是用``0```来表示,
        # 与此同时10个类别对应的标号分别加一,即1~10(会在``TrainCondition.py``中体现), 因此共需要11个embedding
        self.condEmbedding = nn.Sequential(
            nn.Embedding(num_embeddings=num_labels + 1, embedding_dim=d_model, padding_idx=0),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )

    def forward(self, labels):
        cemb = self.condEmbedding(labels)
        return cemb


class DownSample(nn.Module):
    """
    相比于``Diffusion.Model.DownSample``, 这里的降采样模块多加了一个5x5、stride=2的conv层
    前向过程由3x3和5x5卷积输出相加得来,不知为什么这么做,可能为了融合更多尺度的信息
    查看原文(4.Experiments 3~4行),原文描述所使用的模型与《Diffusion Models Beat GANs on Image Synthesis》所用模型一致,
    但是该文章源码并没有使用这种降采样方式,只是简单的3x3或者avg_pool
    """
    def __init__(self, in_ch):
        super().__init__()
        self.c1 = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
        self.c2 = nn.Conv2d(in_ch, in_ch, 5, stride=2, padding=2)

    def forward(self, x, temb, cemb):
        x = self.c1(x) + self.c2(x)
        return x


class UpSample(nn.Module):
    """
    相比于``Diffusion.Model.UpSample``, 这里的上采样模块使用反卷积而不是最近邻插值
    同``DownSample``也不明白原因,因该两种方式都可以,看个人喜好。
    """
    def __init__(self, in_ch):
        super().__init__()
        self.c = nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, padding=1)
        self.t = nn.ConvTranspose2d(in_ch, in_ch, kernel_size=5, stride=2, padding=2, output_padding=1)

    def forward(self, x, temb, cemb):
        _, _, H, W = x.shape
        x = self.t(x)
        x = self.c(x)
        return x


class AttnBlock(nn.Module):
    """
    和``Diffusion.Model``中的``AttnBlock``一模一样
    """
    def __init__(self, in_ch):
        super().__init__()
        self.group_norm = nn.GroupNorm(32, in_ch)
        self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.group_norm(x)
        q = self.proj_q(h)
        k = self.proj_k(h)
        v = self.proj_v(h)

        q = q.permute(0, 2, 3, 1).view(B, H * W, C)
        k = k.view(B, C, H * W)
        w = torch.bmm(q, k) * (int(C) ** (-0.5))
        assert list(w.shape) == [B, H * W, H * W]
        w = F.softmax(w, dim=-1)

        v = v.permute(0, 2, 3, 1).view(B, H * W, C)
        h = torch.bmm(w, v)
        assert list(h.shape) == [B, H * W, C]
        h = h.view(B, H, W, C).permute(0, 3, 1, 2)
        h = self.proj(h)

        return x + h


class ResBlock(nn.Module):
    """
    相比于``Diffusion.Model.ResBlock``, 这里的残差模块多加了一个条件投射层``self.cond_proj``,
    在这里其实可以直接把它看作另一个time-embedding, 它们参与训练的方式一模一样
    """
    def __init__(self, in_ch, out_ch, tdim, dropout, attn=True):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.GroupNorm(32, in_ch),
            Swish(),
            nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
        )
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(tdim, out_ch),
        )
        self.cond_proj = nn.Sequential(
            Swish(),
            nn.Linear(tdim, out_ch),
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(32, out_ch),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
        )
        if in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
        else:
            self.shortcut = nn.Identity()
        if attn:
            self.attn = AttnBlock(out_ch)
        else:
            self.attn = nn.Identity()

    def forward(self, x, temb, cemb):
        h = self.block1(x)
        h += self.temb_proj(temb)[:, :, None, None]  # 加上time-embedding
        h += self.cond_proj(cemb)[:, :, None, None]  # 加上conditional-embedding
        h = self.block2(h)                           # 特征融合

        h = h + self.shortcut(x)
        h = self.attn(h)
        return h


class UNet(nn.Module):
    """
    相比于``Diffusion.Model.UNet``, 这里的UNet模块就多加了一个``cond_embedding``,
    还有一个变化是在降采样和上采样阶段没有加自注意力层,只在中间过度的时候加了一次,这我不明白是何用意,
    可能是希望网络不要从自己身上学到太多,多关注condition?(我瞎猜的)
    """
    def __init__(self, T, num_labels, ch, ch_mult, num_res_blocks, dropout):
        super().__init__()
        tdim = ch * 4
        self.time_embedding = TimeEmbedding(T, ch, tdim)
        self.cond_embedding = ConditionalEmbedding(num_labels, ch, tdim)
        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record output channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):
            out_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout))
                now_ch = out_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)

        self.middleblocks = nn.ModuleList([
            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])

        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=False))
                now_ch = out_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
        )

    def forward(self, x, t, labels):
        # Timestep embedding
        temb = self.time_embedding(t)
        cemb = self.cond_embedding(labels)
        # Downsampling
        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h, temb, cemb)
            hs.append(h)
        # Middle
        for layer in self.middleblocks:
            h = layer(h, temb, cemb)
        # Upsampling
        for layer in self.upblocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb, cemb)
        h = self.tail(h)

        assert len(hs) == 0
        return h


if __name__ == '__main__':
    batch_size = 8
    model = UNet(
        T=1000, num_labels=10, ch=128, ch_mult=[1, 2, 2, 2],
        num_res_blocks=2, dropout=0.1)
    x = torch.randn(batch_size, 3, 32, 32)
    t = torch.randint(1000, size=[batch_size])
    labels = torch.randint(10, size=[batch_size])
    # resB = ResBlock(128, 256, 64, 0.1)
    # x = torch.randn(batch_size, 128, 32, 32)
    # t = torch.randn(batch_size, 64)
    # labels = torch.randn(batch_size, 64)
    # y = resB(x, t, labels)
    y = model(x, t, labels)
    print(y.shape)

4.3. TrainCondition.py

import os
from typing import Dict
import numpy as np

import torch
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.utils import save_image

from DiffusionFreeGuidence.DiffusionCondition import GaussianDiffusionSampler, GaussianDiffusionTrainer
from DiffusionFreeGuidence.ModelCondition import UNet
from Scheduler import GradualWarmupScheduler


def train(modelConfig: Dict):
    device = torch.device(modelConfig["device"])
    # dataset
    dataset = CIFAR10(
        root='./CIFAR10', train=True, download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]))
    dataloader = DataLoader(
        dataset, batch_size=modelConfig["batch_size"], shuffle=True, num_workers=4, drop_last=True, pin_memory=True)

    # model setup
    # 这里模型的输入相比于无条件的情况多了一个``num_labels``即分类数据集的类别数,这里是CIFAR10有10个类别,所以num_labels=10
    net_model = UNet(T=modelConfig["T"], num_labels=10, ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"],
                     num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
    if modelConfig["training_load_weight"] is not None:
        net_model.load_state_dict(torch.load(os.path.join(
            modelConfig["save_dir"], modelConfig["training_load_weight"]), map_location=device), strict=False)
        print("Model weight load down.")
    optimizer = torch.optim.AdamW(
        net_model.parameters(), lr=modelConfig["lr"], weight_decay=1e-4)
    # 设置学习率衰减,按余弦函数的1/2周期衰减,从``lr``衰减至0
    cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer, T_max=modelConfig["epoch"], eta_min=0, last_epoch=-1)
    # 设置逐步预热调度器,学习率从0逐渐增加至multiplier * lr, 共用1/10总epoch数,后续学习率按``cosineScheduler``设置进行变化
    warmUpScheduler = GradualWarmupScheduler(optimizer=optimizer, multiplier=modelConfig["multiplier"],
                                             warm_epoch=modelConfig["epoch"] // 10, after_scheduler=cosineScheduler)
    # 实例化训练模型
    trainer = GaussianDiffusionTrainer(
        net_model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)

    # start training
    for e in range(modelConfig["epoch"]):
        with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
            for images, labels in tqdmDataLoader:
                # train
                b = images.shape[0]                                     # 获取batch大小
                optimizer.zero_grad()                                   # 清空过往梯度
                x_0 = images.to(device)                                 # 将输入图像加载到计算设备上
                labels = labels.to(device) + 1                          # 将label也就是condition加载到计算设备上,这里+1的原因
                                                                        # 和``ModelCondition.py``中的``ConditionalEmbedding``一致
                if np.random.rand() < 0.1:
                    labels = torch.zeros_like(labels).to(device)        # 10%的概率使用0替代condition
                loss = trainer(x_0, labels).sum() / b ** 2.             # 前向传播计算损失
                loss.backward()                                         # 反向计算梯度
                torch.nn.utils.clip_grad_norm_(
                    net_model.parameters(), modelConfig["grad_clip"])   # 裁剪梯度,防止梯度爆炸
                optimizer.step()                                        # 更新参数
                tqdmDataLoader.set_postfix(ordered_dict={
                    "epoch": e,
                    "loss: ": loss.item(),
                    "img shape: ": x_0.shape,
                    "LR": optimizer.state_dict()['param_groups'][0]["lr"]
                })                                                      # 设置进度条显示内容
        warmUpScheduler.step()                                          # 调度器更新
        torch.save(net_model.state_dict(), os.path.join(
            modelConfig["save_dir"], 'ckpt_' + str(e) + "_.pt"))        # 保存模型


def eval(modelConfig: Dict):
    device = torch.device(modelConfig["device"])
    # load model and evaluate
    with torch.no_grad():
        # 这一块代码是用来生成label也就是condition,用来指导图像生成,
        # 具体做法是将batch按照10个类别分成10部分,假设batch_size=50, 那么step=5,
        # 经过for循环得到的labelList就是[0,0,0,0,0,1,1,1,1,1,2,...,9,9,9,9,9]
        # 最后还要对label+1得到最终的label,+1原因和之前一样。
        step = int(modelConfig["batch_size"] // 10)
        labelList = []
        k = 0
        for i in range(1, modelConfig["batch_size"] + 1):
            labelList.append(torch.ones(size=[1]).long() * k)
            if i % step == 0:
                if k < 10 - 1:
                    k += 1
        labels = torch.cat(labelList, dim=0).long().to(device) + 1
        print("labels: ", labels)
        # 建立和加载模型
        model = UNet(T=modelConfig["T"], num_labels=10, ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"],
                     num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
        ckpt = torch.load(os.path.join(
            modelConfig["save_dir"], modelConfig["test_load_weight"]), map_location=device)
        model.load_state_dict(ckpt)
        print("model load weight done.")
        # 实例化反向扩散采样器
        model.eval()
        sampler = GaussianDiffusionSampler(
            model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"], w=modelConfig["w"]).to(device)
        # Sampled from standard normal distribution
        # 随机生成高斯噪声图像并保存
        noisyImage = torch.randn(
            size=[modelConfig["batch_size"], 3, modelConfig["img_size"], modelConfig["img_size"]], device=device)
        saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1)
        save_image(saveNoisy, os.path.join(
            modelConfig["sampled_dir"],  modelConfig["sampledNoisyImgName"]), nrow=modelConfig["nrow"])
        # 反向扩散并保存输出图像
        sampledImgs = sampler(noisyImage, labels)
        sampledImgs = sampledImgs * 0.5 + 0.5  # [0 ~ 1]
        print(sampledImgs)
        save_image(sampledImgs, os.path.join(
            modelConfig["sampled_dir"],  modelConfig["sampledImgName"]), nrow=modelConfig["nrow"])

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
xiaoxingxing的头像xiaoxingxing管理团队
上一篇 2023年3月5日
下一篇 2023年3月5日

相关推荐