Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成

学习前言

我又死了我又死了我又死了!
在这里插入图片描述

源码下载地址

https://github.com/bubbliiiing/ddpm-pytorch

喜欢的可以点个star噢。

网络构建

一、什么是Diffusion

在这里插入图片描述
如上图所示。DDPM模型主要分为两个过程:
1、Forward加噪过程(从右往左),数据集的真实图片中逐步加入高斯噪声,最终变成一个杂乱无章的高斯噪声,这个过程一般发生在训练的时候。加噪过程满足一定的数学规律。
2、Reverse去噪过程(从左往右),指对加了噪声的图片逐步去噪,从而还原出真实图片,这个过程一般发生在预测生成的时候。尽管在这里说的是加了噪声的图片,但实际去预测生成的时候,是随机生成一个高斯噪声来去噪。去噪的时候不断根据Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成的图片生成Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成的噪声,从而实现图片的还原。

1、加噪过程

在这里插入图片描述

Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成

其中每次加入的噪声都服从高斯分布 Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成,两个高斯分布的相加高斯分布满足公式:Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成,因此,得到Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成的公式为:
Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成
因此不断往里面套,就能发现规律了,其实就是累乘
可以直接得出Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成的公式:
Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成

其中Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成,这是随Noise schedule设定好的超参数,Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成也是一个高斯噪声。通过上述两个公式,我们可以不断的将图片进行破坏加噪。

2、去噪过程

在这里插入图片描述

此时的均值为:Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成。根据之前的公式,Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成,我们可以使用Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成反向估计Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成得到Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成满足分布Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成。最终得到均值为Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成代表t时刻的噪音是什么。由Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成无法直接获得,网络便通过当前时刻的Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成经过神经网络计算Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成也就是上面提到的Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成代表神经网络。
Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成
由于加噪过程中的真实噪声Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成在复原过程中是无法获得的,因此DDPM的关键就是训练一个由Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成估测橾声的模型 Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成,其中Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成就是模型的训练参数,Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成 也是一个高斯噪声 Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成,用于表示估测与实际的差距。在DDPM中,使用U-Net作为估测噪声的模型。

本质上,我们就是训练这个Unet模型,该模型输入为Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成,输出为Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成时刻的高斯噪声。即利用Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成预测这一时刻的高斯噪声。这样就可以一步一步的再从噪声回到真实图像。

二、DDPM网络的构建(Unet网络的构建)

在这里插入图片描述

本质上,DDPM最重要的工作就是训练Unet模型,该模型输入为Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成,输出为Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成时刻的高斯噪声。即利用Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成预测上一时刻的高斯噪声。这样就可以一步一步的再从噪声回到真实图像。

假设我们需要生成一个[64, 64, 3]的图像,在Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成时刻,我们有一个Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成噪声图,该噪声图的的shape也为[64, 64, 3],我们将它和Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成一起输入到Unet中。Unet的输出为Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成时刻的[64, 64, 3]的噪声。

实现代码如下,代码中的特征提取模块为残差结构,方便优化:

import math

import torch
import torch.nn as nn
import torch.nn.functional as F


def get_norm(norm, num_channels, num_groups):
    if norm == "in":
        return nn.InstanceNorm2d(num_channels, affine=True)
    elif norm == "bn":
        return nn.BatchNorm2d(num_channels)
    elif norm == "gn":
        return nn.GroupNorm(num_groups, num_channels)
    elif norm is None:
        return nn.Identity()
    else:
        raise ValueError("unknown normalization type")
    
#------------------------------------------#
#   计算时间步长的位置嵌入。
#   一半为sin,一半为cos。
#------------------------------------------#
class PositionalEmbedding(nn.Module):
    def __init__(self, dim, scale=1.0):
        super().__init__()
        assert dim % 2 == 0
        self.dim = dim
        self.scale = scale

    def forward(self, x):
        device      = x.device
        half_dim    = self.dim // 2
        emb = math.log(10000) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        # x * self.scale和emb外积
        emb = torch.outer(x * self.scale, emb)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

#------------------------------------------#
#   下采样层,一个步长为2x2的卷积
#------------------------------------------#
class Downsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.downsample = nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1)
    
    def forward(self, x, time_emb, y):
        if x.shape[2] % 2 == 1:
            raise ValueError("downsampling tensor height should be even")
        if x.shape[3] % 2 == 1:
            raise ValueError("downsampling tensor width should be even")

        return self.downsample(x)

#------------------------------------------#
#   上采样层,Upsample+卷积
#------------------------------------------#
class Upsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(in_channels, in_channels, 3, padding=1),
        )
        
    def forward(self, x, time_emb, y):
        return self.upsample(x)

