(AAAI 2024) 即插即用时空动态特征提取和优化的轻量型注意力模块GAU,涨点涨爆了

文摘   2024-11-24 17:20   中国香港  

论文介绍

题目:Gated Attention Coding for Training High-performance and Efficient Spiking Neural Networks

论文地址:https://arxiv.org/pdf/2308.06582

QQ深度学习交流群:719278780

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 提出了门控注意力编码(Gated Attention Coding, GAC)

    • 这是一个插件式模块,利用多维度门控注意力单元(Gated Attention Unit, GAU)对输入进行高效编码,从而生成更强大的时空动态表征。

    • GAC作为一个预处理层,不破坏脉冲神经网络(Spiking Neural Networks, SNN)的脉冲驱动特性,可以在神经形态硬件上以最小修改实现高效的部署。

  • 改进了编码效率和时空动态特性

    • 与传统的直接编码(Direct Coding)相比,GAC通过门控注意力机制引入时空动态性,克服了传统编码生成周期性无效脉冲表征的局限性。

    • 提升了在静态数据集上的动态信息利用率,同时降低了冗余信息。

  • 创新性地将注意力机制应用于深度SNN编码

    • 首次在深度SNN中探索基于注意力的动态编码方案,而不是将注意力机制直接应用于SNN架构的每一层。

    • 避免了传统注意力机制破坏脉冲驱动通信的弊端,并显著提升了SNN的性能和效率。

  • 提升了性能和能源效率

    • 在CIFAR10/100和ImageNet数据集上实现了当前最先进的准确率,且显著降低了能耗。例如,在CIFAR100数据集上,GAC-SNN以仅6个时间步实现了80.45%的准确率,相比现有方法提高了3.10%的准确率,同时将能耗降低到之前方法的66.9%。

    • 提出了一种理论能耗计算模型,量化了不同架构的能源效率。

  • 理论分析与实验验证

    • 提出了基于观测模型的理论分析框架,用以量化编码方案的动态持续时间和信息熵,从理论上证明了GAC在动态持续时间上的优势。

    • 通过实验展示了GAC的时空动态特性和编码效果,与直接编码相比具有显著的改进。

方法

整体结构

       论文的整体结构包括一个 门控注意力编码(GAC) 和一个基于 残差连接的深度脉冲神经网络(SNN)。GAC作为编码器,通过时间注意力、空间通道注意力和门控机制,将静态输入图像转换为动态的时空特征,同时保留硬件友好的脉冲驱动特性。随后,这些编码特征被输入到深度SNN(如MS-ResNet),通过多层残差结构进行特征提取和分类。整个框架高效融合了注意力机制和脉冲神经网络的特点,实现了高性能和低能耗的目标。

1. 编码器部分:门控注意力编码(Gated Attention Coding, GAC)

GAC是模型的核心创新,用于对输入数据进行预处理,生成时空动态的表征。它由以下三个模块组成:

(1) 门控注意力单元(Gated Attention Unit, GAU)

  • 时间注意力(Temporal Attention)

    • 对输入的时间维度进行注意力处理,捕获时间上的相关性。

    • 使用平均池化和最大池化计算时间权重,然后通过共享多层感知器(MLP)生成时间权重向量。

  • 空间通道注意力(Spatial Channel Attention)

    • 针对每个时间步,使用2D卷积提取空间通道动态特性,生成空间通道矩阵。

  • 门控(Gating)

    • 将时间注意力权重和空间通道矩阵结合,通过逐元素乘积融合时空动态特性,最终生成GAU的输出。

(2) GAC编码过程

  • 输入静态图像数据,通过卷积层生成特征后,将结果重复多次以引入时间维度。

  • 将上述结果分别送入脉冲神经元模型和GAU模块,并对两者的输出进行门控操作,生成最终编码特征。


2. SNN部分:基于残差连接的深度SNN架构

GAC处理后的特征作为输入送入深度SNN架构,进一步提取特征并完成分类任务。论文中的SNN架构基于 ResNet,采用以下设计:

(1) 残差连接方式

  • Membrane Shortcut (MS) ResNet:

    • 残差连接在不同层之间传递脉冲神经元的膜电位,而非直接传递脉冲信号。

    • 保持了脉冲神经网络的脉冲驱动特性(Spike-driven),使其更适合在神经形态硬件上实现。

(2) 网络结构

  • 结合经典的卷积层(Conv)、批量归一化(BN)和脉冲神经元(如LIF模型)。

  • 通过残差结构实现深度SNN的稳定训练和特征提取能力。

