融合 Mamba 与 Transformer | MaskMamba 引领非自回归图像合成,推理速度提升 54.44% !

科技   2024-11-03 09:02   上海  

点击下方卡片,关注「AI视界引擎」公众号


( 添加时备注:方向+学校/公司+昵称/姓名 )

图像生成模型遇到了与可扩展性和二次复杂性相关的挑战,主要原因是依赖于基于Transformer的 Backbone 网络。

在本研究中,作者引入了一种新颖的混合模型MaskMamba,它结合了Mambo和Transformer架构,使用Masked Image Modeling进行非自回归图像合成。

作者仔细重新设计了双向Mamba架构,通过实现两个关键的修改:

(1)用标准卷积替换因果卷积,以更好地捕捉全局上下文;

(2)用 ConCat 而不是乘法,这显著提高了性能,同时加快了推理速度。此外,作者还探索了MaskMamba的各种混合方案,包括串行和分组并行排列。

此外,作者引入了一个在语境中的条件,使得作者的模型可以执行分类到图像和文本到图像生成任务。

MaskMamba 在生成质量上超过了基于Mamba和Transformer的模型。

值得注意的是,它实现了在2048x2048分辨率下推理速度的54.44%的显著提升。

1 Introduction

近年来,计算机视觉领域生成图像模型的研究取得了显著进展,特别是在类别到图像 ;Sun等人(2024);Sauer等人(2022))和文本到图像任务。传统的自回归生成模型,如VQGAN ,在条件生成方面表现出色。在文本条件生成领域,模型如Parti(Yu等人和 DALL-E 使用图像分词器和附加的 MLP 将图像转换为离散 Token ,并将编码的文本特征通过另一个MLP(Chen等人(2023))投射到描述嵌入中,以自回归方式在训练和推理中进行。

同时,非自回归方法,包括MAGE(Li等人(2023)和MUSE(Chang等人(2023))),利用Masked Image Modeling,在训练期间将图像转换为离散 Token ,并随机预测被遮挡的 Token 。

另一种在图像生成中具有突出地位的方法是扩散模型,例如LDM(Rombach et al. (2022))带有UNet Backbone 网络。尽管这些模型展示了很高的生成质量,但它们的卷积神经网络架构给可扩展性带来了限制。

为了解决这个问题,基于Transformer的生成模型,如DiT(Peebles和Xie (2023)),通过注意力机制显著增强了全局建模能力,同时显著提高了生成质量。然而,注意力机制的计算复杂性随序列长度呈平方关系增加,这限制了训练和推理效率。

Mamba (Gu和Dao (2023)) 提出了一个状态空间模型(Gu等人(2022,2021)),其具有线性时间复杂度,在处理长序列任务方面具有显著优势。当前的图像生成努力,包括 DiM 和 diffuSSM,主要用 Mamba 模块替代了原始的 Transformer 模块。这些模型在提高效率和可扩展性方面都有所提升。

然而,基于扩散模型的图像生成通常需要数百次迭代,这可能非常耗时。

为了消除Transformer模型中序列长度带来的二次复杂度增长和自动回归模型中生成迭代过多的问题,作者提出了MaskMamba,它整合了Mamba和Transformer架构,并利用非自回归式遮挡图像建模(Ni等人(2024);Lezama等人(2022))进行图像合成。作者精心重新设计了Bi-Mamba(Mo和Tian(2024);Zhu等人(2024)),通过用标准卷积替换因果卷积,使其适用于遮挡图像生成。

同时,在Bi-Mamba的最后阶段,作者选择连接而非乘法来降低计算复杂度,与Bi-Mamba(Zhu等人(2024))相比,显著提高了推理速度,提高了17.77%。

作者进一步研究了各种MaskMamba混合方案,包括串行和分组并行方案(Shaker等人,2024年)。在串行方案中,作者探索了层与层交替的安排,以及将Transformer放在最后层。对于分组并行方案,作者评估了将模型沿着通道维度分为两组或四组的影响。

作者的发现表明,将Transformer放在最后层可以显著提高模型捕捉全局上下文的能力。

此外,作者实现了一个在语境中的条件,使得作者的模型可以在单个框架中同时执行从类别到图像的生成和从文本到图像的生成任务,如图1所示。

同时,作者研究了条件嵌入(Zhu等人,2024年)的放置位置,包括输入序列的不同位置,包括 Head 、中部和尾部。结果表明,将条件嵌入放在中部可以获得最佳性能。

在实验部分,作者通过两个不同的任务来验证MaskMamba的生成能力:条件生成和文本生成,每个任务使用各种大小的模型。对于条件生成到图像的任务,作者在ImageNet1k(Deng等人(2009))数据集上训练300个周期,将作者的MaskMamba与类似大小的基于Transformer和Mamba的模型进行比较。结果表明,在生成质量和推理速度方面,作者的MaskMamba都优于这两个对照组。此外,作者在CC3M(Sharma等人(2018))数据集上进行训练和评估,在CC3M和MS-COCO(Lin等人(2014))验证数据集上取得了优越的性能。

总结起来,作者的贡献包括:

  1. 作者重新设计了Bi-Mamba,通过用标准卷积代替因果卷积来提高其对遮挡图像生成任务适用性。此外,在最后阶段用 ConCat 代替乘法,从而显著提高了性能并比Bi-Mamba快了17.77%的推理速度。

  2. 作者提出了MaskMamba,这是一个统一的生成模型,它集成了重新设计的Bi-Mamba和Transformer层,使得通过在语境中的条件,可以在同一模型中执行类到图像和文本到图像的生成任务。

  3. 作者的MaskMamba模型在ImageNet1k和CC3M数据集上,无论是在生成质量还是推理速度方面,都超过了基于Transformer和基于Mamba的模型。

相关工作图像生成。图像生成的领域正在当前研究中取得重大进展。最初的自动回归图像生成模型,如 VQGAN(Esser等人(2021))和LlamaGen(Sun等人(2024)),证明了通过将图像转换为离散 Token 并应用自动回归模型生成图像 Token ,可以生成高保真图像的潜力。文本到图像生成模型的出现,如Parti;2022)和 DALL-E(Ramesh等人(2021)),进一步推动了这一领域的进展。然而,这些模型在生成过程中存在特定效率问题。为了解决这些问题,非自动回归生成模型如MaskGIT、MAGE 和 MUSE(Chang等人(2023))通过 Mask 图像建模提高了生成效率。同时,扩散模型;Song等人(2020);Ho等人(2020);Dhariwal和Nichol;Saharia等人(2022),如LDM(Rombach等人(2022)),在生成质量方面尽管受到与卷积神经网络基础架构相关的可伸缩性限制,但在生成质量上表现出色。为克服这些限制,Transformer基生成模型,包括DiT(Peebles和Xie(2023)),通过引入注意力机制提高了全局建模能力。然而,当处理大量序列时,这些模型仍面临着计算复杂性随平方增加的挑战。

