本文介绍了一种名为CLEAR的卷积式线性化方法,用于将预训练的扩散变换器的注意力机制线性化,从而显著提高高分辨率图像生成的效率。通过限制特征交互到局部窗口,CLEAR在保持与原始模型相当的性能的同时,将注意力计算减少了99.5%,并在生成8K分辨率图像时加速了6.3倍。
太长不看版
端侧文生图扩散模型的成功范式。
Diffusion Transformer (DiT) 已经成为图像生成的主要架构。然而,Self-Attention 的二次复杂度负责对 token 之间的关系进行建模,在生成高分辨率图像时会产生显著的时延。为了解决这个问题,本文的目标是引入线性注意力机制,将预训练的 DiT 的复杂度降低到线性。
作者对现有的高效注意机制做了全面的总结开始,并确定了 4 个关键因素,这些因素对于成功线性化预训练的 DiT 至关重要:局部性 (locality),表达一致性 (formulation consistency),高阶注意力图 (high-rank attention maps),和特征完整性 (feature integrity)。
基于以上观察,本文提出了一种称为 CLEAR 的类卷积的局部注意力策略,该策略将特征交互限制为每个 query 标记周围的局部窗口,从而实现线性复杂度。
实验表明,仅在 10K 个样本上微调注意力层进行 10K 次迭代,就可以有效地将知识从预训练的 DiT 转移到具有线性复杂度的学生模型,产生的结果与教师模型相当。
同时,CLEAR 将注意力计算减少了 99.5%,为生成 8K 分辨率的图像生成加速 6.3 倍。此外,本文研究了蒸馏注意力层中的一些好的性质,比如 Zero-Shot 的泛化性 (跨越各种模型和插件),改进支持了多 GPU 并行推理的。
下面是对本文的详细介绍。
本文目录
1 CLEAR:类卷积线性扩散 Transformer
(来自 NUS)
1 CLEAR 论文解读
1.1 CLEAR 研究背景
1.2 高效注意力机制:分类概述
1.3 线性化 DiT 什么比较重要?
1.4 类卷积线性化
1.5 训练和优化
1.6 多 GPU 并行推理
1.7 实验设置
1.8 实验结果
1 CLEAR:类卷积线性扩散 Transformer
论文名称:CLEAR: Conv-Like Linearization Revs Pre-Trained Diffusion Transformers Up
论文地址:
http://arxiv.org/pdf/2412.16112
Project Page:
http://github.com/Huage001/CLEAR
1.1 CLEAR 研究背景
扩散模型在文生图领域得到了广泛的关注,其证明了从文本提示生成高质量,多样化的图像非常有效。传统基于 U-Net 的架构因其强大的生成能力主导了该领域。近年来,Diffusion Transformer (DiT) 已成为一种很有前途的替代方案,在该领域取得了领先的性能。DiTs 利用 Self-Attention 灵活地对复杂的 token 之间的关系进行建模,使其能够捕获图像和文本中所有 token 的细微依赖关系,产生视觉上丰富且连贯的输出。
尽管 Self-Attention 的性能令人印象深刻,但其二次复杂度的复杂成对 token 关系进行建模,在高分辨率图像生成中引入大量延迟。如图 2 所示,FLUX.1-dev[1]是最先进的文生图 DiT 模型,在生成 8K 分辨率的图像时,即使使用像 FlashAttention[2][3]这样的硬件感知优化,也有 20 个 denoising steps,时延超过 30min。
针对这些缺点,本文希望把预训练的 DiT 转化成线性复杂度的。尚不清楚现有的高效注意力机制可以有效地应用于预训练的 DiT。
为了回答这个问题,作者总结了以前致力于高效注意力的方法,将它们分为 3 种主要策略:Formulation Variation,Key-value Compression,和 Key-value Sampling。然后,作者尝试通过用这些高效的替代方案替换原始注意力层来微调模型。结果表明,虽然 Formulation Variation 策略已被证明在基于注意力的 U-Net[4]和从头开始训练的 DiTs 方面是有效的[5],但它们与预训练的 DiTs 并没有适配得很成功。Key-value Compression 通常会导致细节失真,Key-value Sampling 突出了 local token 对每个 Query 生成视觉上连贯的结果的必要性。
基于这些观察,本文找出了 4 个对线性化预训练 DiT 至关重要的组件,包括局部性、公式一致性、高阶注意力图和特征完整性。作者又提出了一种类似卷积的线性化策略 CLEAR,其中每个 Query 只与预定义距离 r 内的 token 交互。由于每个 Query 交互的 Key-value token 的数量是固定的,因此生成的 DiT 在图像分辨率方面实现了线性复杂度。
令人惊讶的是,这样简洁的设计得到了与原始 FLUX.1-dev 相当的结果,在 10K 个样本上只需 10K 次微调迭代。如图 1 所示,CLEAR 表现出令人满意的交叉分辨率泛化能力,该特性也反映在基于 UNet 的扩散模型[6]。对于 8K 等超高分辨率生成,将注意力计算减少了 99.5%,将原始 DiT 加速了 6.3 倍,如图 2 所示。蒸馏的局部注意力也与教师模型的不同变体兼容,例如 FLUX.1-dev 和 FLUX.1-schnell,以及各种预训练的插件,如 ControlNet。
1.2 高效注意力机制:分类概述
Self-Attention 机制在建模 token 关系方面很灵活。给定矩阵,生成输出矩阵为:
其中, 和 分别是 Query 和 Key token 的数量, 和 是 Query 和 Value 的特征维度,本文假定 。
如式 1 所示,Self-Attention 需要计算 个 token 与 token 的关系,导致时间和内存的复杂性。为了解决这个问题,许多研究侧重于开发高效注意力机制。本文将现有方法分为 3 个主要类别:Formulation Variation,Key-value Compression,和 Key-value Sampling。
Formulation Variation
回顾式 1,如果省略 Softmax 操作,可以首先计算 ,线性注意力机制分别对 和 应用核函数 和 来模拟 Softmax 的影响:
例如,Mamba2[7]、Gated Linear Attention[8],和 Generalized Linear Attention[9]。另一种主流方法试图将 softmax 操作替换为有效的替代方案,例如 Sigmoid[10]、ReLU2[11]和基于 Nystrom[12]的近似。
Key-value Compression
在 Self-Attention 的默认设置中,Query 和 Key,Value token 的数量是一致的,即 ,注意力图的形状为 。因此,压缩 key value token 有望使 可以小于 以降低复杂度。按照这个流程,PixArt-Sigma[13]使用下采样 Conv2d 算子在本地压缩 KV token。Agent Attention[14]首先对下采样 得到 Agent tokens,再与 交互。Linformer[15]引入了可学习的映射,从原始映射中获取压缩 tokens。
Key-value Sampling
基于 Key-value Sampling 的高效注意力假设:并非所有 Key-Value 对 Query 是同等重要的,且注意力矩阵是高度稀疏的。与 Key-value Compression 相比,Key-value Sampling 会 prune 每个 token 的原始 key-value token,而不是生成新的 key-value token。比如,Routing Attention[16]基于分组对 key-value token 进行采样。Swin Transformer[17]将特征图划分为不重叠的局部 window,并为每个 window 独立执行注意力。BigBird[18]使用结合邻域注意力和随机注意力的 token 选择策略,LongFormer[19]将邻域注意力与全局 tokens 相结合,这些全局 tokens 对所有 token 可见。
1.3 线性化 DiT 什么比较重要?
基于以上高效注意力机制概述,作者这里探索了一个关键问题:什么对于线性化预训练的 DiT 至关重要?
作者在本节中用各种替代方案替换 FLUX.1-dev 中的所有注意力层。初步的文本到图像结果如图 3 所示,作者找出了 4 个关键元素:局部性 (locality),表达一致性 (formulation consistency),高阶注意力图 (high-rank attention maps),和特征完整性 (feature integrity)。根据这些点,作者总结了之前一些高效注意力方法,如图 4 所示。
局部性
局部性表明,对于 Attention 中的 Query,只包含一个邻域的 Key,Value。从图 3 中,可以观察到许多有此功能的方法可以产生合理的结果,比如 PixArt-Sigma、Swin Transformer 和 Neighborhood Attention。特别是,比较 Neighborhood Attention 和 Strided Attention 的结果,作者发现结合局部 key-value token 会减少很多失真模式。
这些现象的原因是预训练的 DiT,例如 FLUX,严重依赖局部特征来管理 token 之间的关系。为了验证这一点,作者在图 5 中可视化了注意力图,观察到最显著的注意力分数落在每个 Query 周围的局部区域中。
图 6 提供了进一步的证据来说明局部特征的重要性,即扰动远程特征不会过多损害 FLUX.1-dev 的质量。具体来说,FLUX.1-dev 依靠 RoPE 感知空间关系,并且对 2 D 特征图的两个轴上的相对距离 很敏感,其中索引 和 分别表示 Q 和 K 的 token 索引。作者一这样的方式扰动远程特征 ,即当 RoPE 的相对距离超过阈值 时,将距离 clip 到最大值。当 小到 8 时, 特征映射的结果仍是合理的。相反,如果扰动局部特征,将最小绝对距离 设置为 2 ,结果就崩溃,如图 6 所示。这些结果强调局部性的重要性。
表达一致性
表达一致性的意思是还需要使用基于 Softmax 的 Scaled Dot-product Attention。LinFusion 表明,Linear Attention 等方法在基于注意力的 U-Net 中取得了成果。然而,本文发现预训练的 DiT 并非如此,如图 3 所示。作者推测这是由于注意力层是 DiT 中令牌交互的唯一模块,与 U-Net 的情况不同。替换所有这些会对最终输出产生重大影响。Sigmoid Attention 等公式无法在有限次的迭代中收敛,无法减轻原始公式和修改后的公式之间的差异。因此,保持与原始注意力功能的一致性是有益的。
高阶注意力图
高阶注意力图意味着通过高效的注意力替代方案计算的注意力图应该足以捕获复杂的 token 关系。如图 7 所示,注意力大多集中在对角线,表明注意力图没有表现出许多先前工作假设的 low-rank 属性。这就是为什么 Linear Attention 和 Swin Transformer 等方法在很大程度上会产生 Block 状模式。
特征完整性
特征完整性意味着原始 Q,K,V 特征比压缩之后的特征更有利。尽管 PixArt-Sigma 已经证明对深层中 KV 应用压缩不会对性能造成太大影响,但这种方法不适合完全线性化预训练的 DiT。如图 3 所示,与 Swin Transformer 和 Neighborhood Attention 的结果相比,基于 KV 压缩的方法 (如 PixArt-Sigma 和 Agent Attention) 往往会使得纹理失真,这个结果突出了保留原始 Q,K,V token 的完整性的必要性。
1.4 类卷积线性化
基于对线性化 DiT 的上述分析,Neighborhood Attention 是满足所有约束的唯一方案。基于此,作者提出了 CLEAR,一种为预训练 DiT 定制的类卷积线性化策略。
鉴于最先进的用于文生图的 DiT,如 FLUX 和 StableDiffusion 3 系列,通常采用文本-图像联合的 Self-Attention 进行特征交互,对于每个 text Query,从所有 text 和 image 的 key-value tokens 中收集特征。对于每个 image query,与所有 text token 交互,还与周围局部窗口中的 key-value tokens 进行交互。由于 text token 的数量和局部窗口大小随着分辨率的增加保持不变,因此整体复杂度与图像 token 的数量成线性关系。
与使用方形滑动局部窗口的 Neighborhood Attention 和标准 2D 卷积不同,CLEAR 采用圆形窗口,其中每个 query 考虑欧几里得距离小于半径 的 key-value token。与相应的方形窗口相比,这种设计引入的计算开销约为 倍。注意力掩码如下:
其中, 表示 text token 的数量。图 8 说明了这种范式。
1.5 训练和优化
尽管每个 query 只能访问本地窗口内的 tokens,但堆叠多个 Transformer Block 使每个 token 逐渐能够捕获整体信息:类似于卷积神经网络运行的方式。为了提高微调前后模型之间的功能一致性,作者在微调过程中采用了知识蒸馏目标。具体来讲,包括传统的 Flow Matching 损失函数[20][21]:
其中, 表示使用预训练的 VAE 编码器 编码的图像 的特征,而 是第 个时间步的噪声版本, 是文本条件, 是参数为 的 DiT Backbone。除此之外,作者在预测和注意力输出方面鼓励线性化学生模型与原始教师模型之间的一致性:
其中, 表示原始教师 DiT 的参数, 是注意力层应用损失项的数量。上标 表示层索引。训练目标可以写成:
其中, α 和 β 是控制相应损失项权重的超参数。只有注意力层中的参数是可训练的。对于训练数据,本文发现使用原始 DiT 模型生成的样本进行训练,比在真实数据集上作训练得到更好的结果,即便真实数据集包含更高质量的数据。
1.6 多 GPU 并行推理
由于注意力仅限于每个 query 周围的局部窗口,与原始 DiT 的注意力相比,CLEAR 为多 GPU patch-wise 并行推理提供了更高的效率,这对于生成超高分辨率图像特别有价值。具体来说,每个 GPU 负责处理一个图像 patch。换句话说,如果将 特征图沿垂直维度划分为 patches,每个 GPU 处理 patch,则每个相邻 GPU 之间图像标记的通信成本在 CLEAR中为 ,原始 DiT 中为 。
然而,由于每个 text token 都需要来自所有 image token 的信息,CLEAR 中进行精确注意力计算仍然需要专门为 text tokens 同步所有的 key-value tokens,损害了它的潜力。幸运的是,如图 9 所示,作者发现在没有任何训练的情况下,text tokens 的原始注意力计算可以通过 patch 的平均有效近似,同时不会对性能造成太大影响:
其中, 是 patch/GPU 索引。因此,只需要聚合 text token 的注意力输出,从而消除传输所有 key-value 对的需要。
此外,本文方法与现有的 Patch 并行策略正交,例如 Distrifusion[22],该策略通过使用陈旧的特征图应用异步计算和通信。在这些优化之上构建 CLEAR 可实现更大的加速。
1.7 实验设置
本文主要使用 FLUX 系列模型进行实验,因为它在文生图方面有最先进的性能。作者将 FLUX- 1.dev 中的所有注意力层替换为 CLEAR,并尝试 3 种不同的窗口大小,。依靠 PyTorch 中的 FlexAttention,CLEAR 作为一种稀疏注意力机制,可以使用 GPU 通过底层优化高效地实现。
作者使用式 6 中定义的损失函数,在总 Batch Size 为 32 的 10 K 个分辨率为 样本上微调注意层中的参数。 应用于 FLUX 的 single transformer blocks,层索引为 20~57。继之前关于扩散模型的架构蒸馏的工作 LinFusion 之后,超参数 和 都设置为 0.5 。其他超参数遵循 Diffusers[23]的默认设置。训练是在 DeepSpeed ZeRO-2[24]支持的 4 个 H100 GPU 上进行的,这需要约 1 天才能完成。除非另有说明,所有推理都是在单个 H100 GPU 上进行的。
继之前的工作 LinFusion 之后,作者在 COCO2014 的验证集上定量研究了所提出的方法,并随机抽样 5000 张图像及其提示进行评估。作者使用 FID、LPIPS、CLIP 图像相似度和 DINO 图像相似度作为指标。对于需要像素级对齐的设置,如图像上采样和 ControlNet,还加入了 PSNR 和多尺度 SSIM 作为参考。在与 COCO 中的真实图像进行比较时,只包括分布距离的 FID 和 LPIPS。此外,采用 CLIP 文本相似度、Inception Score (IS) 和浮点运算次数 (FLOPs) 分别反映文本对齐、一般图像质量和计算负担。
1.8 实验结果
本文的目标是线性化预训练的 DiT,并且线性化模型有望与原始模型相媲美。如图 10 中的结果,高效的注意力算法会导致对目标问题性能次优。相比之下,本文提出的类卷积的线性化策略实现了与原始 FLUX-1.dev 相当或更好的性能,同时需要更少的计算。
利用式 5中定义的知识蒸馏损失项,进一步最小化线性化模型的输出与原始模型的输出之间的差异。当 时,CLIP 图像分数超过 90。定性地,如图 11 所示,CLEAR 的线性化模型保留了原始输出的整体布局、纹理和色调。
分辨率外推
线性化扩散模型的一个关键优势是它能够高效地生成超高分辨率图像。然而,之前许多研究表明,扩散模型在训练期间生成超出其原生分辨率的图像具有挑战性。因此,他们应用一种实用的解决方案,以从粗到细的方式生成高分辨率图像,并为位置嵌入和注意力尺度等组件设计自适应策略。另一方面,所提出的 CLEAR 对预训练的扩散主干进行了架构修改,使其无缝地适用于它们。
本文采用 SDEdit[25],一种简单有效的基线,将图像调整到更大的尺度,以生成高分辨率图像。通过调整 SDEdit 中的编辑强度,如图 12 所示,本文可以有效地控制精细细节和内容保存之间的权衡。
CLEAR 测量结果与原始 FLUX-1.dev 的结果之间的依赖关系。如图 13 所示,本文实现了高达 0.9 的 MS-SSIM 分数,展示了使用原始 FLUX 的有效替代方案作为 CLEAR 的线性化模型的有效性。
参考
^Flux: Official inference repository for flux.1 models ^FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness ^FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning ^LinFusion: 1 GPU, 1 Minute, 16K Image ^Sana: Efficient high-resolution image synthesis with linear diffusion transformers ^ScaleCrafter: Tuning-free Higher-Resolution Visual Generation with Diffusion Models ^Transformers are ssms: Generalized models and efficient algorithms through structured state space duality ^Gated linear attention transformers with hardware-efficient training ^Linfusion: 1 gpu, 1 minute, 16k image ^Theory, analysis, and best practices for sigmoid selfattention ^Transformer quality in linear time ^Nystr ̈omformer: A nystr ̈om-based algorithm for approximating self-attention ^PixArt-Σ: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation ^Agent Attention: On the Integration of Softmax and Linear Attention ^Linformer: Self-Attention with Linear Complexity ^Efficient Content-Based Sparse Attention with Routing Transformers ^Swin transformer: Hierarchical vision transformer using shifted windows ^Big Bird: Transformers for Longer Sequences ^Longformer: The Long-Document Transformer ^Flow Matching for Generative Modeling ^Scaling Rectified Flow Transformers for High-Resolution Image Synthesis ^DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models ^https://github.com/huggingface/diffusers/blob/main/examples/ dreambooth/README flux.md ^ZeRO: Memory optimizations Toward Training Trillion Parameter Models ^SDEdit: Guided Image Synthesis and Editing with Stochastic Differential Equations