(3) 注意力机制的独立性

  • 注意力模块仅作用于编码器部分,不影响SNN主体架构的脉冲驱动特性。

  • 避免了其他SNN注意力机制方法需要动态调整每层注意力权重的问题,从而保持硬件友好性。


整体流程

  1. 静态输入图像 → 编码器(GAC):生成动态的时空编码特征。

  2. GAC编码特征 → 深度SNN(MS-ResNet):通过多层残差网络提取时空特征并完成分类任务。

即插即用模块作用

GAU 作为一个即插即用模块,主要适用于:

  • 提升时空动态特性

    • 通过时间注意力(Temporal Attention)和空间通道注意力(Spatial Channel Attention),GAU捕获输入数据在时间和空间维度的动态特性,使得生成的特征更具表达能力。

    • 克服了传统直接编码(Direct Coding)生成周期性无效表征的问题,有效延长了编码的动态持续时间。

  • 兼容性与模块化

    • GAU是一个独立的编码模块,可作为预处理层无缝集成到各种SNN架构中,而不影响其原有的脉冲驱动特性。

    • 对于传统的SNN设计,GAU能够显著增强其特征提取能力而无需大幅修改架构。

  • 降低冗余与提高效率

    • 通过门控机制(Gating)融合时间和空间特征,减少冗余信息,优化了信息利用率,从而降低了计算复杂性和能源消耗。

  • 硬件友好性

    • 设计考虑了神经形态硬件的实现特点,将注意力模块限制在编码层,从而保留了后续SNN架构的脉冲驱动特性(Spike-driven),便于在低能耗硬件上高效运行。

消融实验结果

  • 展示了空间通道注意力模块中2D卷积核大小KK 对性能的影响,实验表明随着KK的增大,模型的分类准确率逐渐提升,但当K>4K > 4 时性能趋于平稳,表明适当的感受野能够有效提升特征提取能力,而过大的卷积核可能增加计算开销。因此,最终选择K=4K = 4 作为最优配置,在性能与效率间取得了良好平衡。

  • 评估了时间注意力(TA)和空间通道注意力(SCA)模块的独立贡献。实验发现,空间通道注意力对性能的提升作用更显著,这是由于SNN中通道数量通常多于时间步,SCA能捕获更丰富的特征信息。然而,无论移除哪个模块,性能均有所下降,表明时间和空间注意力的结合能够协同作用,实现最佳效果。


  • 比较了不同编码方案(Phase Coding、Temporal Coding、Rate Coding、Direct Coding、GAC)的性能,在CIFAR10数据集上,GAC以96.46%的准确率显著优于其他方案,并且仅需6个时间步,展现了生成动态时空表征的能力和高效性。相比传统编码方案,GAC能够更充分地挖掘时间和空间信息,表现出卓越的综合性能

即插即用模块

import torch
import torch.nn as nn

class TA(nn.Module):
    def __init__(self, T,ratio=2):

        super(TA, self).__init__()

        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.max_pool = nn.AdaptiveMaxPool3d(1)
        self.sharedMLP = nn.Sequential(
            nn.Conv3d(T, T // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv3d(T // ratio, T, 1,bias=False),
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg = self.avg_pool(x)
        # B,T,C
        out1 = self.sharedMLP(avg)
        max = self.max_pool(x)
        # B,T,C
        out2 = self.sharedMLP(max)
        out = out1+out2

        return out

# task classifictaion or generation
class SCA(nn.Module):
    def __init__(self, in_planes, kerenel_size,ratio = 1):
        super(SCA, self).__init__()
        self.sharedMLP = nn.Sequential(
                nn.Conv2d(in_planes, in_planes // ratio, kerenel_size, padding='same', bias=False),
                nn.ReLU(),
                nn.Conv2d(in_planes // ratio, in_planes, kerenel_size, padding='same', bias=False),)
    def forward(self, x):
        b,t, c, h, w = x.shape
        x = x.flatten(0,1)
        x = self.sharedMLP(x)
        out = x.reshape(b,t, c, h, w)
        return out
if __name__ == '__main__':

    block1 = TA(T=10) # 假设输入有10个时间步长
    print("TA模型结构:\n", block1)

    # 创建SCA模型
    block2 = SCA(in_planes=64, kerenel_size=3) # 假设输入通道数为64
    print("\nSCA模型结构:\n", block2)

    # 创建随机输入数据
    batch_size = 4
    time_steps = 10
    channels = 64
    height = 32
    width = 32
    input = torch.randn(batch_size, time_steps, channels, height, width)
    print("\n输入数据形状:", input.size())

    # 测试TA模型
    output = block1(input)
    print("TA模型输出形状:", output.shape)

    # 测试SCA模型
    output2 = block2(input)
    print("SCA模型输出形状:", output2.shape)

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章