长短期 Transformer :用于语言和视觉的高效 Transformer

科技   2024-12-03 14:12   北京  
摘要

Transformer模型已在语言和视觉领域取得成功。 然而,将其扩展到长序列(例如长文档或高分辨率图像)成本高昂,因为自注意力机制的时间和内存复杂度与输入序列长度呈二次方关系。 在本文中,我们提出了一种高效的自注意力机制——长短Transformer (Transformer-LS),用于对语言和视觉任务中的长序列进行建模,其时间复杂度为线性。 它结合了一种新颖的具有动态投影的长程注意力机制来建模远程关联,以及一种短期注意力机制来捕获细粒度的局部关联。 我们提出了一种双重归一化策略来解决这两种注意力机制之间的尺度不匹配问题。 Transformer-LS 可以应用于自回归模型和双向模型,而不会增加额外的复杂性。 我们的方法在语言和视觉领域的多个任务上都优于现有技术模型,包括远程竞技场基准测试、自回归语言建模和ImageNet分类。 For instance, Transformer-LS achieves 0.97 test BPC on enwik8 using half the number of parameters than previous method, while being faster and is able to handle 3× as long sequences compared to its full-attention version on the same hardware. 在 ImageNet 上,它可以获得最先进的结果(例如,仅在 224×224 ImageNet-1K 上训练的中等大小的 55.8M 模型可以获得 Top-1 准确率 84.1%),同时在高分辨率图像上更具可扩展性。 源代码和模型已发布在 https://github.com/NVIDIA/transformer-ls。

1引言

基于Transformer的模型[1]在自然语言处理 (NLP) [2, 3]和计算机视觉[4, 5, 6]领域取得了巨大成功。 这些模型受益于自注意力模块,该模块可以有效地捕获符元之间相邻和远程的相关性,并在现代硬件上进行高效扩展。 然而,自注意力消耗的时间和内存随着输入长度的增加呈二次方增长,使得处理长序列非常昂贵。 许多语言和视觉任务都受益于对长序列的建模。 在自然语言处理(NLP)中,文档级任务需要处理长篇文章[例如,7, 8],并且语言模型的性能通常随着序列长度的增加而提高[例如,9, 10]。 在计算机视觉中,许多任务都涉及高分辨率图像,这些图像在使用Transformer模型处理之前被转换为长的图像块序列[4, 6, 11]。 因此,设计一种高效的注意力机制来对长序列进行建模,并使其能够很好地泛化到不同的领域,至关重要。

已经提出了各种方法来降低完全注意力机制的二次方成本。 然而,一种在语言和视觉领域都能很好地泛化的有效注意力机制的研究较少。 一种方法是使用预定义的模式(例如滑动窗口[例如,12, 13, 14, 15]和随机稀疏模式[16])来稀疏化注意力矩阵。 这些方法利用强大的归纳偏差来提高计算性能和模型性能,但是它们限制了自注意力层的容量,因为每个特定的符元只能关注符元的一个子集。 另一种方法是利用低秩投影来形成输入序列的低分辨率表示,但是这些方法的成功应用仅限于某些NLP任务[例如,17, 18, 19]。 与稀疏注意力不同,这种方法允许每个符元关注整个输入序列。 然而,由于丢失了高保真逐符元信息,因此在需要细粒度局部信息的任务(包括语言[20]和视觉[21]中的标准基准)上,它们的性能有时不如完全注意力或稀疏注意力。

尽管高效Transformer取得了快速进展,但一些提出的架构只能应用于双向模型[例如,15, 16, 18]。 基于Transformer的自回归模型在语言建模[22]、图像合成[23]和文本到图像合成[24]方面取得了巨大的成功,这些任务也涉及长文本或高分辨率图像。 设计一种可以应用于自回归模型和双向模型的高效Transformer是可取的。

