无注意力Transformer模块AFT,涨点起飞起飞了!

文摘   2024-12-28 17:20   中国香港  

论文介绍

题目:An Attention Free Transformer

论文地址:https://arxiv.org/pdf/2105.14103v1

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 完全去除了点积自注意力机制:提出了一种称为Attention-Free Transformer (AFT) 的新模型,完全摆脱了传统Transformer模型中复杂的点积自注意力操作,采用一种更高效的基于查询(Query)、键(Key)和值(Value)的交互方式。

  • 线性复杂度设计:AFT通过重新排列Query、Key、Value的计算顺序,显著降低了计算和内存的复杂度,从传统Transformer的二次复杂度(O(T²d))降为线性复杂度(O(Td))。这使得AFT能够更高效地处理大规模输入和模型。

  • 局部性和空间共享权重变体:引入了AFT-local和AFT-conv模型,这些变体利用了局部性和空间权重共享的概念,同时保留了全局连接的特点,在性能和效率上都表现出显著的提升。

  • 强大的通用性和插拔设计:AFT被设计为对Transformer架构的模块化替代,能够轻松替换自注意力模块,且无需改变其他架构组件,这使得它可以直接应用于多种任务。

  • 广泛的实验验证:论文通过在多个任务上的实验(包括图像生成、语言建模、图像分类等)验证了AFT的有效性,展示了它在效率和性能上的竞争优势。例如,AFT在CIFAR10和ImageNet-1K上均达到了与或优于标准Transformer的性能,同时显著降低了内存和计算开销。

  • 参数高效性:AFT通过位置偏差参数的因子分解,减少了参数数量,提高了模型的可扩展性和训练效率。

方法

整体架构

     Attention-Free Transformer (AFT) 是一种改进的 Transformer 架构,去除了传统的点积自注意力机制,采用基于逐元素加权平均的高效注意力计算,同时通过学习相对位置偏差和因子分解减少参数量。其结构包括线性变换生成的 Q/K/V 交互模块、残差连接、层归一化和前馈网络,并设计了 AFT-local 和 AFT-conv 等变体,分别通过局部窗口限制和卷积权重共享进一步优化性能和效率,适用于语言建模、图像生成和分类等任务。

1. 基本模块组成