#------------------------------------------#
#   使用Self-Attention注意力机制
#   做一个全局的Self-Attention
#------------------------------------------#
class AttentionBlock(nn.Module):
    def __init__(self, in_channels, norm="gn", num_groups=32):
        super().__init__()
        
        self.in_channels = in_channels
        self.norm = get_norm(norm, in_channels, num_groups)
        self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)
        self.to_out = nn.Conv2d(in_channels, in_channels, 1)

    def forward(self, x):
        b, c, h, w  = x.shape
        q, k, v     = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)

        q = q.permute(0, 2, 3, 1).view(b, h * w, c)
        k = k.view(b, c, h * w)
        v = v.permute(0, 2, 3, 1).view(b, h * w, c)

        dot_products = torch.bmm(q, k) * (c ** (-0.5))
        assert dot_products.shape == (b, h * w, h * w)

        attention   = torch.softmax(dot_products, dim=-1)
        out         = torch.bmm(attention, v)
        assert out.shape == (b, h * w, c)
        out         = out.view(b, h, w, c).permute(0, 3, 1, 2)

        return self.to_out(out) + x
    
#------------------------------------------#
#   用于特征提取的残差结构
#------------------------------------------#
class ResidualBlock(nn.Module):
    def __init__(
        self, in_channels, out_channels, dropout, time_emb_dim=None, num_classes=None, activation=F.relu,
        norm="gn", num_groups=32, use_attention=False,
    ):
        super().__init__()

        self.activation = activation

        self.norm_1 = get_norm(norm, in_channels, num_groups)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)

        self.norm_2 = get_norm(norm, out_channels, num_groups)
        self.conv_2 = nn.Sequential(
            nn.Dropout(p=dropout), 
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
        )

        self.time_bias  = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else None
        self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None

        self.residual_connection    = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        self.attention              = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups)
    
    def forward(self, x, time_emb=None, y=None):
        out = self.activation(self.norm_1(x))
        # 第一个卷积
        out = self.conv_1(out)
        
        # 对时间time_emb做一个全连接,施加在通道上
        if self.time_bias is not None:
            if time_emb is None:
                raise ValueError("time conditioning was specified but time_emb is not passed")
            out += self.time_bias(self.activation(time_emb))[:, :, None, None]

        # 对种类y_emb做一个全连接,施加在通道上
        if self.class_bias is not None:
            if y is None:
                raise ValueError("class conditioning was specified but y is not passed")

            out += self.class_bias(y)[:, :, None, None]

        out = self.activation(self.norm_2(out))
        # 第二个卷积+残差边
        out = self.conv_2(out) + self.residual_connection(x)
        # 最后做个Attention
        out = self.attention(out)
        return out