在这项工作中,我们将局部窗口注意力和一种新颖的长程注意力统一到一个高效的注意力机制中。 我们证明了这两种注意力具有互补作用,它们共同在语言和视觉的一系列任务中,对于自回归模型和双向模型都产生了最先进的结果。 具体而言,我们的贡献如下:

  •  

    我们提出了一种高效的Transformer模型——长短Transformer (Transformer-LS),它集成了基于动态投影的注意力机制来建模长程相关性,以及局部窗口注意力机制来捕捉细粒度相关性。 Transformer-LS可以应用于自回归和双向模型,其时间和内存复杂度均为线性。

  •  

    我们计算一个动态低秩投影,它取决于输入序列的内容。 与之前的低秩投影方法相比,我们的动态投影方法更灵活,并且对保持语义的位置变化(例如,插入、释义)更鲁棒。 我们证明它在Long Range Arena基准测试[20]上优于之前的低秩方法[17, 18]

  •  

    我们识别出了长程注意力和短程注意力的嵌入之间存在规模不匹配的问题,并设计了一种简单但有效的双重归一化策略,称为DualLN,以解决这种不匹配并增强聚合的有效性。

  •  

    我们证明,尽管长短Transformer具有较低的内存和运行时复杂度,但在Long Range Arena的一组任务以及enwik8和text8上的自回归语言建模方面,它仍然优于最先进的模型。 此外,所提出的高效注意力机制可以轻松地应用于最新的视觉Transformer架构[6, 11],并提供最先进的结果,同时更易于扩展到高分辨率图像。 我们还研究了Transformer-LS在不同ImageNet数据集上的鲁棒性。

2相关工作

2.1高效Transformer

近年来,已经引入了许多方法来处理完全注意力机制的二次成本。 一般来说,它们可以分为以下几类: i) 具有预定义模式的稀疏注意力机制(例如,滑动窗口),包括用于建模图像的Sparse Transformer[12]、Image Transformer[13]、Axial Transformer[25],以及用于建模语言的Longformer[14]、分块自注意力[26]、ETC[15]、Big Bird[16]。 ii) 低秩投影注意力,包括Linformer[17]、Nyströmformer[18]、Synthesizer[19]。 例如,Linformer使用线性层将原始高分辨率键(K)和值(V)(长度为n)投影到大小为r的低分辨率(rn),并允许所有查询符元(Q)关注这些压缩表示。 iii) 基于记忆的机制,例如压缩Transformer[10]和集合Transformer[27],它们使用额外的存储器来缓存全局长程信息,用于计算远距离符元之间的注意力。 iv) 注意力矩阵的基于核的近似方法,包括Performer[28]、线性Transformer[29]和随机特征注意力[30]。 vi) 基于相似性和聚类的方法,包括Reformer[31]、路由Transformer[32]和Sinkhorn Transformer[33]

我们的方法无缝地整合了低秩投影和局部窗口注意力,以利用它们的优势来建模长程和短期相关性。 特别地,我们的长程注意力使用动态低秩投影来编码输入序列,并且优于Linformer[17]使用的先前低秩投影方法。 同样,其他一些方法也尝试结合不同方法的优势。 例如,Longformer[14]和ETC[15]用任务驱动的全局符元增强局部窗口注意力。 这种全局符元可能不适用于某些任务(例如,自回归建模)。 BigBird[16]进一步结合了局部窗口和全局符元注意力以及随机稀疏注意力。 它不适用于自回归任务,因为引入了全局符元和随机稀疏模式。 为了压缩边缘设备上的模型占用空间,Lite Transformer[34]结合了卷积和自注意力,但对于长序列,它仍然具有二次复杂度。

2.2视觉Transformer

视觉Transformer (ViT)[4]将图像分割成小的图像块,并将这些图像块视为输入词符。 它使用标准的Transformer进行图像分类,并且在有足够训练数据的情况下,已被证明优于卷积神经网络(例如,ResNet[35])。 DeiT[36]应用了师生策略来缓解ViT的数据效率问题,并且仅使用标准ImageNet数据集[37]就显示出强大的可比性能。 最近的一些工作,包括金字塔视觉Transformer (PVT)[5]、Swin-Transformer[38]、T2T-ViT[39]、视觉Longformer (ViL)[11]和卷积视觉Transformer (CvT)[6],并没有在单个低分辨率的图像块(例如,16×16图像块)上应用Transformer,而是堆叠了一个ViT金字塔来形成一个多尺度架构,并在更高的分辨率(例如,56×56=3136图像块用于具有224×224像素的图像)上对长序列的图像块进行建模。 这些方法大多数的自注意力计算复杂度与输入图像大小成二次方关系。