Mamba Vision。Transformer 作为一种领先的网络架构,在各种任务中得到了广泛应用。然而,其平方的计算复杂度为长序列任务的有效处理带来了巨大障碍。在最近的发展中,一种新的状态-空间模型(Gu等人,2021年;Gu和Dao,2023年;Dao和Gu,2024年)——被称为Mamba(Gu和Dao,2023年;Dao和Gu,2024年),在处理长序列任务方面展现出巨大的潜力,并在研究社区中引起了广泛关注。Mamba架构已经有效地替代了传统的Transformer框架,在多个领域取得了显著的成果。Mamba 家族涵盖了广泛的应用,包括文本生成、物体识别、3D点云处理、推荐系统以及图像生成,并有许多基于如Vision-Mamba 、U-Mamba 和Rec-Mamba 等框架的实现。Vision-Mamba采用双向状态-空间模型结构,并与混合 Transformer 相结合。

然而,Mamba在非自回归图像生成方面的应用尚未得到探索。目前,大多数基于Mamba的生成任务遵循扩散模型范式,这涉及到训练和推理次数的复杂性。

为了解决这些挑战,作者设计了一种新颖的混合Mamba结构,旨在将Mamba应用于非自回归图像生成任务,并将其与Masked Image Modeling(He等人,2022年)相结合,用于训练和推理,从而提高这些过程的效率。

3 Method

MaskMamba Model: Overview

