ACM 即插即用TSConformerBlock 模块,涨点启动!

文摘   2024-12-11 17:20   上海  

论文介绍

题目:CMGAN: Conformer-Based Metric-GAN for Monaural Speech Enhancement

论文地址:https://arxiv.org/pdf/2209.11112v3

QQ深度学习交流群:719278780

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • CMGAN架构的引入提出了第一个基于Conformer的Metric-GAN架构(CMGAN),用于单通道语音增强,包括去噪(denoising)、去混响(dereverberation)和超分辨率(super-resolution)任务。


  • 双阶段Conformer模块在生成器中利用双阶段Conformer模块,分别捕获时间和频率维度的依赖性,有效结合了局部特征和全局特征。

  • 生成器设计采用共享编码器来处理幅度和复数部分的联合表示。生成器的解码器分为掩码解码器和复数解码器,用于分别优化幅度和复数谱图的细化。

  • 度量判别器引入度量判别器(metric discriminator),可以优化感知评估语音质量(PESQ)等非可微度量,提升语音质量。

  • 新颖的任务扩展该方法不仅局限于语音去噪,还扩展到去混响和超分辨率任务,尤其在复杂时频域中进行了超分辨率探索。

  • 综合性实验验证通过消融实验深入研究了输入特征、架构设计和损失函数的影响。在多个基准数据集上的实验结果表明,CMGAN在语音去噪、去混响和超分辨率任务中优于现有的最先进方法。

  • 性能与参数效率的平衡CMGAN在仅使用1.83M参数的情况下,达到了显著的性能提升,相较于某些依赖更大模型的框架更具效率。

方法

整体架构

     CMGAN 的整体结构由编码器、双阶段 Conformer 模块、解码器和度量判别器组成。编码器提取输入语音的时频域特征,双阶段 Conformer 模块分别在时间和频率维度捕捉依赖性,解码器通过掩码和复数解码器优化幅度和复数部分,生成增强后的复数谱图。最终,度量判别器通过模拟感知语音质量(如 PESQ)引导生成器优化,提升语音增强的整体效果。

生成器(Generator)

生成器采用编码器-双阶段Conformer-解码器的结构,输入为时频域的复数谱图,包含幅度(magnitude)和复数部分(实部与虚部)。编码器通过两个卷积块和一个膨胀DenseNet模块提取输入的时频特征,映射到共享的表示空间。双阶段Conformer模块分为两部分:第一阶段建模时间维度依赖性,捕捉语音信号的时间序列模式;第二阶段建模频率维度依赖性,提取频率分布特性。解码器则包含两个分支:掩码解码器用于生成幅度掩码,与输入幅度相乘以改善幅度表示;复数解码器直接细化复数谱图的实部和虚部,最终整合幅度和复数部分生成增强后的复数谱图。


判别器(Discriminator)

判别器基于 Metric-GAN 的设计,旨在评估生成器生成的增强语音的质量,并通过对抗训练优化生成器。判别器模拟感知语音质量评分(如 PESQ),以此作为度量标准,为生成器提供反馈指导。通过优化生成器,使其输出的增强语音更加接近目标语音,提升感知质量。


输入与输出处理

输入语音信号通过短时傅里叶变换(STFT)转化为时频域复数谱图,提供幅度和复数部分的细节信息,作为模型的输入特征。在增强语音生成后,通过逆STFT(ISTFT)将复数谱图转化回时域信号,生成最终的增强语音。输入与输出处理环节为模型的时频域到时域转换提供了基础支撑。

即插即用模块作用

TSConformerBlock 作为一个即插即用模块

  • 时间与频率联合建模通过双阶段设计,分别在时间维度和频率维度捕捉依赖性,能够处理语音信号中复杂的时间动态模式和频率分布特性。


  • 局部与全局特征提取结合 Transformer 和卷积模块,既能建模长距离依赖(全局特征),又能捕捉局部的短期变化。

  • 增强模型的鲁棒性对语音或音频数据的噪声、混响或分辨率变化更具适应性,适合多种复杂环境。


  • 模块化和易于集成TSConformerBlock 是一个独立模块,易于嵌入到其他深度学习架构中,不需要对现有系统大幅改动。

  • 计算效率高通过有效的设计(如残差连接和膨胀卷积),在捕捉依赖性的同时降低计算复杂度