为了降低复杂度,Swin-Transformer [38] 通过仅在每个局部窗口内限制自注意力的计算,实现了线性复杂度。 HaloNet [40] 对分块图像应用局部注意力,其计算复杂度仅与块的大小成二次方关系。 Perceiver [41] 使用数据和潜在数组之间的交叉注意力来替换数据上的自注意力,从而消除二次复杂度的瓶颈。 另一项同时进行的工作,视觉Longformer (ViL) [11],通过将Longformer [14] 应用于视觉,实现了线性复杂度。 ViL 使用特定任务的全局符元增强局部窗口注意力,但全局符元不适用于解码任务(例如,图像合成 [23, 24])。 相反,我们的方法通过结合局部窗口注意力和全局动态投影注意力将二次成本降低到线性成本,这可以应用于编码和解码任务。

3长短期Transformer

图1: 单个注意力头的长短期注意力。 其中,序列长度 n=8,隐藏维度 d=3,局部窗口段大小 w=2,以及动态投影的秩 r=3。 在图中,K(V) 表示键 K 或值 V。在左图中,我们将 K 或 Vn×d 虚拟地复制成 n 行,并突出显示所有 n 查询 Q 短期注意力范围内(表示为 K~(V~))的键和值。 在中间图中,所有查询都关注长期注意力中相同的投影键 K¯ 和值 V¯。 在右图中,K~(V~) 和 K¯(V¯) 首先用两组 LayerNorm 进行归一化,查询同时关注其注意力范围内的归一化 K~(V~) 和 K¯(V¯)

Transformer-LS 通过聚合长程和短程注意力来近似完全注意力,同时保持其捕获所有输入符元之间相关性的能力。 在本节中,我们首先介绍Transformer中多头注意力的预备知识。 然后,我们分别通过滑动窗口呈现短期注意力,并通过动态投影呈现长期注意力。 之后,我们提出了聚合方法和双重归一化 (DualLN) 策略。 参见图1,其中说明了我们的长短期注意力。

3.1预备知识和符号

多头注意力是 Transformer [1] 的核心组成部分,它通过关注不同表示子空间中的整个输入序列来计算每个符元的上下文表示。 其定义为

其中Q,K,Vn×d是查询、键和值嵌入,WOd×d是输出的投影矩阵,第i个头Hin×dk是缩放点积注意力,dk=d/h是每个头的嵌入维度,

其中WiQ,WiK,WiVd×dk是学习到的投影矩阵,Ain×n表示每个注意力头的完整注意力矩阵。 计算和存储Ai的复杂度为O(n2),当n很大时,这可能是难以承受的。 为简便起见,我们下面的讨论基于一维输入序列的情况。 给定预定的顺序,将其扩展到二维图像数据是很简单的。

3.2通过分段滑动窗口实现短期注意力

我们使用简单而有效的滑动窗口注意力来捕获细粒度的局部相关性,其中每个查询都关注固定大小邻域内的附近符元。 类似的技术也已在[14, 16, 11]中采用。 具体来说,为了提高效率,我们将输入序列划分为长度为w的不相交段。 一个段内的所有符元都关注其所属段内的所有符元,以及其所属段左右两侧w/2个连续的符元(必要时进行零填充),从而导致对总共2w个键值对的注意力跨度。 请参见附录中的图5。 对于第i个头部中位置t处的每个查询Qt,我们将它窗口内的2w键值对表示为K~t,V~t2w×d。 使用PyTorch实现时,这种分段滑动窗口注意力比每个符元滑动窗口注意力更快,其中每个符元都关注自身及其左右w个符元,并且其内存消耗随序列长度线性缩放;更多细节请参见[14]和我们的图3

通过为滑动窗口注意力的不同头部引入不同的扩张率[14],可以增强滑动窗口注意力以部分捕获远程相关性。 但是,不同头部的扩张率配置需要进一步调整,并且高效实现具有不同扩张率的多头注意力并非易事。 一个更有效的替代方案是使用随机稀疏注意力[16]增强滑动窗口注意力,但这并不能保证像全注意力那样在每一层都捕获远程相关性。 在下一节中,我们将提出我们的远程注意力来解决这个问题。

3.3通过动态投影实现远程注意力

先前的研究表明,自注意力矩阵可以用低秩矩阵的乘积很好地近似[17]。 通过用低秩矩阵的乘积[42, 19, 18, 43, 28]替换全注意力,每个查询能够关注所有符元。 Linformer[17]是此类别中最具代表性的模型之一。 它学习一个固定的投影矩阵来减少键和值的长度,但是固定投影对于保持语义的位置变化不够灵活。

