即插即用频域增强通道注意力机制EFCAttention,涨点启动!

文摘   2024-11-22 17:20   上海  

论文介绍

题目:FECAM: Frequency Enhanced Channel Attention Mechanism for Time Series Forecasting

论文地址:https://arxiv.org/abs/2212.01209

QQ深度学习交流群:719278780

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 频域增强的通道注意机制 (FECAM):提出了一种新的频域增强通道注意力机制,通过离散余弦变换 (DCT) 代替传统的傅里叶变换 (FT) 提取频率信息。相比傅里叶变换,DCT天然避免了因周期性问题带来的吉布斯现象 (Gibbs Phenomenon),从而减少了高频噪声的引入。

  • 模块的通用性:FECAM不仅可以作为独立模型用于时间序列预测,还可以无缝嵌入主流的时间序列模型(如基于Transformer的模型和LSTM),提升这些模型在时间序列预测任务中的表现。这种通用性使得FECAM具有较高的实用价值和灵活性。

  • 理论证明与验证:论文中通过理论分析和实验证明了FECAM在频域建模的有效性,尤其是利用DCT进行频率信息提取,显著提升了模型的预测性能。

  • 实验结果表现优异:在六个真实世界的时间序列数据集上,FECAM在预测准确性方面达到了最新的最佳效果,并且相比其他方法具有更少的参数增量和计算开销。

方法

整体结构

       论文中的模型结构主要是利用离散余弦变换(DCT)提取时间序列数据的频域信息,通过频域增强通道注意力机制(FECAM)在不同通道和频率分量之间自适应建模,从而捕捉到更多关键特征,最终结合全连接层或投影层生成增强的预测输出。这一结构既可独立用于预测,也能无缝集成到其他模型中,提升其预测性能。

  • 输入处理与通道划分:首先将多变量时间序列数据按通道维度拆分成多个子序列,每个子序列包含不同的变量特征。

  • 离散余弦变换 (DCT):对每个通道进行DCT变换,提取出对应的频率分量。这一步骤能够避免传统傅里叶变换带来的周期性问题(即吉布斯现象),从而更高效地捕捉低频信息,同时避免高频噪声的干扰。

  • 通道注意力机制:通过频域的特征图,FECAM可以在不同通道和频率分量之间自适应地建模。使用全连接层对频率增强后的特征图进行加权学习,从而获得每个通道和频率分量的重要性。

  • 重建与输出:最后,将学习到的频域信息和通道注意力机制的结果重新组合,生成增强后的时间序列预测。这一步通过全连接层或投影层来进行,确保频域和时间域信息能够在预测中得到充分利用。

即插即用模块作用

EFCAttention 作为一个即插即用模块,主要适用于:

  • 时间序列预测场景

    • 用于电力负荷预测、气象数据预测、金融数据分析、交通流量预测等领域,帮助模型处理周期性和趋势性强的数据。

  • 频域信息的增强

    • FECAttention通过离散余弦变换(DCT)获取数据的频域特征,有效捕捉低频和高频信息,避免传统傅里叶变换带来的高频噪声问题。

  • 增强特征重要性

    • 该模块在不同通道和频率分量之间自适应地建模,提升特征的表达能力,使模型能够更精准地学习到时间序列数据中的重要模式。

  • 提升模型鲁棒性与预测精度

    • 通过引入频域信息,FECAttention可以显著提升各种时序模型(如LSTM、Transformer等)的预测性能,尤其在包含丰富低频信息的数据集上效果尤为显著。

消融实验结果

  • 该表显示了在不同数据集上,将FECAM模块嵌入到主流的Transformer和RNN模型(如LSTM、Reformer、Informer、Autoformer等)后所带来的性能提升情况。实验结果表明,FECAM模块显著提升了各个模型的预测精度,尤其是在Exchange、ETTm2和Weather等包含丰富低频信息的数据集上效果尤为显著,而在Traffic数据集上提升相对较小,表明FECAM对频率信息较为敏感的数据集更具优势。

即插即用模块

import torch.nn as nn
import numpy as np
import torch
#论文:FECAM: Frequency Enhanced Channel Attention Mechanism for Time Series Forecasting
#论文地址:https://arxiv.org/abs/2212.01209

try:
    from torch import irfft
    from torch import rfft
except ImportError:
    def rfft(x, d):
        t = torch.fft.fft(x, dim=(-d))
        r = torch.stack((t.real, t.imag), -1)
        return r


    def irfft(x, d):
        t = torch.fft.ifft(torch.complex(x[:, :, 0], x[:, :, 1]), dim=(-d))
        return t.real


def dct(x, norm=None):
    """
    Discrete Cosine Transform, Type II (a.k.a. the DCT)

    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html

    :param x: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the DCT-II of the signal over the last dimension
    """

    x_shape = x.shape
    N = x_shape[-1]
    x = x.contiguous().view(-1, N)

    v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)

    Vc = rfft(v, 1)

    k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
    W_r = torch.cos(k)
    W_i = torch.sin(k)

    V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i

    if norm == 'ortho':
        V[:, 0] /= np.sqrt(N) * 2
        V[:, 1:] /= np.sqrt(N / 2) * 2

    V = 2 * V.view(*x_shape)

    return V


class dct_channel_block(nn.Module):
    def __init__(self, channel):
        super(dct_channel_block, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(channel, channel * 2, bias=False),
            nn.Dropout(p=0.1),
            nn.ReLU(inplace=True),
            nn.Linear(channel * 2, channel, bias=False),
            nn.Sigmoid()
        )

        self.dct_norm = nn.LayerNorm([96], eps=1e-6) # for lstm on length-wise

    def forward(self, x):
        b, c, l = x.size() # (B,C,L) (32,96,512)
        list = []
        for i in range(c):
            freq = dct(x[:, i, :])
            list.append(freq)

        stack_dct = torch.stack(list, dim=1)

        lr_weight = self.dct_norm(stack_dct)
        lr_weight = self.fc(lr_weight)
        lr_weight = self.dct_norm(lr_weight)

        return x * lr_weight # result


if __name__ == '__main__':
    input = torch.rand(8, 7, 96)
    block = dct_channel_block(96)
    result = block(input)
    print(input.size())    print(result.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章