#------------------------------------------#
#   Unet模型
#------------------------------------------#
class UNet(nn.Module):
    def __init__(
        self, img_channels, base_channels=128, channel_mults=(1, 2, 2, 2),
        num_res_blocks=2, time_emb_dim=128 * 4, time_emb_scale=1.0, num_classes=None, activation=F.silu,
        dropout=0.1, attention_resolutions=(1,), norm="gn", num_groups=32, initial_pad=0,
    ):
        super().__init__()
        # 使用到的激活函数,一般为SILU
        self.activation = activation
        # 是否对输入进行padding
        self.initial_pad = initial_pad
        # 需要去区分的类别数
        self.num_classes = num_classes
        
        # 对时间轴输入的全连接层
        self.time_mlp = nn.Sequential(
            PositionalEmbedding(base_channels, time_emb_scale),
            nn.Linear(base_channels, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        ) if time_emb_dim is not None else None
    
        # 对输入图片的第一个卷积
        self.init_conv  = nn.Conv2d(img_channels, base_channels, 3, padding=1)

        # self.downs用于存储下采样用到的层,首先利用ResidualBlock提取特征
        # 然后利用Downsample降低特征图的高宽
        self.downs      = nn.ModuleList()
        self.ups        = nn.ModuleList()
        
        # channels指的是每一个模块处理后的通道数
        # now_channels是一个中间变量,代表中间的通道数
        channels        = [base_channels]
        now_channels    = base_channels
        for i, mult in enumerate(channel_mults):
            out_channels = base_channels * mult
            for _ in range(num_res_blocks):
                self.downs.append(
                    ResidualBlock(
                        now_channels, out_channels, dropout,
                        time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
                        norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,
                    )
                )
                now_channels = out_channels
                channels.append(now_channels)
            
            if i != len(channel_mults) - 1:
                self.downs.append(Downsample(now_channels))
                channels.append(now_channels)

        # 可以看作是特征整合,中间的一个特征提取模块
        self.mid = nn.ModuleList(
            [
                ResidualBlock(
                    now_channels, now_channels, dropout,
                    time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
                    norm=norm, num_groups=num_groups, use_attention=True,
                ),
                ResidualBlock(
                    now_channels, now_channels, dropout,
                    time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, 
                    norm=norm, num_groups=num_groups, use_attention=False,
                ),
            ]
        )

        # 进行上采样,进行特征融合
        for i, mult in reversed(list(enumerate(channel_mults))):
            out_channels = base_channels * mult

            for _ in range(num_res_blocks + 1):
                self.ups.append(ResidualBlock(
                    channels.pop() + now_channels, out_channels, dropout, 
                    time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, 
                    norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,
                ))
                now_channels = out_channels
            
            if i != 0:
                self.ups.append(Upsample(now_channels))
        
        assert len(channels) == 0
        
        self.out_norm = get_norm(norm, base_channels, num_groups)
        self.out_conv = nn.Conv2d(base_channels, img_channels, 3, padding=1)
    
    def forward(self, x, time=None, y=None):
        # 是否对输入进行padding
        ip = self.initial_pad
        if ip != 0:
            x = F.pad(x, (ip,) * 4)

        # 对时间轴输入的全连接层
        if self.time_mlp is not None:
            if time is None:
                raise ValueError("time conditioning was specified but tim is not passed")
            time_emb = self.time_mlp(time)
        else:
            time_emb = None
        
        if self.num_classes is not None and y is None:
            raise ValueError("class conditioning was specified but y is not passed")
        
        # 对输入图片的第一个卷积
        x = self.init_conv(x)

        # skips用于存放下采样的中间层
        skips = [x]
        for layer in self.downs:
            x = layer(x, time_emb, y)
            skips.append(x)
        
        # 特征整合与提取
        for layer in self.mid:
            x = layer(x, time_emb, y)
        
        # 上采样并进行特征融合
        for layer in self.ups:
            if isinstance(layer, ResidualBlock):
                x = torch.cat([x, skips.pop()], dim=1)
            x = layer(x, time_emb, y)

        # 上采样并进行特征融合
        x = self.activation(self.out_norm(x))
        x = self.out_conv(x)
        
        if self.initial_pad != 0:
            return x[:, :, ip:-ip, ip:-ip]
        else:
            return x

三、Diffusion的训练思路

Diffusion的训练思路比较简单,首先随机给每个batch里每张图片都生成一个t,代表我选择这个batch里面第t个时刻的噪声进行拟合。代码如下:

t = torch.randint(0, self.num_timesteps, (b,), device=device)

生成batch_size个噪声,计算施加这个噪声后模型在t个时刻的噪声图片是怎么样的,如下所示:

def perturb_x(self, x, t, noise):
    return (
        extract(self.sqrt_alphas_cumprod, t,  x.shape) * x +
        extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * noise
    )   

def get_losses(self, x, t, y):
    # x, noise [batch_size, 3, 64, 64]
    noise           = torch.randn_like(x)

    perturbed_x     = self.perturb_x(x, t, noise)

之后利用这个噪声图片、t和网络模型计算预测噪声,利用预测噪声和实际噪声进行拟合。

def get_losses(self, x, t, y):
    # x, noise [batch_size, 3, 64, 64]
    noise           = torch.randn_like(x)

    perturbed_x     = self.perturb_x(x, t, noise)
    estimated_noise = self.model(perturbed_x, t, y)

    if self.loss_type == "l1":
        loss = F.l1_loss(estimated_noise, noise)
    elif self.loss_type == "l2":
        loss = F.mse_loss(estimated_noise, noise)
    return loss

利用DDPM生成图片

DDPM的库整体结构如下:
在这里插入图片描述

一、数据集的准备

在训练前需要准备好数据集,数据集保存在datasets文件夹里面。
在这里插入图片描述

二、数据集的处理

打开txt_annotation.py,默认指向根目录下的datasets。运行txt_annotation.py。
此时生成根目录下面的train_lines.txt。
在这里插入图片描述

三、模型训练

在完成数据集处理后,运行train.py即可开始训练。
在这里插入图片描述

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
扎眼的阳光的头像扎眼的阳光普通用户
上一篇 2023年3月1日
下一篇 2023年3月1日

相关推荐