从这些观察结果出发,我们将第i个头部的动态低秩投影参数化为Pi=f(K)n×r,其中rn是低秩大小,而Pi取决于输入序列的所有键Kn×d。 它将(n×dk)维键嵌入KWiK和值嵌入VWiV投影到更短的(r×dk)维键K¯i和值V¯i嵌入。 与Linformer[17]不同,低秩投影矩阵是动态的,它取决于输入序列,旨在更灵活且更能适应例如插入、删除、释义以及其他改变序列长度的操作。 请参见表2中的示例。 注意,查询嵌入QWiQn×dk保持相同的长度,我们让每个查询都关注K¯iV¯i。 通过这种方式,完整的(n×n)注意力矩阵可以分解为两个矩阵的乘积,这两个矩阵具有r列或行。 具体来说,我们将动态投影矩阵Pin×r和低秩注意力的键值嵌入K¯i,V¯ir×dk定义为

其中WiPd×r是可学习的参数,1 softmax对所有n符元的第一维上的投影权重进行归一化,这在我们的实验中稳定了训练。 注意,在所有我们考虑的实验中K=V,所以如果Pi依赖于V,它将保持不变。公式3的计算复杂度为O(rn)

为了了解完整的注意力是如何被低秩矩阵的乘积所取代的,我们将每个长程注意力的头部Hin×dk计算为:

因此,完整的注意力现在被两个低秩矩阵A¯in×rPir×n的隐式乘积所取代,计算复杂度降低到O(rn)。 注意,查询在所有符元上的有效注意力权重之和仍然为1。 我们的全局注意力允许每个查询关注同一自注意力层内的所有符元嵌入。 相反,稀疏注意力机制[14, 16]需要堆叠多层来构建这种相关性。

应用于自回归模型:  在自回归模型中,每个符元只能关注之前的符元,因此长程注意力对于不同的符元应该具有不同的范围。 实现我们全局注意力的一个直接方法是循环更新每个查询的K¯i,V¯i,但这需要由于softmax的非线性而为每个符元重新计算公式(3)中的投影,这导致O(rn2)的计算复杂度。 为了保持线性复杂度,对于自回归模型,我们首先将输入序列划分成长度为l的等长段,并应用我们的动态投影从每个段中提取K¯i,V¯i。 每个符元只能关注K¯i,V¯i不包含其未来符元的片段。 形式上,设Qt为位置t处的查询,K(l1)s:ls,V(l1)s:ls为来自第s个片段的键值对,以及st=t/l。 对于自回归模型,我们通过关注Ki,t,Vi,t来计算Qt的远程注意力,定义为

通过这种方式,动态低秩投影仅对每个片段并行应用一次,从而保持线性复杂度和较高的训练速度。 相比之下,由于需要递归,随机特征注意力[30]的训练速度较慢。

3.4聚合远程和短期注意力

为了聚合局部和远程注意力,我们没有为不同的头采用不同的注意力机制[12, 14, 34],而是让第i个头的每个查询都关注来自局部窗口和全局低秩投影的键和值的并集,因此它可以学习选择来自两者中的重要信息。 在我们对自回归语言模型的初步试验中,我们发现这种聚合策略比分离头部效果更好。 具体来说,对于第i个头,我们将全局低秩投影的键和值表示为K¯i,V¯ir×dk,并将局部键和值表示为K~t,V~t2w×d,它们位于查询Qt位置t的局部窗口内。 然后,位置t处的第i个注意力Hi,t

其中[;]表示沿第一维连接矩阵。 此外,我们发现K~tWiKK¯i的初始范数之间存在尺度不匹配,这使得在语言和视觉任务的初始化阶段,注意力偏向局部窗口。 我们引入一种归一化策略(DualLN)来对齐范数并提高聚合的有效性。

图2: 左:初始化时局部窗口平均2范数与全局低秩键/值嵌入的比率。 没有DualLN,稀疏和低秩嵌入存在幅度不匹配。 使用双层归一化(DualLN),每一层的比率将为1.0,这将有助于优化。 右图:在enwik8和text8数据集上,使用和未使用双层归一化(DualLN)的Transformer-LS模型的验证损失。

