Openai神作Dalle2理论和代码复现

Openai神作Dalle2

注:大家觉得博客好的话,别忘了点赞收藏呀,本人每周都会更新关于人工智能和大数据相关的内容,内容多为原创,Python Java Scala SQL 代码,CV NLP 推荐系统等,Spark Flink Kafka Hbase Hive Flume等等~写的都是纯干货,各种顶会的论文解读,一起进步。
今天和大家分享一下Openai神作Dalle2理论和代码复现
论文:https://cdn.openai.com/papers/dall-e-2.pdf
代码:https://github.com/lucidrains/DALLE2-pytorch
#博学谷IT学习技术支持#

前言

今天和大家分享的是一篇2022Openai神作Dalle2,如何输入一句话,生成非常有趣的图片过程。
先来看看论文的最终效果图。

在这里插入图片描述

是不是非常有趣?让我们看看Openai团队是如何做到的。

一、Dalle2模型整体框架

在这里插入图片描述
其实论文整体架构并不是很复杂,主要是运用
1.对比学习为主的CLIP
2.生成模型Diffusion Model
把这两个技术栈结合起来,这两个技术栈在我之前的博客中都已经有说明。大家可以先看一下。否则可能有一点抽象。
https://blog.csdn.net/weixin_53280379/article/details/125585445?spm=1001.2014.3001.5502
https://blog.csdn.net/weixin_53280379/article/details/126250598?spm=1001.2014.3001.5502

二、预训练模型CLIP

在这里插入图片描述
这里模型的上半部分,就是用预训练模型CLIP的结果,得到2个向量

  1. 第一个是text_embed,他是一个(2乘512)的tensor,2表示batch_size,512表示特征向量的维度,text_embed可以看作是对输入这句话的总结。
  2. 第二个是text_encodings,他是一个(2乘256乘512)的tensor,他是输入这句话每个词的特征。2表示batch_size,256表示这句话的长度不够用0代替,512表示特征向量的维度。

对应原文中的:
在这里插入图片描述

所以这里预训练模型CLIP的作用就是:输入一句话,通过CLIP,得到2个更接近图片的向量。把这两个向量运用到实际的下游任务中。

1.主要代码

text_embed, text_encodings = self.clip.embed_text(text)

三、先验模型

有了上面通过CLIP得到的2个文本特征向量以后,通过生成模型Diffusion Model,进一步对特征进行处理。
在这里插入图片描述
这一步主要就是运用Diffusion Model,但是和传统的Diffusion Model又有一点不一样,原文当中有说,传统的Diffusion Model是通过unet等网络,学到一个噪音,通过噪音一步步进行迭代,而Dalle2没有学噪音,而是直接学习得到x0,省去了中间的计算过程。
在这里插入图片描述

在这里插入图片描述
通过CLIP得到的2个文本特征向量,1个随机初始的噪音和一个时间步特征学习这个x0,有了这个x0,就和Diffusion Model模型一样,通过xt可以推出xt-1的分布了,在对xt-1的分布经过正太分布重采样技巧,就可以一步步迭代往前推,得到最终的特征向量了。
如何训练得到x0,主要就是通过一个transformer网络:
在这里插入图片描述

1.主要代码

这里x是随机初始的噪音,t是时间步特征,text_cond是CLIP文本特征。得到pred就是预测上一步所需要的x0特征。

pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond)

有了x0特征以后,根据Diffusion Model公式推出xt-1的分布,正太分布主要是一个期望和方差。和Diffusion Model公式完全一样。

def q_posterior(self, x_start, x_t, t):
    posterior_mean = (
        extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
        extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
    )
    posterior_variance = extract(self.posterior_variance, t, x_t.shape)
    posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped

model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)

这里就是重采样操作,得到xt-1特征图的image_embed

def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
    b, *_, device = *x.shape, x.device
    model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
    noise = torch.randn_like(x)
    # no noise when t == 0
    nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
    return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise#采样 (0.5 * model_log_variance).exp()计算得到标准差
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)

最终不停迭代n次得到先验模型最终的特征图输出

for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
    times = torch.full((batch,), i, device = device, dtype = torch.long)
    image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)

四、Decoder解码模型

在这里插入图片描述
Decoder解码模型和先验模型类似也是运用Diffusion Model,但是运用了2个Diffusion Model,可能一个效果一般,但是这样速度就变慢了。

在这里插入图片描述
模型最后做了各种对比实验发现,Diffusion Model会比自回归的方法效果更好。

在这里插入图片描述

总结

今天和大家分享一下Openai神作Dalle2理论和代码复现,有点难度,有点抽象,主要是需要有对比学习为主的CLIP和生成模型Diffusion Model的基础,然后把两者所结合起来。文章主要是给大家分享了一下主要思想,还有细节没写。有时间可以读一下,有疑问可以留言进行讨论。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
扎眼的阳光的头像扎眼的阳光普通用户
上一篇 2023年2月25日 上午9:15
下一篇 2023年2月25日 上午9:16

相关推荐