AFT 模型的基本结构是围绕Transformer的经典模块设计的,包含以下主要组件:

  • 输入变换:首先将输入序列XX 通过线性变换生成Query (Q)、Key (K) 和 Value (V)。

    Q=XWQ,K=XWK,V=XWV

    其中,WQW_QWKW_K 和WVW_V 是线性变换的权重矩阵。

  • 新型注意力计算:不同于传统的点积自注意力,AFT对QQKKVV 的交互采用了一种基于加权平均的计算方式:

    Yt=σq(Qt)t=1Texp(Kt+wt,t)Vtt=1Texp(Kt+wt,t)

    其中,wt,tw_{t, t'}是学习到的相对位置偏差,σq\sigma_q是对Query的非线性变换(默认使用Sigmoid函数),\odot 表示逐元素相乘。

  • 残差连接和归一化:如同经典Transformer架构,AFT在每个层中包含残差连接和层归一化,确保训练的稳定性。

  • 两层前馈网络 (FFN):每个AFT层后连接一个两层的前馈网络,类似于标准Transformer中的结构。

2. 模型变体

AFT 提出了几种变体,针对不同任务进行了优化:

  • AFT-full:完整版本,计算全局范围的注意力,适用于较小的输入序列。

  • AFT-local:限制位置偏差wt,tw_{t, t'} 的范围,仅在局部窗口内计算,提升了计算效率,同时仍然保持全局连接。

  • AFT-conv:结合了卷积的思想,进一步引入了空间权重共享,尤其适合图像处理任务。

3. 创新点细节

  • 位置偏差学习:AFT通过学习相对位置偏差wt,tw_{t, t'}来增强模型对位置相关性的建模能力。对于AFT-local和AFT-conv,这些偏差可以限制在局部窗口中。

  • 参数化技巧:为了减少模型参数,AFT对位置偏差wt,t 使用了因子分解:

    wt,t=utTvtw_{t, t'} = u_t^T v_{t'}

    其中,utu_t 和vtv_{t'} 是低维嵌入向量。

  • 卷积扩展:在AFT-conv中,位置偏差被扩展为卷积核,通过卷积操作实现权重共享。

4. 整体结构可视化

AFT的整体结构保留了Transformer的模块化特点,同时用更高效的AFT模块替代了传统的自注意力模块。这种结构可以灵活应用于多种任务,包括语言建模、图像生成和分类等

即插即用模块作用

AFT 作为一个即插即用模块

  • 长序列建模任务

    • 例如:自然语言处理中的字符级语言建模(如 Enwik8 数据集)。

    • 作用:AFT 的线性复杂度能够处理传统 Transformer 难以扩展的大规模序列,显著降低了内存和计算成本。

  • 图像生成与分类任务

    • 在图像生成中,AFT-local 和 AFT-conv 能有效捕获局部和全局信息,平衡计算效率与性能。

    • 在图像分类中,AFT-conv 模块结合卷积操作,兼具卷积神经网络(CNN)的局部感知优势和 Transformer 的全局连接能力。

    • 例如:CIFAR-10 图像生成和 ImageNet-1K 图像分类任务。

    • 作用

  • 资源受限场景

    • 例如:需要在低内存占用和高速度需求下完成推理或训练的场景。

    • 作用:相比标准 Transformer,AFT 模块极大地降低了内存消耗和计算复杂度,适合嵌入式设备或大规模分布式训练。

  • 可扩展模型设计

  •     例如:需要处理可变输入大小的任务(如不同分辨率的图像分类)。

    • 作用:AFT-conv 模块完全卷积化,能够灵活适应不同输入大小的变化,无需重新设计模型架构

消融实验结果

  • 内容

    • 表 3 比较了在 CIFAR-10 上位置偏差参数是否采用因子分解对训练和测试损失的影响,因子分解显著减少了参数量(每层从 9.6M 降至 0.6M),同时提升了训练和测试性能。

    • 表 9 说明在 ImageNet-1K 上,因子分解后的模型训练损失从 3.17 降至 3.08,Top-1 准确率从 78.2% 提升至 79.8%。

  • 说明:位置偏差的因子分解不仅减少了参数量,还有效提升了模型性能。


  • 内容

    • 表 5 比较了不同窗口大小ss 对 AFT-local 在 Enwik8 任务上的训练和测试损失的影响,结果表明窗口大小为 32 时性能最佳。

    • 表 11 比较了 AFT-conv 在 ImageNet-1K 上不同卷积核大小对训练损失和准确率的影响,发现 7×7 和 11×11 的核大小表现最优。

  • 说明:局部窗口的设置可以作为有效的归纳偏置(inductive bias),在性能和计算效率之间找到平衡。

即插即用模块

import numpy as np
import torch
from torch import nn
from torch.nn import init

# 论文地址:https://arxiv.org/pdf/2105.14103v1
# 论文:An Attention Free Transformer


class AFT_FULL(nn.Module):

    def __init__(self, d_model,n=49,simple=False):

        super(AFT_FULL, self).__init__()
        self.fc_q = nn.Linear(d_model, d_model)
        self.fc_k = nn.Linear(d_model, d_model)
        self.fc_v = nn.Linear(d_model,d_model)
        if(simple):
            self.position_biases=torch.zeros((n,n))
        else:
            self.position_biases=nn.Parameter(torch.ones((n,n)))
        self.d_model = d_model
        self.n=n
        self.sigmoid=nn.Sigmoid()

        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, input):

        bs, n,dim = input.shape

        q = self.fc_q(input) #bs,n,dim
        k = self.fc_k(input).view(1,bs,n,dim) #1,bs,n,dim
        v = self.fc_v(input).view(1,bs,n,dim) #1,bs,n,dim
        
        numerator=torch.sum(torch.exp(k+self.position_biases.view(n,1,-1,1))*v,dim=2) #n,bs,dim
        denominator=torch.sum(torch.exp(k+self.position_biases.view(n,1,-1,1)),dim=2) #n,bs,dim

        out=(numerator/denominator) #n,bs,dim
        out=self.sigmoid(q)*(out.permute(1,0,2)) #bs,n,dim

        return out


if __name__ == '__main__':
    input=torch.randn(50,49,512)
    block = AFT_FULL(d_model=512, n=49)
    output=block(input)    print(output.shape)

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章