双层归一化(DualLN): 对于具有层归一化(LN)的Transformer(参见[44]的图示),Ki,Vi嵌入是LN层的输出,因此它们在初始化时均值为零,方差为一。 均值为零的向量的2范数与其方差成比例。 我们注意到,加权平均值将降低此类均值为零向量的方差,从而降低其范数。 结果,公式(3)中加权平均值K¯i,V¯i的低秩注意力嵌入向量的范数将小于来自滑动窗口注意力的常规键和值嵌入(参见图2左图示)。 这种尺度不匹配会导致两个副作用。 首先,局部秩分量的内积QtWiQK¯i的幅度往往小于局部窗口的幅度,因此长程注意力的注意力分数系统性地较小。 其次,即使低秩和局部窗口分配相同的注意力分数,低秩注意力的键值对K¯i,V¯iHi方向的影响也会自然减小,因为V¯i的范数较小。 这两种效应都会导致低秩分量上的梯度较小,并阻碍模型学习有效利用长程相关性。

为避免此类问题,我们在局部窗口和全局低秩注意力的键和值投影之后添加了两组层归一化,以便它们的尺度在初始化时对齐,但网络仍然可以在训练后学习重新加权范数。 具体来说,聚合注意力现在计算为

其中LN()L,LN()G分别表示局部和全局注意力的层归一化。 在实践中,为了保持局部注意力和动态投影之间的一致性,我们使用LNL(K),LNL(V)而不是K,V来计算公式3中的K¯i,V¯i。 如图2右图所示,使用双层归一化(DualLN)训练的Transformer-LS模型的验证损失始终低于未使用双层归一化(DualLN)的模型。

4 实验

本节,我们将演示我们的方法在语言和视觉领域中的有效性和效率。 我们使用 PyTorch 进行实现,并使用 fvcore [45] 统计 FLOPs。

4.1长程竞技场和 IMDb 上的双向建模

表 1: 长程竞技场 (LRA) 上的准确率 (%) 和 FLOPs (G),模型配置已标注(更多信息请参见表 7)。 所有结果均为使用不同随机种子进行的 4 次运行的平均值。

TaskListOpsTextRetrievalAverage
(mean ± std.) of sequence length(888 ± 339)(1296 ± 893)(3987 ± 560)
ModelAcc.FLOPsAcc.FLOPsAcc.FLOPsAcc.
Full Attention [1] 37.131.2165.354.5782.309.1461.59
Reformer [31] (2)36.440.2764.880.5878.641.1559.99
Linformer [17] (k=256)37.380.4156.120.8179.371.6257.62
Performer [28] (r=256)32.780.4165.210.8281.701.6359.90
Nyströmformer [18] (l=128)37.340.6165.751.0281.292.0361.46
Transformer-LS (w,r=8,32)37.500.2066.010.4081.790.8061.77
Dynamic Projection (best)37.790.1566.280.6981.862.1761.98
Transformer-LS (best)38.360.1668.400.2981.952.1762.90


表 2: 比较模型在测试时插入和删除操作下的鲁棒性。 DP 指的是通过动态投影实现的长程注意力,Win. 指的是滑动窗口注意力。

TaskTextRetrieval
Test PerturbNoneInsertionDeletionNoneInsertionDeletion
Linformer56.1255.9454.9179.3753.6651.75
DP66.2863.1658.9581.8670.0164.98
Linformer + Win.59.6356.6956.2979.6852.8352.13
DP + Win. (ours)68.4066.3462.6281.9569.9364.19


表 3: 比较在 IMDb 上微调的预训练语言模型的结果。

ModelRoBERTa-baseRoBERTa-largeLongformer-baseLS-baseLS-large
Accuracy95.396.595.796.096.8


为了评估长短 Transformer 作为长文本的双向编码器,我们在三个 NLP 任务上训练我们的模型,即最近提出的长程竞技场 (LRA) 基准[20]中的ListOpsTextRetrieval,遵循Peng 等人 [30]Tay 等人 [46]的设置。 为了进行公平比较,我们使用 PyTorch 实现以及与[18]中相同的数据预处理/分割、训练超参数和模型大小,除了Retrieval,我们意外地使用了更多预热步骤,并改进了所有模型的结果。 更多细节请参见附录B。 这三个任务的结果在表1中给出。 LRA其他两个基于图像的任务的结果,以及在JAX中实现的模型结果,见附录CC.2

此外,我们遵循Longformer[14]的预训练流程,基于RoBERTa-base和RoBERTa-large[47]对我们的模型进行预训练,并在IMDb情感分类数据集上对其进行微调。 结果见表3

