论文介绍
题目:Gated Attention Coding for Training High-performance and Efficient Spiking Neural Networks
论文地址:https://arxiv.org/pdf/2308.06582
QQ深度学习交流群:719278780
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
提出了门控注意力编码(Gated Attention Coding, GAC):
这是一个插件式模块,利用多维度门控注意力单元(Gated Attention Unit, GAU)对输入进行高效编码,从而生成更强大的时空动态表征。
GAC作为一个预处理层,不破坏脉冲神经网络(Spiking Neural Networks, SNN)的脉冲驱动特性,可以在神经形态硬件上以最小修改实现高效的部署。
改进了编码效率和时空动态特性:
与传统的直接编码(Direct Coding)相比,GAC通过门控注意力机制引入时空动态性,克服了传统编码生成周期性无效脉冲表征的局限性。
提升了在静态数据集上的动态信息利用率,同时降低了冗余信息。
创新性地将注意力机制应用于深度SNN编码:
首次在深度SNN中探索基于注意力的动态编码方案,而不是将注意力机制直接应用于SNN架构的每一层。
避免了传统注意力机制破坏脉冲驱动通信的弊端,并显著提升了SNN的性能和效率。
提升了性能和能源效率:
在CIFAR10/100和ImageNet数据集上实现了当前最先进的准确率,且显著降低了能耗。例如,在CIFAR100数据集上,GAC-SNN以仅6个时间步实现了80.45%的准确率,相比现有方法提高了3.10%的准确率,同时将能耗降低到之前方法的66.9%。
提出了一种理论能耗计算模型,量化了不同架构的能源效率。
理论分析与实验验证:
提出了基于观测模型的理论分析框架,用以量化编码方案的动态持续时间和信息熵,从理论上证明了GAC在动态持续时间上的优势。
通过实验展示了GAC的时空动态特性和编码效果,与直接编码相比具有显著的改进。
方法
整体结构
1. 编码器部分:门控注意力编码(Gated Attention Coding, GAC)
GAC是模型的核心创新,用于对输入数据进行预处理,生成时空动态的表征。它由以下三个模块组成:
(1) 门控注意力单元(Gated Attention Unit, GAU)
时间注意力(Temporal Attention):
对输入的时间维度进行注意力处理,捕获时间上的相关性。
使用平均池化和最大池化计算时间权重,然后通过共享多层感知器(MLP)生成时间权重向量。
空间通道注意力(Spatial Channel Attention):
针对每个时间步,使用2D卷积提取空间通道动态特性,生成空间通道矩阵。
门控(Gating):
将时间注意力权重和空间通道矩阵结合,通过逐元素乘积融合时空动态特性,最终生成GAU的输出。
(2) GAC编码过程
输入静态图像数据,通过卷积层生成特征后,将结果重复多次以引入时间维度。
将上述结果分别送入脉冲神经元模型和GAU模块,并对两者的输出进行门控操作,生成最终编码特征。
2. SNN部分:基于残差连接的深度SNN架构
GAC处理后的特征作为输入送入深度SNN架构,进一步提取特征并完成分类任务。论文中的SNN架构基于 ResNet,采用以下设计:
(1) 残差连接方式
Membrane Shortcut (MS) ResNet:
残差连接在不同层之间传递脉冲神经元的膜电位,而非直接传递脉冲信号。
保持了脉冲神经网络的脉冲驱动特性(Spike-driven),使其更适合在神经形态硬件上实现。
(2) 网络结构
结合经典的卷积层(Conv)、批量归一化(BN)和脉冲神经元(如LIF模型)。
通过残差结构实现深度SNN的稳定训练和特征提取能力。
(3) 注意力机制的独立性
注意力模块仅作用于编码器部分,不影响SNN主体架构的脉冲驱动特性。
避免了其他SNN注意力机制方法需要动态调整每层注意力权重的问题,从而保持硬件友好性。
整体流程
静态输入图像 → 编码器(GAC):生成动态的时空编码特征。
GAC编码特征 → 深度SNN(MS-ResNet):通过多层残差网络提取时空特征并完成分类任务。
即插即用模块作用
GAU 作为一个即插即用模块,主要适用于:
提升时空动态特性:
通过时间注意力(Temporal Attention)和空间通道注意力(Spatial Channel Attention),GAU捕获输入数据在时间和空间维度的动态特性,使得生成的特征更具表达能力。
克服了传统直接编码(Direct Coding)生成周期性无效表征的问题,有效延长了编码的动态持续时间。
兼容性与模块化:
GAU是一个独立的编码模块,可作为预处理层无缝集成到各种SNN架构中,而不影响其原有的脉冲驱动特性。
对于传统的SNN设计,GAU能够显著增强其特征提取能力而无需大幅修改架构。
降低冗余与提高效率:
通过门控机制(Gating)融合时间和空间特征,减少冗余信息,优化了信息利用率,从而降低了计算复杂性和能源消耗。
硬件友好性:
设计考虑了神经形态硬件的实现特点,将注意力模块限制在编码层,从而保留了后续SNN架构的脉冲驱动特性(Spike-driven),便于在低能耗硬件上高效运行。
消融实验结果
展示了空间通道注意力模块中2D卷积核大小
对性能的影响,实验表明随着K K 的增大,模型的分类准确率逐渐提升,但当K K 时性能趋于平稳,表明适当的感受野能够有效提升特征提取能力,而过大的卷积核可能增加计算开销。因此,最终选择K > 4 K > 4 作为最优配置,在性能与效率间取得了良好平衡。K = 4 K = 4
评估了时间注意力(TA)和空间通道注意力(SCA)模块的独立贡献。实验发现,空间通道注意力对性能的提升作用更显著,这是由于SNN中通道数量通常多于时间步,SCA能捕获更丰富的特征信息。然而,无论移除哪个模块,性能均有所下降,表明时间和空间注意力的结合能够协同作用,实现最佳效果。
比较了不同编码方案(Phase Coding、Temporal Coding、Rate Coding、Direct Coding、GAC)的性能,在CIFAR10数据集上,GAC以96.46%的准确率显著优于其他方案,并且仅需6个时间步,展现了生成动态时空表征的能力和高效性。相比传统编码方案,GAC能够更充分地挖掘时间和空间信息,表现出卓越的综合性能
即插即用模块
import torch
import torch.nn as nn
class TA(nn.Module):
def __init__(self, T,ratio=2):
super(TA, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool3d(1)
self.max_pool = nn.AdaptiveMaxPool3d(1)
self.sharedMLP = nn.Sequential(
nn.Conv3d(T, T // ratio, 1, bias=False),
nn.ReLU(),
nn.Conv3d(T // ratio, T, 1,bias=False),
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg = self.avg_pool(x)
# B,T,C
out1 = self.sharedMLP(avg)
max = self.max_pool(x)
# B,T,C
out2 = self.sharedMLP(max)
out = out1+out2
return out
# task classifictaion or generation
class SCA(nn.Module):
def __init__(self, in_planes, kerenel_size,ratio = 1):
super(SCA, self).__init__()
self.sharedMLP = nn.Sequential(
nn.Conv2d(in_planes, in_planes // ratio, kerenel_size, padding='same', bias=False),
nn.ReLU(),
nn.Conv2d(in_planes // ratio, in_planes, kerenel_size, padding='same', bias=False),)
def forward(self, x):
b,t, c, h, w = x.shape
x = x.flatten(0,1)
x = self.sharedMLP(x)
out = x.reshape(b,t, c, h, w)
return out
if __name__ == '__main__':
block1 = TA(T=10) # 假设输入有10个时间步长
print("TA模型结构:\n", block1)
# 创建SCA模型
block2 = SCA(in_planes=64, kerenel_size=3) # 假设输入通道数为64
print("\nSCA模型结构:\n", block2)
# 创建随机输入数据
batch_size = 4
time_steps = 10
channels = 64
height = 32
width = 32
input = torch.randn(batch_size, time_steps, channels, height, width)
print("\n输入数据形状:", input.size())
# 测试TA模型
output = block1(input)
print("TA模型输出形状:", output.shape)
# 测试SCA模型
output2 = block2(input)
print("SCA模型输出形状:", output2.shape)
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文