论文介绍
题目:An Attention Free Transformer
论文地址:https://arxiv.org/pdf/2105.14103v1
QQ深度学习交流群:994264161
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
完全去除了点积自注意力机制:提出了一种称为Attention-Free Transformer (AFT) 的新模型,完全摆脱了传统Transformer模型中复杂的点积自注意力操作,采用一种更高效的基于查询(Query)、键(Key)和值(Value)的交互方式。
线性复杂度设计:AFT通过重新排列Query、Key、Value的计算顺序,显著降低了计算和内存的复杂度,从传统Transformer的二次复杂度(O(T²d))降为线性复杂度(O(Td))。这使得AFT能够更高效地处理大规模输入和模型。
局部性和空间共享权重变体:引入了AFT-local和AFT-conv模型,这些变体利用了局部性和空间权重共享的概念,同时保留了全局连接的特点,在性能和效率上都表现出显著的提升。
强大的通用性和插拔设计:AFT被设计为对Transformer架构的模块化替代,能够轻松替换自注意力模块,且无需改变其他架构组件,这使得它可以直接应用于多种任务。
广泛的实验验证:论文通过在多个任务上的实验(包括图像生成、语言建模、图像分类等)验证了AFT的有效性,展示了它在效率和性能上的竞争优势。例如,AFT在CIFAR10和ImageNet-1K上均达到了与或优于标准Transformer的性能,同时显著降低了内存和计算开销。
参数高效性:AFT通过位置偏差参数的因子分解,减少了参数数量,提高了模型的可扩展性和训练效率。
方法
整体架构
Attention-Free Transformer (AFT) 是一种改进的 Transformer 架构,去除了传统的点积自注意力机制,采用基于逐元素加权平均的高效注意力计算,同时通过学习相对位置偏差和因子分解减少参数量。其结构包括线性变换生成的 Q/K/V 交互模块、残差连接、层归一化和前馈网络,并设计了 AFT-local 和 AFT-conv 等变体,分别通过局部窗口限制和卷积权重共享进一步优化性能和效率,适用于语言建模、图像生成和分类等任务。
1. 基本模块组成
AFT 模型的基本结构是围绕Transformer的经典模块设计的,包含以下主要组件:
输入变换:首先将输入序列
通过线性变换生成Query (Q)、Key (K) 和 Value (V)。X X Q = X W Q , K = X W K , V = X W V 其中,
、W Q W_Q 和W K W_K 是线性变换的权重矩阵。W V W_V 新型注意力计算:不同于传统的点积自注意力,AFT对
、Q Q 、K K 的交互采用了一种基于加权平均的计算方式:V V Y t = σ q ( Q t ) ⊙ ∑ t ′ = 1 T exp ( K t ′ + w t , t ′ ) ⊙ V t ′ ∑ t ′ = 1 T exp ( K t ′ + w t , t ′ ) 其中,
是学习到的相对位置偏差,w t , t ′ w_{t, t'} 是对Query的非线性变换(默认使用Sigmoid函数),σ q \sigma_q 表示逐元素相乘。⊙ \odot 残差连接和归一化:如同经典Transformer架构,AFT在每个层中包含残差连接和层归一化,确保训练的稳定性。
两层前馈网络 (FFN):每个AFT层后连接一个两层的前馈网络,类似于标准Transformer中的结构。
2. 模型变体
AFT 提出了几种变体,针对不同任务进行了优化:
AFT-full:完整版本,计算全局范围的注意力,适用于较小的输入序列。
AFT-local:限制位置偏差
的范围,仅在局部窗口内计算,提升了计算效率,同时仍然保持全局连接。w t , t ′ w_{t, t'} AFT-conv:结合了卷积的思想,进一步引入了空间权重共享,尤其适合图像处理任务。
3. 创新点细节
位置偏差学习:AFT通过学习相对位置偏差
来增强模型对位置相关性的建模能力。对于AFT-local和AFT-conv,这些偏差可以限制在局部窗口中。w t , t ′ w_{t, t'} 参数化技巧:为了减少模型参数,AFT对位置偏差
w t , t 使用了因子分解: w t , t ′ = u t T v t ′ w_{t, t'} = u_t^T v_{t'} 其中,
和u t u_t 是低维嵌入向量。v t ′ v_{t'} 卷积扩展:在AFT-conv中,位置偏差被扩展为卷积核,通过卷积操作实现权重共享。
4. 整体结构可视化
AFT的整体结构保留了Transformer的模块化特点,同时用更高效的AFT模块替代了传统的自注意力模块。这种结构可以灵活应用于多种任务,包括语言建模、图像生成和分类等
即插即用模块作用
AFT 作为一个即插即用模块:
长序列建模任务
例如:自然语言处理中的字符级语言建模(如 Enwik8 数据集)。
作用:AFT 的线性复杂度能够处理传统 Transformer 难以扩展的大规模序列,显著降低了内存和计算成本。
图像生成与分类任务
在图像生成中,AFT-local 和 AFT-conv 能有效捕获局部和全局信息,平衡计算效率与性能。
在图像分类中,AFT-conv 模块结合卷积操作,兼具卷积神经网络(CNN)的局部感知优势和 Transformer 的全局连接能力。
例如:CIFAR-10 图像生成和 ImageNet-1K 图像分类任务。
作用:
资源受限场景
例如:需要在低内存占用和高速度需求下完成推理或训练的场景。
作用:相比标准 Transformer,AFT 模块极大地降低了内存消耗和计算复杂度,适合嵌入式设备或大规模分布式训练。
可扩展模型设计
作用:AFT-conv 模块完全卷积化,能够灵活适应不同输入大小的变化,无需重新设计模型架构
例如:需要处理可变输入大小的任务(如不同分辨率的图像分类)。
消融实验结果
内容:
表 3 比较了在 CIFAR-10 上位置偏差参数是否采用因子分解对训练和测试损失的影响,因子分解显著减少了参数量(每层从 9.6M 降至 0.6M),同时提升了训练和测试性能。
表 9 说明在 ImageNet-1K 上,因子分解后的模型训练损失从 3.17 降至 3.08,Top-1 准确率从 78.2% 提升至 79.8%。
说明:位置偏差的因子分解不仅减少了参数量,还有效提升了模型性能。
内容:
表 5 比较了不同窗口大小
对 AFT-local 在 Enwik8 任务上的训练和测试损失的影响,结果表明窗口大小为 32 时性能最佳。s s 表 11 比较了 AFT-conv 在 ImageNet-1K 上不同卷积核大小对训练损失和准确率的影响,发现 7×7 和 11×11 的核大小表现最优。
说明:局部窗口的设置可以作为有效的归纳偏置(inductive bias),在性能和计算效率之间找到平衡。
即插即用模块
import numpy as np
import torch
from torch import nn
from torch.nn import init
# 论文地址:https://arxiv.org/pdf/2105.14103v1
# 论文:An Attention Free Transformer
class AFT_FULL(nn.Module):
def __init__(self, d_model,n=49,simple=False):
super(AFT_FULL, self).__init__()
self.fc_q = nn.Linear(d_model, d_model)
self.fc_k = nn.Linear(d_model, d_model)
self.fc_v = nn.Linear(d_model,d_model)
if(simple):
self.position_biases=torch.zeros((n,n))
else:
self.position_biases=nn.Parameter(torch.ones((n,n)))
self.d_model = d_model
self.n=n
self.sigmoid=nn.Sigmoid()
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, input):
bs, n,dim = input.shape
q = self.fc_q(input) #bs,n,dim
k = self.fc_k(input).view(1,bs,n,dim) #1,bs,n,dim
v = self.fc_v(input).view(1,bs,n,dim) #1,bs,n,dim
numerator=torch.sum(torch.exp(k+self.position_biases.view(n,1,-1,1))*v,dim=2) #n,bs,dim
denominator=torch.sum(torch.exp(k+self.position_biases.view(n,1,-1,1)),dim=2) #n,bs,dim
out=(numerator/denominator) #n,bs,dim
out=self.sigmoid(q)*(out.permute(1,0,2)) #bs,n,dim
return out
if __name__ == '__main__':
input=torch.randn(50,49,512)
block = AFT_FULL(d_model=512, n=49)
output=block(input) print(output.shape)
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文