结果。 从表3可以看出,我们的基础模型优于Longformer-base,我们的大型模型比RoBERTa-large有所改进,这证明了学习对长序列建模的益处。 与LRA上模型的比较结果见表1。 具有每个任务最佳配置的Transformer-LS(最佳)结果见附录B中的表7w,r 我们还报告了在所有任务上使用固定超参数w=8,r=32的结果。 总的来说,我们的Transformer-LS(最佳)明显优于其他高效Transformer,并且在所有三个任务上,使用w,r=8,32的模型表现良好,同时计算量仅为其他高效Transformer的约50%到70%。 聚合局部和远程注意力的好处在ListOps上最为显著,这需要模型理解涉及长期和短期关系的树状结构。 在Retrieval中,其中测试了文档级编码能力,我们发现我们的全局注意力比窗口注意力更有效。 仅使用动态投影的测试精度比Linformer在Text上高约10%(即66.28对56.12),后者在序列长度上的方差最大(即标准差893)。 这表明与Linformer学习到的但固定的投影相比,动态投影在学习具有序列长度高方差的数据的表示方面具有更高的灵活性。 同样,Linformer、Nyströmformer和我们的模型在ListOps上优于全注意力,这表明它们可能具有更好的归纳偏差,并且高效Transformer的有效性可能超越效率本身。

动态投影的鲁棒性。 在表2中,我们将Linformer和提出的动态投影(DP)针对LRA的文本和检索任务上的插入和删除的鲁棒性进行了比较。 我们使用原始的、干净的训练集训练模型,只扰动它们的测试集。 对于插入操作,我们在每个测试样本的10个随机位置插入10个随机标点符号。 对于删除操作,我们删除测试样本中的所有标点符号。 这两种变换在大多数情况下都保留标签。 通过设计,动态投影对位置变化更鲁棒。

4.2自回归语言建模

我们将我们的方法与其他高效的Transformer模型在字符级语言建模任务上进行了比较,其中每个输入符元都是一个字符。

设置。   我们在enwik8和text8数据集上训练和评估我们的模型,每个数据集包含1亿个字符,并按照[48]中的方法将其划分为9000万、500万和500万用于训练、验证和测试。 我们较小的12层模型和较大的30层模型都是预归一化Transformer,其宽度和深度与Longformer [20]相同,不同之处在于我们在每一层的投影片段中添加了相对位置编码。 我们采用了Transformer-XL [9]的缓存机制,将缓存大小设置为与输入序列长度相同。 我们遵循与Longformer类似的训练计划,并分三个阶段训练我们的模型,每个阶段的序列长度递增。 三个阶段的输入序列长度分别为2048、4096和8192。 相比之下,Longformer使用具有48GB内存的GPU分五个阶段训练其模型(我们的最大内存为32GB),其中最后一个阶段的序列长度为23040。 Longformer的窗口大小随着深度的增加而增加,其第五阶段的平均窗口大小为4352,而我们最后一个阶段平均有效的被关注符元数量为1280。 每个实验在大约8个V100 GPU上运行大约8天才能完成。 详细的超参数见附录D。在测试中,与Longformer相同,我们将数据集分割成长度为32K、步长为512的重叠序列,并评估在给定前面32K个字符的情况下预测接下来512个符元的BPC。

图3:Transformer-XL(全注意力)和我们在Char-LM上的Transformer-LS的运行时间和内存消耗。 我们将序列长度增加到V100 GPU的32GB内存用完为止。 Transformer-LS与表4中的较小模型相同。 我们使用虚线表示全注意力Transformer,使用实线表示我们的模型。 我们使用不同的颜色表示不同的批量大小。

结果 表4显示了在text8和enwik8上的比较。 我们的方法取得了最先进的结果。 在text8上,我们使用较小模型实现了1.09的测试BPC。 在enwik8上,我们的较小模型实现了0.99的测试BPC,并且优于参数数量相当的最先进模型。 我们的较大模型获得了0.97的测试BPC,与具有2×个参数的压缩Transformer不相上下。 我们的结果始终优于Longformer,后者是在更长的序列上训练的,具有5个阶段和48个GPU内存。 在图3中,我们展示了我们的模型比全注意力模型在内存和计算方面效率更高。

表4:较小模型在enwik8和text8上的BPC()(左),以及较大模型在enwik8上的BPC(右)。

Method#Paramtext8enwik8
DevTestDevTest
T12 [49] 44M-1.18-1.11
Transformer-XL [9] 41M---1.06
Reformer [31] ----1.05
Adaptive [50] 38M1.051.111.041.02
BP-Transformer [51] 38M-1.11-1.02
Longformer [20] 41M1.041.101.021.00
Transformer-LS44M1.031.091.010.99