如图2所示,作者的MaskMamba核心包括三个部分。首先,将图像像素通过图像分词器(Yu等人(2021);Van Den Oord等人(2017);Esser等人(2021))量化为离散的 Token ,其中表示图像分词器的下采样比。这些离散 Token 作为图像词表的索引。然后,作者随机选择 Mask 比例(范围为0.55至1.0),并从 Token 中进行 Mask ,用可学习的 Mask  Token 替换它们。其次,作者将类别ID转换为可学习的标签嵌入(Peebles和Xie(2023);Esser等人(2021)》,表示为。另一方面,关于文本条件,作者首先使用T5-Large Encoder(Colin(2020))提取特征,然后将提取的特征映射到描述嵌入(Chen等人(2023)》,表示为

最后,作者将条件嵌入与图像 Token Embedding 在中部拼接,其中表示,并添加位置嵌入到这些。训练目标是利用交叉熵损失(Zhang & Sabuncu(2018))预测被 Mask 区域的 Token 索引。

模型配置作者提出两种图像生成模型:条件分类模型和条件文本模型。遵循先前的研究工作(Radford等人(2019);Touvron等人(2023)的标准),作者遵循Mamba的标准配置。如Tab.1所示,作者提供了三种条件分类模型的不同版本,参数大小从103M到741M不等。生成的图像分辨率为256x256,经过16倍下采样因子后,图像 Token Embedding 的长度设置为256。类别条件嵌入的长度设置为1,文本条件嵌入的长度N设置为120。

MaskMamba Model: Architecture

3.2.1 Bi-Mamba-V2 Layer.

卷积替换。如图3(c)所示,作者将原始的Bi-Mamba(朱等,2024年)架构进行了重新设计,以便更好地适应与遮挡图像生成相关的任务。作者将原始的因果卷积替换为标准卷积。由于遮挡图像生成的非自回归性质,因果卷积只允许单向的 Token 混合,这限制了非自回归图像生成的潜力。相反,标准卷积使 Token 可以在输入序列中的所有位置双向互动,有效地捕获全局上下文。

对称SSM分枝设计

作者将对称SSM分枝引入,以更好地适应 Mask 图像生成。在对称分枝中,作者在Backward SSM之前先将输入翻转,然后在其之后再翻转,将其与Forward SSM的结果合并。此外,与Bi-Mamba右侧分枝相比,作者使用额外的卷积层来减少特征损失。为了充分利用所有分枝的优势,作者将输入映射到大小为的特征空间,从而确保最终拼接维数一致。作者的输出可以表示为,其计算使用以下公式1。

3.2.2 Maskmaba Hybrid Scheme.

群体方案设计。如图4(a)和图4(b)所示,作者设计了两组群体混合方案。在群体方案v1中,作者将输入数据沿通道维度分为两组,然后分别由作者的Bi-Mamba-v2层和Transformer层处理。接下来,作者将处理结果沿通道维度进行 ConCat ,并最终输入到Norm和Project层。在群体方案v2中,作者将输入数据沿通道维度分为四组。其中两组由作者 的Bi-Mamba-v2层在前向和后向SSM中处理,而另外两组由Transformer层处理。

串行方案设计。如图4(c)和图4(d)所示,作者还设计了两种串行混合方案。在串行方案v1中,作者依次层叠地排列作者的Bi-Mamba-v2和Transformer。在串行方案v2中,作者将Bi-Mamba-v2放在前层,将Transformer放在后层。由于Transformer的注意力机制可以更好地增强特征表示,作者在所有串行模式中,将Transformer层放在Mamba层之后。

Image Generation By MaskLambda

作者利用 Mask 图像生成(Li等人(2023);Chang等人(2022))方法进行图像合成。对于生成分辨率为256×256,下采样因子为16的情况下,在正向传播过程中,作者首先初始化256个 Mask  Token 。然后,作者将条件嵌入与中间位置的 Mask  Token 连接。受到MUSE(Chang等人(2023))迭代生成方法的启发,作者的解码过程也采用余弦计划(Chang等人(2022)),在每一步选择最高置信度的 Mask  Token 进行预测。这些 Token 随后被设置为无 Mask ,剩余步骤中的 Mask  Token 集合相应减少。通过这种方法,作者可以在20个解码步骤中推理256个 Token ,而自动回归方法(Touvron等人(2023);Sun等人(2024))需要256个步骤。

条件图像生成。类别标签嵌入基于每个类别的索引。这些类别标签嵌入与 Mask  Token  ConCat ,MaskMamba通过余弦进度表逐渐预测这些 Mask  Token 。

文本条件图像生成。 首先,作者使用Colin(2020)的T5-Large Encoder提取文本特征,然后将这些特征转换为描述符嵌入。与标签嵌入类似,作者将提取的描述符嵌入与 Mask  Token Embedding  ConCat 。MaskMamba通过余弦时间表逐渐预测这些 Mask  Token 。

无分类引导图像生成 扩散模型(Ho 和 Salimans,2022年)提出的无分类引导(CFG)方法是一种非常有效的技术,可增强模型在处理文本和图像特征时的条件生成能力。因此,作者将这种方法应用到作者的模型中。在训练阶段,为了模拟无条件图像生成的过程,作者以0.1的概率随机删除条件嵌入。在推理阶段,每个 Token 的logit 由以下方程确定:,其中是无条件logit,是条件logit,是CFG的缩放。

4 Experimental Results

Class-conditional Image Generation

训练设置所有类到图像生成模型都在ImageNet 数据集上训练300个epoch,所有模型的训练参数设置保持一致。具体而言,基本学习率设置为每256个批次大小为1e-4,全局批量大小为1024。此外,作者使用AdamW优化器,其中β1 = 0.9,β2 = 0.95。正则化率保持一致,包括在特定条件下。在训练期间, Mask 率从0.5变化到1。所有模型和推理都将在具有32GB内存的V100 GPU上进行训练和推理。

评估指标作者使用FID-50K(Heusel等人(2017))作为主要的评估指标,同时采用Inception Score(Salimans等人(2016))(IS)和Inception Score标准差(IS-std)作为评估标准。在ImageNet验证数据集上,作者根据CFG生成50,000张图像,并使用上述指标对所有模型进行评估。

4.1.1 Qualitative Results

与其他图像生成方法的比较 如图2所示,作者将MaskMamba模型与流行的图像生成模型进行了比较,包括自回归(AR)方法(Esser等人(2021);Sun等人(2024))、 Mask 预测模型(Mask)(Li等人(2023);Chang等人(2022))和基于Transformer的模型(Masked Image Modeling训练使用相同的超参数),重点关注它们的基础网络差异。MaskMamba使用串行方案v2模式。在不同模型大小的比较中,MaskMamba表现出竞争力的性能。如图5所示,作者从MaskMamba-XL模型中随机选择的图像仅在ImageNet上训练即可获得高质量的结果。

4.1.2 Experiment Analysis

无类别指导(CFG)和生成迭代的影响。图6(a)显示了在图像生成中,当CFG设置为3时,随着迭代次数的增加,FID和IS变化。模型在25迭代时达到最佳性能,进一步增加迭代次数将降低FID。图6(b)显示了不同CFG设置下的FID和IS分数,表明无类别指导可以提高视觉质量,而当CFG=3时,模型达到最佳性能。

有效性分析 作者进行了一系列实验来评估作者重新设计的Bi-Mamba-v2层、原始Bi-Mamba层和Transformer层的有效性。为了评估在更高分辨率图像上的推理实验,作者主要关注单层推理速度和内存使用。所有关于效率分析的实验都是在A100 40G设备上进行的,并且作者比较了这些模型在不同分辨率下的推理速度,如图7所示。结果表明,当分辨率小于时,作者的Bi-Mamba-v2层和原始Bi-Mamba层比Transformer层稍慢。然而,当分辨率超过时,作者的Bi-Mamba-v2层比Transformer层和原始Bi-Mamba层都要快。值得注意的是,在的分辨率下,作者的Bi-Mamba-v2层比Transformer层快1.5倍。作者还比较了不同批处理大小的GPU内存使用情况。

作者的Bi-Mamba-v2层的内存使用与Bi-Mamba层的内存使用相当,而Transformer层由于其二次复杂度,随着批处理大小的增加,内存使用呈指数增长。当批处理大小达到6时,Transformer层消耗了63GB的GPU内存,导致内存不足,而作者的Bi-Mamba-v2层只需要38GB。这些实验结果表明,作者的Bi-Mamba-v2层可以在更快的速度下生成图像,并且内存使用更低。

不同混合方案的影响. 如图3所示,作者对MaskMamba在各种混合配置下的图像生成结果进行了比较分析,这些混合配置分为两类:并行和串行。如图4所示,在分组并行配置中,作者研究将模型分为两组和四组的效果。在分层串行配置中,作者设计了一种交错结构,包括Bi-Mamba-v2和Transformer {MSMS...MSMS},以及另一种配置{MMMM...SSSS},其中前层是Mamba,后层是Transformer。这些实验的结果揭示了不同混合配置的性能和效率。

不同 Backbone 的影响。 作者在不同的 Backbone 上进行了消融实验:VisionMamba (Zhu等人,2024年)提出的Bi-Mamba,重新设计的 Bi-Mamba-V2 ,以及Transformer (Vaswani,2017年)。Bi-Mamba-L只使用原始的Bi-Mamba作为层,而 Bi-Mamba-V2 使用作者重新设计的Bi-Mamba-v2。Transformer只使用Transformer架构。在(Bi-Mamba + Transformer)-L中,前层是原始的Bi-Mamba,然后是层的Transformer。在(Bi-Mamba-V2 + Transformer)-L中,前层是Bi-Mamba-v2,然后是层的Transformer。

结果表明,作者重新设计的Bi-Mamba-v2在原始Bi-Mamba之上提高了性能,将Mamba和Transformer结合进一步提高了结果。因此,作者选择(Bi-Mamba-V2 + Transformer)用于MaskMamba。

不同条件嵌入位置的影响。作者进行消融实验来评估条件嵌入 的放置对模型性能的影响。具体而言,作者研究了条件嵌入在序列的 Head 、中间和尾部的不同位置的组合效果。实验结果表明,当条件嵌入置于中间位置时,性能最优。这一结果主要归因于选择性扫描的机制。由于作者随机遮挡图像 Token 的一部分,将条件嵌入置于 Head 或尾部会导致由于注意力距离增加而无法提供足够的监督信息进行条件生成控制。

Text-conditional Image Generation

训练设置

与分类训练策略类似,作者为文本数据采用一种遮挡生成非自回归训练策略。在Sharma等人(2018年)的CC3M(256×256)数据集上训练模型30个周期。训练参数与之前的实验保持一致,基本学习率设置为每256个批次的1e-4,全局批处理大小为1024。此外,作者使用AdamW优化器,其中β1 = 0.9,β2 = 0.95。

基于CC3M的模型训练。如表6所示,作者在CC3M和MS-COCO的验证集上比较了Transformer-XL和作者的MaskMamba-XL在文本到图像生成的性能,评估了FID和IS。作者的结果始终优于基于Transformer的模型。如图8所示,作者使用CC3M中的文本作为 Prompt 来生成图像。MaskMamba-XL能够生成高质量的图像。然而,由于训练数据的有限性和CC3M数据集中文本描述的精度不精确,一些生成的图像存在局限性。

5 Conclusion.


在本工作中,作者提出了一种新颖的混合模型 MaskMamba,该模型结合了Mamba 和 Transformer 架构,利用 Masked Image Modeling进行非自回归图像合成。

作者不仅重新设计了一种新的Bi-Mamba结构,使其更适合图像生成,而且还研究了不同的模型混合策略和条件嵌入的放置,最终确定了最佳设置。

此外,作者在一个包含上下文的条件下,提供了一系列类别条件图像生成模型和文本条件图像生成模型。

作者的实验结果表明,在生成质量和推理速度方面,作者的MaskMamba模型超过了基于Transformer和基于Mamba的模型。

作者希望Masked Image Modeling for non-autoregressive image synthesis在MaskMamba中的应用可以激发对Mamba图像生成任务进行进一步探索。




点击上方卡片,关注「AI视界引擎」公众号


集智书童
书童带你领略视觉前沿之美,精选科研前沿、工业实用的知识供你我进步与学习!
 最新文章