[自注意力神经网络]Segment Anything(SAM)论文阅读

论文地址https://arxiv.org/abs/2304.02643源码地址https://github.com/facebookresearch/segment-anything强烈建议大家试试Demo,效果真的很好:https://segment-anything.com/

 一、概述

        本文建立了一个基础图像分割模型,并将其在一个巨大的数据集上进行训练,目的是解决一系列下游任务。本文的关键点有3个:taskmodeldata

                Task

                        本文定义了一个可提示的通用分割任务,可以提供的预训练目标来支持下游任务的应用。提示仅指定在图像中分割的什么(如对象的空间/文本信息),这个提示可以是模糊的,输出会针对这些对象输出至少一个合理的掩码。

                Model

                        本任务需要支持灵活提示的模型,且这个模型需要提示时实时输出分割掩码。满足三个约束:①图像编码器(负责图像嵌入);②提示编码器(负责提示嵌入);③结合前面两个信源的轻量级掩码解码器;该模型称为:Segment Anything Model(SAM);该模型还有歧义意识,能自然的处理歧义。

                Data

                        本项目使用了一个大规模数据源SA-1B;同时为了管理好这些数据,本项目建立了一个数据引擎(Data engine),分为三个阶段:辅助-手动半自动全自动

 二、网络结构

        本文从NLP中获得启发,将token用于基础模型的训练,通过即时的工程解决各种下游任务。SAM会从单个模糊点提示生成3个有效掩码,分别为:整体、部分、子部分,如下图所示:

         Segment Anything Model(SAM)的网络结构如下图所示:

         1.Image encoder(图像编码器)

                使用了一个MAE预训练的Vision Transformer(ViT)作为图像编码器。ViT网络可以参见往期博文[自注意力神经网络]Transfomer架构。 

        2.Prompt Encoder(提示编码器)

                本设计中有2组提示,分为:稀疏的(点、框、文本)和密集的(掩码)。点和框可以由位置编码表示,位置编码综合了来自每种提示的学习嵌入和任意形式的文字(使用CLIP处理)。而掩码则通过卷积嵌入后与图像诸元素求和。

        3.Lightweight mask decoder(轻量化掩码解码器)

                掩码解码器可以有效的将图嵌入提示嵌入输出标记映射到掩码。本模型的解码器基于Transformer的解码器块修改,在解码器后添加了动态掩码预测头。解码器使用了提示自注意力和交叉注意力在提示到图嵌入prompt-to-image embedding)和副反转vice-versa,这个翻译不好 )两个方面进行了修改。完成这两个部分后,对图像进行上采样再使用MLP将输出标记映射到动态线性分类器上,最终得出每个图像位置的蒙板前景概率。

                         每层解码器执行以下4个步骤:

                                ①对token进行自注意力

                                ②从token(作为查询)和图(嵌入向量)进行交叉注意

                                ③MLP逐点更新到每个token

                                ④从图(嵌入向量)到token(作为查询)进行交叉注意

                        为了保证解码器能访问到关键几何信息,当他们参与注意层运算时,位置编码都会被添加到图嵌入向量中;此外,整个原始token(包含位置信息)也会被重新添加到图嵌入向量中。

        4.其他技术细节

                ①歧义感知(Ambiguity-aware)

                        对于一个不确定的提示,模型会给出多个有效掩码,经过修改SAM可以由单个提示预测输出多个掩码(一般是3个–整体部分子部分)。训练时,仅掩码进行反向传播。为了对掩码进行排名,模型会预测每个掩码的置信分数(使用IOU度量)

                ②损失函数和训练(Losses and training)

                        本项目使用焦点损失函数(focal loss)和筛子损失函数(dice loss)的线性组合作为损失函数来监督掩码预测。对于文本提示,采用几何提示的方法进行混合训练。随后,随机抽取11轮中的每种掩码进行模拟交互,由于这种设计,SAM可以被无缝的接入数据引擎

三、SAD数据引擎(Segment Anything Data Engine)

        本小节对应前面的Data部分,主要用来从互联网上收集蒙板数据集SA-1B。数据引擎分为三个阶段:

        1)辅助手动阶段(Assisted-manual stage)

                进行人工标注,此阶段SAM先使用常见数据集进行分割训练,在标记了足够多数据时,SAM仅使用新注释的掩码进行再训练。

        2)半自动阶段(Semi-automatic stage)

                这个阶段的目标主要是增加遮罩的多样性以提高模型的切割能力。首先SAM会预测一些掩码,这些掩码会被展示给标注者,由标注者对任何其他未标注的地方进行标注,一增强模型在不显眼位置的能力。

        3)全自动阶段(Fully automatic stage)

                这个阶段标注是全自动的,由于歧义感知的存在,即是在模棱两可的情况下也能预测出有效掩码。对于最终得到的预测掩码,需要使用非极大抑制(NMS)来过滤重复项。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
xiaoxingxing的头像xiaoxingxing管理团队
上一篇 2023年6月12日
下一篇 2023年6月12日

相关推荐