即插即用简化自注意力机制SSAN,涨点起飞起飞了!

文摘   2024-12-29 17:20   上海  

论文介绍

题目:Simplified Self-Attention for Transformer-Based end-to-end Speech Recognition

论文地址:https://arxiv.org/pdf/2005.10463

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 简化自注意力层(SSAN)的设计

    • 论文提出用**前馈序列记忆网络(FSMN)**替代传统自注意力网络(SAN)中的投影层,以生成查询(query)和键(key)向量。

    • 与传统SAN不同,SSAN直接将输入分配给值(value)向量,无需额外计算,从而大幅减少计算复杂度和模型参数。

  • 模型参数显著减少

    • SSAN能够减少约20%的模型参数,同时在公开的AISHELL-1任务中相对于传统SAN实现了6.7%的相对CER(字符错误率)降低。

    • 在大规模数据集(如20,000小时的语音数据)上,SSAN在减少参数的同时保持了与SAN相当的性能。

  • 优越的长时上下文建模能力

    • 借助FSMN的FIR滤波结构,SSAN可以高效地编码长时间上下文信息,从而提升了模型对远场语音的识别性能。

    • 实验表明,SSAN在远场测试集上的表现优于SAN,尤其适用于信号质量较低的复杂语音场景。

  • 模块化与通用性

    该方法保留了Transformer的整体结构,仅对自注意力部分进行了修改,其他模块如解码器和前向反馈层未受影响,易于集成到现有的Transformer框架中。

方法

整体架构

     该论文的模型基于Transformer的编码器-解码器结构,引入了简化自注意力网络(SSAN)以替代传统的自注意力网络(SAN)。模型通过FSMN生成查询(Query)和键(Key)向量,直接使用输入作为值(Value),大幅减少了参数量,同时增强了对长时上下文的建模能力。编码器负责提取高层声学特征,解码器逐步生成文本序列,并通过交叉注意力实现特征对齐,整体结构简单高效,性能优越

1. 整体架构

  • 编码器-解码器结构

    • 编码器(Encoder):将帧级别的语音特征映射为高层次表示,用于提取语音中的声学特征。

    • 解码器(Decoder):从编码器输出中提取语言特征,逐步生成文本序列。

    • 注意力模块(Attention Module):在编码器和解码器之间建立对齐关系,用于学习声学特征与语言特征之间的映射。

2. 简化自注意力网络(SSAN

  • 在编码器和解码器的自注意力层中,使用简化自注意力网络(SSAN)代替传统的自注意力网络(SAN)。

    • 查询(Query)和键(Key)向量由前馈序列记忆网络(FSMN)生成,而非线性投影。

    • 输入向量直接作为值(Value),无需额外计算。

    • 通过FSMN增强了对长时间上下文的建模能力。

    • SSAN的设计

3. 编码器的结构

  • 编码器包含两大模块:

    • 多头自注意力层(Multi-head Self-Attention):捕捉不同子空间的长时上下文依赖。

    • 位置前馈网络(Position-wise Feedforward Network):对每一时间步的特征进行非线性变换,提升表示能力。

  • 在SSAN中,自注意力的输入向量直接赋予值(Value),而查询(Query)和键(Key)通过FSMN生成,减少了模型参数。

4. 解码器的结构

  • 解码器包括三个子模块:

    • 掩码多头自注意力层(Masked Multi-head Self-Attention):处理解码器内部的时间步依赖。

    • 编码器-解码器交叉注意力层(Cross-Attention):捕捉编码器和解码器之间的信息交互。

    • 位置前馈网络:与编码器类似。

5. 整体优化

  • 每层后面均添加跳跃连接(Skip Connection)和层归一化(Layer Normalization)。

  • 模型通过基于交叉熵(CE)的损失函数训练,不使用外部语言模型。


即插即用模块作用

SSAN 作为一个即插即用模块

  • 长时上下文建模场景

    • 在需要捕捉长时间上下文依赖的任务中(如语音识别、自然语言处理等),SSAN的FSMN结构能够更高效地编码长时序信息。

    • 特别是在复杂语音场景(如远场语音、噪声环境和信号质量较差的条件下)表现更优。

  • 资源受限场景

    • 在计算资源有限的场景(如移动设备或边缘设备)中,SSAN大幅减少模型参数和计算复杂度,同时保持较高性能,非常适合部署轻量化的语音识别系统。

  • 端到端任务场景

    • 在端到端的序列建模任务(如语音到文本、翻译等)中,SSAN可以作为替代传统自注意力网络(SAN)的模块,用于提升模型的整体效率和精度。

消融实验结果

  • 内容:将传统SAN(Self-Attention Network)与提出的SSAN(Simplified Self-Attention Network)进行对比,包括模型层数(编码器和解码器层数)、参数数量(百万级)和字符错误率(CER%)。

  • 结论

    • SSAN相比SAN在相同层数下可以显著减少参数量(减少约21.7%)并降低CER(减少6.7%)。

    • 当编码器层数为10层时,SSAN表现最优,CER降至6.84%。


  • 内容:比较了两种模型在近场和远场测试集上的CER,以及参数数量。

  • 结论

    • SSAN相比SAN在远场测试集上的CER相对降低了6.0%,近场测试集也略有改善,同时减少了约20.4%的模型参数。

    • 结果表明,SSAN在处理复杂的远场语音场景时具有显著优势。

即插即用模块

import numpy as np
import torch
from torch import nn
from torch.nn import init

# 论文地址:https://arxiv.org/pdf/2005.10463
# 论文:Simplified Self-Attention for Transformer-Based end-to-end Speech Recognition


class SimplifiedScaledDotProductAttention(nn.Module):
    '''
    Scaled dot-product attention
    '''


    def __init__(self, d_model, h,dropout=.1):
        '''
        :param d_model: Output dimensionality of the model
        :param d_k: Dimensionality of queries and keys
        :param d_v: Dimensionality of values
        :param h: Number of heads
        '''

        super(SimplifiedScaledDotProductAttention, self).__init__()

        self.d_model = d_model
        self.d_k = d_model//h
        self.d_v = d_model//h
        self.h = h

        self.fc_o = nn.Linear(h * self.d_v, d_model)
        self.dropout=nn.Dropout(dropout)



        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
        '''
        Computes
        :param queries: Queries (b_s, nq, d_model)
        :param keys: Keys (b_s, nk, d_model)
        :param values: Values (b_s, nk, d_model)
        :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
        :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
        :return:
        '''

        b_s, nq = queries.shape[:2]
        nk = keys.shape[1]

        q = queries.view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
        k = keys.view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
        v = values.view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)

        att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
        if attention_weights is not None:
            att = att * attention_weights
        if attention_mask is not None:
            att = att.masked_fill(attention_mask, -np.inf)
        att = torch.softmax(att, -1)
        att=self.dropout(att)

        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
        out = self.fc_o(out) # (b_s, nq, d_model)
        return out


if __name__ == '__main__':
    input=torch.randn(50,49,512)
    block = SimplifiedScaledDotProductAttention(d_model=512, h=8)
    output=block(input,input,input)    print(output.shape)

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章