Method#ParamTest BPC
Transformer-XL [9] 88M1.03
Transformer-XL [9] 277M0.99
Routing [32] 223M0.99
Longformer [14] 102M0.99
Sparse [12] 95M0.99
Adaptive [50] 209M0.98
Compressive [10] 227M0.97
Transformer-LS110M0.97


表 5: 在 ImageNet、ImageNet Real [52] 和 ImageNet V2 [53] 上基于 ImageNet-1K 训练的模型的测试准确率。 灰色行表示我们的结果。 CvT-LS 表示我们基于非官方 CvT 实现的基于长短期注意力的模型。 带有 LS 后缀的 ViL 模型是基于官方 ViL 实现和相对位置偏差的基于长短期注意力的模型。 我们还在同一 V100 GPU 上使用批量大小为 32 的测试对模型的延迟进行了测试。 我们对 ViL 的改进主要来自对短期注意力的更好实现。

Model#ParamImageFLOPsImageNetRealV2Latency

(M)Size(G)top-1 (%)top-1 (%)top-1 (%)(s)
ResNet-502522424.176.282.563.3-
ResNet-1014522427.977.483.765.7-
ResNet-1526022421178.384.167.0-
DeiT-S [36] 2222424.679.885.768.5-
DeiT-B [36] 86224217.681.886.770.9-
PVT-Medium [5] 4422426.781.2---
PVT-Large [5] 6122429.881.7---
Swin-S [38] 5022428.783.2---
Swin-B [38] 88224215.483.5--0.115
PVTv2-B4 [54] 62.6224210.183.6---
PVTv2-B5 [54] 82.0224211.883.8---
ViT-B/16 [4] 86384255.577.9---
ViT-L/16 [4] 3073842191.176.5---
DeiT-B [36] 86384255.583.1---
Swin-B [38] 88384247.184.5--0.378
CvT-13 [6] 2022426.781.686.770.40.122
CvT-21 [6] 32224210.182.587.271.30.165
CvT-LS-1320.322424.981.987.070.50.083
CvT-LS-1723.722429.882.587.271.6-
CvT-LS-2132.122427.982.787.571.90.122
CvT-LS-21S30.1224211.382.987.471.7-
CvT-13 [6] 20384231.983.087.971.9-
CvT-21 [6] 32384245.083.387.771.9-
CvT-LS-2132.1384223.983.288.072.5-
CvT-LS-2132.1448234.283.688.272.9-
ViL-Small [14] 24.622424.982.4---
ViL-Medium [14] 39.722428.783.5--0.106
ViL-Base [14] 55.7224213.483.7--0.164
ViL-LS-Medium39.822428.783.8--0.075
ViL-LS-Base55.8224213.484.1--0.113
ViL-LS-Medium39.9384228.784.4--0.271


4.3ImageNet 分类

我们在包含 130 万张图像和 1000 个类别的 ImageNet-1K 数据集上训练和评估模型。 我们使用 CvT [6] 和 ViL [11](最先进的视觉 Transformer 架构)作为骨干网络,并将其注意力机制替换为我们的长短期注意力机制,在表 5 中分别表示为 CvT-LS 和 ViL-size-LS。 CvT 使用重叠卷积从输入图像和特征图中提取密集的 patch embedding,从而在早期阶段产生较长的序列长度(例如,对于具有 2242 像素的图像,有 56×56=3136 个 patches)。 对于 ViL,我们的滑动窗口使用相同的组大小 w,但每个符元最多关注窗口内的 2w×2w 个符元(必要时进行四舍五入),而不是像 ViL 中那样关注 3w×3w 个符元,这使得可以在不增加 FLOPs 的情况下添加我们的动态投影。 我们为 ViL-LS-Medium 和 ViL-LS-Base 的动态投影设置 r=8。 请注意,我们的高效注意力机制不依赖于特定的架构,它可以应用于其他视觉 Transformer [例如,4, 36, 5]。 请参考附录 E 获取更多详细信息。