消融实验结果

  • 输入与解码器结构的影响

    • 比较了单一输入(幅度或复数部分)和联合输入的效果。

    • 结果:仅使用幅度(Single-Mask)或仅使用复数(Single-Complex)的性能均低于联合输入(CMGAN 原始模型)。这表明联合建模幅度和复数部分对于语音增强的效果至关重要。

  • 解码器分离的影响

    • 比较了单路径解码器(Single-Path)与双解码器(Mask Decoder 和 Complex Decoder)。

    • 结果:双解码器比单路径解码器表现更好,证明分别优化幅度和复数部分更有效。

  • 损失函数与判别器的影响

    • 移除时间域损失(w/o Time Loss)或判别器(w/o Disc.)会导致性能下降。

      结果:时间域损失有助于提升 SSNR 分数,判别器的引入则显著改善 PESQ 和 COVL 等感知质量分数。

  • 两阶段 Conformer 设计的影响

    • 比较并行与顺序结构(Parallel-Conf. vs Sequential)。

    • 结果:顺序结构(时间维度 → 频率维度)优于并行结构,证明时间-频率的逐步建模更为高效。

  • 掩码激活函数的影响

    • 比较了 Sigmoid、ReLU、Softplus 和 PReLU 等不同激活函数。

    • 结果:PReLU 激活函数的性能最佳,因为它能够动态学习适应每个频率带的不同掩码范围(0 到 1)。

即插即用模块

import torch
import torch.nn as nn
from einops.layers.torch import Rearrange
#论文:CMGAN: Conformer-Based Metric-GAN for Monaural Speech Enhancement
#论文地址:https://arxiv.org/pdf/2209.11112v3

def get_padding(kernel_size, dilation=1):
    return int((kernel_size*dilation - dilation)/2)

class FeedForwardModule(nn.Module):
    def __init__(self, dim, mult=4, dropout=0):
        super(FeedForwardModule, self).__init__()
        self.ffm = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim * mult),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.ffm(x)


class ConformerConvModule(nn.Module):
    def __init__(self, dim, expansion_factor=2, kernel_size=31, dropout=0.):
        super(ConformerConvModule, self).__init__()
        inner_dim = dim * expansion_factor
        self.ccm = nn.Sequential(
            nn.LayerNorm(dim),
            Rearrange('b n c -> b c n'),
            nn.Conv1d(dim, inner_dim*2, 1),
            nn.GLU(dim=1),
            nn.Conv1d(inner_dim, inner_dim, kernel_size=kernel_size,
                      padding=get_padding(kernel_size), groups=inner_dim), # DepthWiseConv1d
            nn.BatchNorm1d(inner_dim),
            nn.SiLU(),
            nn.Conv1d(inner_dim, dim, 1),
            Rearrange('b c n -> b n c'),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.ccm(x)


class AttentionModule(nn.Module):
    def __init__(self, dim, n_head=8, dropout=0.):
        super(AttentionModule, self).__init__()
        self.attn = nn.MultiheadAttention(dim, n_head, dropout=dropout)
        self.layernorm = nn.LayerNorm(dim)

    def forward(self, x, attn_mask=None, key_padding_mask=None):
        x = self.layernorm(x)
        x, _ = self.attn(x, x, x,
                         attn_mask=attn_mask,
                         key_padding_mask=key_padding_mask)
        return x


class ConformerBlock(nn.Module):
    def __init__(self, dim, n_head=8, ffm_mult=4, ccm_expansion_factor=2, ccm_kernel_size=31,
                 ffm_dropout=0., attn_dropout=0., ccm_dropout=0.)
:
        super(ConformerBlock, self).__init__()
        self.ffm1 = FeedForwardModule(dim, ffm_mult, dropout=ffm_dropout)
        self.attn = AttentionModule(dim, n_head, dropout=attn_dropout)
        self.ccm = ConformerConvModule(dim, ccm_expansion_factor, ccm_kernel_size, dropout=ccm_dropout)
        self.ffm2 = FeedForwardModule(dim, ffm_mult, dropout=ffm_dropout)
        self.post_norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = x + 0.5 * self.ffm1(x)
        x = x + self.attn(x)
        x = x + self.ccm(x)
        x = x + 0.5 * self.ffm2(x)
        x = self.post_norm(x)
        return x

#Two-stage conformer (TS-Conformer)
class TSConformerBlock(nn.Module):
    def __init__(self, num_channel=64):
        super(TSConformerBlock, self).__init__()
        self.time_conformer = ConformerBlock(dim=num_channel, n_head=4, ccm_kernel_size=31,
                                             ffm_dropout=0.2, attn_dropout=0.2)
        self.freq_conformer = ConformerBlock(dim=num_channel, n_head=4, ccm_kernel_size=31,
                                             ffm_dropout=0.2, attn_dropout=0.2)

    def forward(self, x):
        b, c, t, f = x.size()
        x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c)
        x = self.time_conformer(x) + x
        x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c)
        x = self.freq_conformer(x) + x
        x = x.view(b, t, f, c).permute(0, 3, 1, 2)
        return x


if __name__ == '__main__':

    block = TSConformerBlock(num_channel=64)
    input = torch.randn(1, 64, 100, 80)
    output =block(input)

    print(input.size())    print(output.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章