分类结果。 结果如表5所示,其中我们还列出了在ImageNet Real和ImageNet V2上的测试准确率。 除CvT外,我们将我们的方法与原始ViT[4]和改进的DeiT[36]、PVT[5](也使用多尺度策略)以及ViL[11](使用窗口注意力和全局符元来提高效率)进行了比较。 高分辨率训练通常会提高视觉Transformer的测试准确率。 通过我们的长短期注意力机制,我们可以轻松地将训练扩展到更高分辨率,CvT-LS和ViL-LS的性能也得到了提高。 我们最好的CvT模型(CvT-LS-21 at 4482)在使用相同数量的参数和76%的FLOPs的情况下,比CvT的最佳报告结果提高了0.3%的准确率。 在CvT架构中,早期阶段特征图的空间维度较大,代表了图像更细粒度的细节。 与高分辨率图像训练类似,该模型也应该受益于更密集的特征图。 通过我们高效的长短期注意力机制,我们可以更好地利用这些细粒度的特征图,而无需过多考虑计算预算。 通过这种方式,我们的CvT-LS-17在224分辨率下取得比CvT-21更好的结果,同时使用了更少的参数和FLOPs,而我们的CvT-LS-21S模型进一步改进了我们的CvT-LS-21模型。

我们的ViL-LS-Medium和ViL-LS-Base采用长短期注意力机制,分别将ViL-Medium和ViL-Base的准确率从83.5和83.7提高到83.8和84.1,而FLOPs没有增加。 当将ViL-LS-Medium的训练分辨率从2242提高到3842时,FLOPs近似线性增加,准确率提高了0.6%,这表明我们的方法仍然可以从更高的分辨率中获益良多,同时在实践中保持线性复杂度。

短期注意力抑制过度平滑。 通过限制不同片段的符元关注不同的窗口,我们的短期稀疏局部注意力促进了特征表示的多样性,并有助于缓解过度平滑问题[55](其中所有查询在更深层提取类似的信息,注意力机制的重要性降低),从而可以充分利用网络的深度。 如[55]所示,我们在附录的图6中提供了我们的CvT-LS-13和重新实现的CvT-13(81.1精度)的块嵌入的余弦相似度。 这是我们的高效注意力机制在相同设置下甚至能够获得比全注意力CvT模型更好结果的原因之一。

在多样化ImageNet数据集上的鲁棒性评估。

表 6: 各种 ImageNet 数据集上的鲁棒性评估。 Top-1/准确率。:Top-1 准确率。 mCE:平均损坏误差。 Mixed-same/Mixed-rand:Mixed-Same/Mixed-Rand 子集上的准确率。

ModelParamsImageNetIN-C [56]IN-A [57]IN-R [58]ImageNet-9 [59]

(M)Top-1mCE ()Acc.Acc.Mixed-sameMixed-rand
ResNet-50 [35] 25.676.278.96.235.387.181.6
DeiT-S [36] 22.179.857.119.041.989.184.2
CvT-132081.659.625.442.990.585.7
CvT-213282.556.231.142.690.585.0
CvT-LS-1320.381.958.727.042.690.785.6
CvT-LS-2132.182.755.229.345.091.585.8


由于视觉模型已广泛应用于安全关键型应用(例如自动驾驶),因此它们的鲁棒性至关重要。 除了分布外鲁棒性(ImageNet-Real 和 ImageNet-v2)之外,我们还进一步研究了我们的视觉 Transformer 对常见损坏(ImageNet-C)、语义变化(ImageNet-R)、背景依赖性(ImageNet-9)和自然对抗样本(ImageNet-A)的鲁棒性。 我们将我们的方法与标准分类方法进行比较,包括基于 CNN 的模型(ResNet [35])和参数数量相似的基于 Transformer 的模型(DeiT [36])。 如表 6 所示,我们观察到我们的方法明显优于基于 CNN 的方法(ResNet-50)。 与 DeiT 相比,我们的模型也取得了可喜的改进。 这些结果表明,不同注意力机制的设计对模型的鲁棒性起着重要作用,这为鲁棒视觉 Transformer 的设计提供了新的思路。 更多细节和结果可在附录 E 中找到。

5结论

在本文中,我们介绍了 Long-Short Transformer,这是一种用于语言和视觉领域长序列建模的高效 Transformer,包括双向和自回归模型。 我们设计了一种新颖的全局注意力机制,其计算和内存复杂度在序列长度上呈线性关系,基于动态投影。 We identify the scale mismatch issue and propose the DualLN technique to eliminate the mismatch at initialization and more effectively aggregate the local and global attentions. We demonstrate that our method obtains the state-of-the-art results on the Long Range Arena, char-level language modeling and ImageNet classification. We look forward to extending our methods to more domains, including document QA, object detection and semantic segmentation on high-resolution images.



arXiv每日学术速递
工作日更新学术速递!官网www.arxivdaily.com。
 最新文章