即插即用xLSTM,适用于NLP和时序,涨点起飞起飞了

文摘   2025-01-21 17:20   中国香港  

论文介绍

题目:xLSTM: Extended Long Short-Term Memory

论文地址:https://arxiv.org/pdf/2405.04517

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 引入指数门控(Exponential Gating):

    • 改进传统LSTM的存储决策问题,增加指数激活函数的门控机制(输入门和遗忘门)。

    • 提供了适当的归一化和稳定化技术,避免指数函数引起的数值溢出。

  • 新的记忆结构:

    • sLSTM:采用标量记忆单元,引入新的记忆混合方法,并支持多头机制。

    • mLSTM:基于矩阵的记忆结构,支持完全并行化,并使用协方差更新规则来增强存储容量。

    • 提出两种新的LSTM变体:

  • 残差结构的改进:

    • 将改进的LSTM集成到残差模块中,形成xLSTM模块,并通过堆叠这些模块构建完整的xLSTM架构

    • 提供两种残差块结构:用于sLSTM的后投影块和用于mLSTM的前投影块。

  • 改进并行化能力:

    • mLSTM摒弃了传统LSTM中的隐藏-隐藏连接,增强了模型的并行性。

    • CUDA优化使得模型在GPU上的计算更加高效。

  • 在存储容量和序列处理方面的性能提升:

    • mLSTM的矩阵记忆结构显著提高了模型的存储容量,可更好地处理稀有词预测和长序列上下文。

    • sLSTM的新记忆混合方法提升了模型的状态跟踪能力。

  • 在语言建模中的竞争表现:

    • 通过在大型数据集上的训练和对比实验,证明了xLSTM在验证困惑度(perplexity)和上下文外推能力上优于现有的Transformer和状态空间模型。

方法

整体架构

       xLSTM 是一种基于残差模块构建的深层神经网络架构,通过堆叠 sLSTM(标量记忆单元,支持记忆混合)和 mLSTM(矩阵记忆单元,支持完全并行化)模块组成。其核心创新包括引入指数门控机制和协方差更新规则,以增强模型的存储能力和长序列处理能力。整个架构采用预归一化和残差连接来保证深层网络的稳定性,并通过线性计算复杂度实现高效的语言建模和长序列任务处理

1. 基本组件:xLSTM模块

xLSTM模块由以下两种主要变体组成:

a. sLSTM(Scalar LSTM)模块:

  • 记忆单元: 标量记忆(scalar memory)。

  • 门控机制: 引入指数门控(Exponential Gating),允许输入门和遗忘门使用指数激活函数。

  • 记忆混合(Memory Mixing): 支持多头机制,每个头内的记忆单元可以混合,但不同头之间不混合。

  • 残差块: 后投影残差块(Post Up-Projection Block),即非线性变换发生在低维空间后,再线性映射到高维空间。

b. mLSTM(Matrix LSTM)模块:

  • 记忆单元: 矩阵记忆(matrix memory),提升存储容量。

  • 更新规则: 协方差更新(Covariance Update Rule),允许存储和检索键值对(key-value pairs)。

  • 并行性: 摒弃了隐藏状态间的连接,支持完全并行化计算。

  • 残差块: 前投影残差块(Pre Up-Projection Block),即先将输入映射到高维空间,进行非线性处理后再映射回原始空间。


2. 架构构建:xLSTM Blocks

通过堆叠多个xLSTM模块,构成整个架构的基础单元:

  • 每个模块以预归一化(Pre-LayerNorm)和残差连接为基础,确保深层网络的稳定性。

  • 模块内部采用sLSTM和mLSTM的混合组合,例如:xLSTM[7:1] 表示 7 个 mLSTM 块和 1 个 sLSTM 块的组合。


3. 完整架构:xLSTM Architecture

  • 通过堆叠多个xLSTM模块(如48个块),形成完整的深层模型架构。

  • 支持预训练和下游任务的微调,针对语言建模任务进行优化。

即插即用模块作用

xLSTM 作为一个即插即用模块

  1. 改进存储能力:

  • sLSTM 的指数门控机制允许对新信息进行更灵活的存储和更新,从而增强存储稀有事件和长时间依赖信息的能力。

  • 增强表达能力:

    • 通过引入记忆混合机制,sLSTM 可更好地整合多头或多单元的记忆,提升模型的上下文理解能力。

  • 扩展性和兼容性:

    • sLSTM 可作为即插即用模块,嵌入到现有的深度学习架构(如 Transformer 或其他 RNN)中,增强其在状态跟踪和记忆管理方面的能力。

  • 稳定性与易用性:

    • sLSTM 通过指数门控的归一化和稳定化技术,避免了数值溢出问题,使得模块在训练和推理过程中更加稳定和易用。

    消融实验结果

    • 位置:表 2 顶部

    • 内容: 分析从传统 LSTM 到 xLSTM 的各个改进步骤对性能的影响。

    • 说明:

      • 每一步的改进显著降低了验证集困惑度(Perplexity)。

      • 指数门控和矩阵记忆的引入对性能提升尤为关键。

      • 将传统多层 LSTM 转换为 xLSTM,通过以下步骤逐步改进:

      • 结果:

      • 引入残差结构(ResNet Backbone)。

      • 增加后投影模块(Up-Projection Backbone)。

      • 添加指数门控(Exponential Gating)。

      • 引入矩阵记忆(Matrix Memory)。

    • 位置:表 2 底部

    • 内容: 分析不同的门控机制对模型性能的影响。

    • 说明:

      • 没有门控的模型性能最差。

      • 使用输入门和遗忘门的组合时模型表现最好,验证了门控机制对模型存储能力和表达能力的重要性。

      • 比较了以下门控机制的影响:

      • 结果:

      • 无门控(No Gates)。

      • 单独使用遗忘门(Forget Gate)。

      • 单独使用输入门(Input Gate)。

      • 同时使用输入门和遗忘门,并结合指数激活函数。

    即插即用模块

    import torch
    import torch.nn as nn
    from typing import Tuple, Optional, List


    class sLSTMCell(nn.Module):
        def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
            super().__init__()

            # Store the input and hidden size
            self.input_size = input_size
            self.hidden_size = hidden_size
            self.bias = bias

            # Combine the Weights and Recurrent weights into a single matrix
            self.W = nn.Parameter(
                nn.init.xavier_uniform_(
                    torch.randn(self.input_size + self.hidden_size, 4 * self.hidden_size)
                ),
                requires_grad=True,
            )
            # Combine the Bias into a single matrix
            if self.bias:
                self.B = nn.Parameter(
                    (torch.zeros(4 * self.hidden_size)), requires_grad=True
                )

        def forward(
            self,
            x: torch.Tensor,
            internal_state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
        )
     -> Tuple[
            torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
        ]:
            # Unpack the internal state
            h, c, n, m = internal_state # (batch_size, hidden_size)

            # Combine the weights and the input
            combined = torch.cat((x, h), dim=1) # (batch_size, input_size + hidden_size)
            # Calculate the linear transformation
            gates = torch.matmul(combined, self.W) # (batch_size, 4 * hidden_size)

            # Add the bias if included
            if self.bias:
                gates += self.B

            # Split the gates into the input, forget, output and stabilization gates
            z_tilda, i_tilda, f_tilda, o_tilda = torch.split(gates, self.hidden_size, dim=1)

            # Calculate the activation of the states
            z_t = torch.tanh(z_tilda) # (batch_size, hidden_size)
            # Exponential activation of the input gate
            i_t = torch.exp(i_tilda) # (batch_size, hidden_size)
            # Exponential activation of the forget gate
            f_t = torch.sigmoid(f_tilda) # (batch_size, hidden_size)

            # Sigmoid activation of the output gate
            o_t = torch.sigmoid(o_tilda) # (batch_size, input_size)
            # Calculate the stabilization state
            m_t = torch.max(torch.log(f_t) + m, torch.log(i_t)) # (batch_size, hidden_size)
            # Calculate the input stabilization state
            i_prime = torch.exp(i_tilda - m_t) # (batch_size, hidden_size)

            # Calculate the new internal states
            c_t = f_t * c + i_prime * z_t # (batch_size, hidden_size)
            n_t = f_t * n + i_prime # (batch_size, hidden_size)

            # Calculate the stabilized hidden state
            h_tilda = c_t / n_t # (batch_size, hidden_size)

            # Calculate the new hidden state
            h_t = o_t * h_tilda # (batch_size, hidden_size)
            return h_t, (
                h_t,
                c_t,
                n_t,
                m_t,
            ) # (batch_size, hidden_size), (batch_size, hidden_size), (batch_size, hidden_size), (batch_size, hidden_size)

        def init_hidden(
            self, batch_size: int, **kwargs
        )
     -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
            return (
                torch.zeros(batch_size, self.hidden_size, **kwargs),
                torch.zeros(batch_size, self.hidden_size, **kwargs),
                torch.zeros(batch_size, self.hidden_size, **kwargs),
                torch.zeros(batch_size, self.hidden_size, **kwargs),
            )


    class sLSTM(nn.Module):
        def __init__(
            self,
            input_size: int,
            hidden_size: int,
            num_layers: int,
            bias: bool = True,
            batch_first: bool = False,
        )
     -> None:
            super().__init__()
            self.input_size = input_size
            self.hidden_size = hidden_size
            self.num_layers = num_layers
            self.bias = bias
            self.batch_first = batch_first

            self.cells = nn.ModuleList(
                [
                    sLSTMCell(input_size if layer == 0 else hidden_size, hidden_size, bias)
                    for layer in range(num_layers)
                ]
            )

        def forward(
            self,
            x: torch.Tensor,
            hidden_states: Optional[
                List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]
            ] = None,
        )
     -> Tuple[
            torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
        ]:
            # Permute the input tensor if batch_first is True
            if self.batch_first:
                x = x.permute(1, 0, 2)

            # Initialize the hidden states if not provided
            if hidden_states is None:
                hidden_states = self.init_hidden(x.size(1), device=x.device, dtype=x.dtype)
            else:
                # Check if the hidden states are of the correct length
                if len(hidden_states) != self.num_layers:
                    raise ValueError(
                        f"Expected hidden states of length {self.num_layers}, but got {len(hidden_states)}"
                    )
                if any(state[0].size(0) != x.size(1) for state in hidden_states):
                    raise ValueError(
                        f"Expected hidden states of batch size {x.size(1)}, but got {hidden_states[0][0].size(0)}"
                    )

            H, C, N, M = [], [], [], []

            for layer, cell in enumerate(self.cells):
                lh, lc, ln, lm = [], [], [], []
                for t in range(x.size(0)):
                    h_t, hidden_states[layer] = (
                        cell(x[t], hidden_states[layer])
                        if layer == 0
                        else cell(H[layer - 1][t], hidden_states[layer])
                    )
                    lh.append(h_t)
                    lc.append(hidden_states[layer][0])
                    ln.append(hidden_states[layer][1])
                    lm.append(hidden_states[layer][2])

                H.append(torch.stack(lh, dim=0))
                C.append(torch.stack(lc, dim=0))
                N.append(torch.stack(ln, dim=0))
                M.append(torch.stack(lm, dim=0))

            H = torch.stack(H, dim=0)
            C = torch.stack(C, dim=0)
            N = torch.stack(N, dim=0)
            M = torch.stack(M, dim=0)

            return H[-1], (H, C, N, M)

        def init_hidden(
            self, batch_size: int, **kwargs
        )
     -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:

            return [cell.init_hidden(batch_size, **kwargs) for cell in self.cells]



    if __name__ == '__main__':
        # 定义输入张量的参数
        input_size = 128
        hidden_size = 128
        num_layers = 2
        seq_length = 10
        batch_size = 32
        dropout = 0.1

        # 初始化 mLSTM 模块
        block = sLSTM(input_size, hidden_size, num_layers)

        # 随机生成输入张量
        input_seq = torch.rand(batch_size, seq_length, input_size)

        # 运行前向传递
        output, hidden_state = block(input_seq)

        # 输出输入张量和输出张量的形状
        print(" sLSTM.Input size:", input_seq.size())
        print("sLSTM.Output size:", output.size())


    class mLSTMCell(nn.Module):
        def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:

            super().__init__()

            self.input_size = input_size
            self.hidden_size = hidden_size
            self.bias = bias

            # Initialize weights and biases
            self.W_i = nn.Parameter(
                nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
                requires_grad=True,
            )
            self.W_f = nn.Parameter(
                nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
                requires_grad=True,
            )
            self.W_o = nn.Parameter(
                nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
                requires_grad=True,
            )
            self.W_q = nn.Parameter(
                nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
                requires_grad=True,
            )
            self.W_k = nn.Parameter(
                nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
                requires_grad=True,
            )
            self.W_v = nn.Parameter(
                nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
                requires_grad=True,
            )

            if self.bias:
                self.B_i = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
                self.B_f = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
                self.B_o = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
                self.B_q = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
                self.B_k = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
                self.B_v = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)

        def forward(
            self,
            x: torch.Tensor,
            internal_state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
        )
     -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
            # Get the internal state
            C, n, m = internal_state

            # Calculate the input, forget, output, query, key and value gates
            i_tilda = (
                torch.matmul(x, self.W_i) + self.B_i
                if self.bias
                else torch.matmul(x, self.W_i)
            )
            f_tilda = (
                torch.matmul(x, self.W_f) + self.B_f
                if self.bias
                else torch.matmul(x, self.W_f)
            )
            o_tilda = (
                torch.matmul(x, self.W_o) + self.B_o
                if self.bias
                else torch.matmul(x, self.W_o)
            )
            q_t = (
                torch.matmul(x, self.W_q) + self.B_q
                if self.bias
                else torch.matmul(x, self.W_q)
            )
            k_t = (
                torch.matmul(x, self.W_k) / torch.sqrt(torch.tensor(self.hidden_size))
                + self.B_k
                if self.bias
                else torch.matmul(x, self.W_k) / torch.sqrt(torch.tensor(self.hidden_size))
            )
            v_t = (
                torch.matmul(x, self.W_v) + self.B_v
                if self.bias
                else torch.matmul(x, self.W_v)
            )

            # Exponential activation of the input gate
            i_t = torch.exp(i_tilda)
            f_t = torch.sigmoid(f_tilda)
            o_t = torch.sigmoid(o_tilda)

            # Stabilization state
            m_t = torch.max(torch.log(f_t) + m, torch.log(i_t))
            i_prime = torch.exp(i_tilda - m_t)

            # Covarieance matrix and normalization state
            C_t = f_t.unsqueeze(-1) * C + i_prime.unsqueeze(-1) * torch.einsum(
                "bi, bk -> bik", v_t, k_t
            )
            n_t = f_t * n + i_prime * k_t

            normalize_inner = torch.diagonal(torch.matmul(n_t, q_t.T))
            divisor = torch.max(
                torch.abs(normalize_inner), torch.ones_like(normalize_inner)
            )
            h_tilda = torch.einsum("bkj,bj -> bk", C_t, q_t) / divisor.view(-1, 1)
            h_t = o_t * h_tilda

            return h_t, (C_t, n_t, m_t)

        def init_hidden(
            self, batch_size: int, **kwargs
        )
     -> Tuple[torch.Tensor, torch.Tensor]:
            return (
                torch.zeros(batch_size, self.hidden_size, self.hidden_size, **kwargs),
                torch.zeros(batch_size, self.hidden_size, **kwargs),
                torch.zeros(batch_size, self.hidden_size, **kwargs),
            )


    class mLSTM(nn.Module):
        def __init__(
            self,
            input_size: int,
            hidden_size: int,
            num_layers: int,
            bias: bool = True,
            batch_first: bool = False,
        )
     -> None:
            super().__init__()
            self.input_size = input_size
            self.hidden_size = hidden_size
            self.num_layers = num_layers
            self.bias = bias
            self.batch_first = batch_first

            self.cells = nn.ModuleList(
                [
                    mLSTMCell(input_size if layer == 0 else hidden_size, hidden_size, bias)
                    for layer in range(num_layers)
                ]
            )

        def forward(
            self,
            x: torch.Tensor,
            hidden_states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
        )
     -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
            # Permute the input tensor if batch_first is True
            if self.batch_first:
                x = x.permute(1, 0, 2)

            if hidden_states is None:
                hidden_states = self.init_hidden(x.size(1), device=x.device, dtype=x.dtype)
            else:
                # Check if the hidden states are of the correct length
                if len(hidden_states) != self.num_layers:
                    raise ValueError(
                        f"Expected hidden states of length {self.num_layers}, but got {len(hidden_states)}"
                    )
                if any(state[0].size(0) != x.size(1) for state in hidden_states):
                    raise ValueError(
                        f"Expected hidden states of batch size {x.size(1)}, but got {hidden_states[0][0].size(0)}"
                    )

            H, C, N, M = [], [], [], []

            for layer, cell in enumerate(self.cells):
                lh, lc, ln, lm = [], [], [], []
                for t in range(x.size(0)):
                    h_t, hidden_states[layer] = (
                        cell(x[t], hidden_states[layer])
                        if layer == 0
                        else cell(H[layer - 1][t], hidden_states[layer])
                    )
                    lh.append(h_t)
                    lc.append(hidden_states[layer][0])
                    ln.append(hidden_states[layer][1])
                    lm.append(hidden_states[layer][2])

                H.append(torch.stack(lh, dim=0))
                C.append(torch.stack(lc, dim=0))
                N.append(torch.stack(ln, dim=0))
                M.append(torch.stack(lm, dim=0))

            H = torch.stack(H, dim=0)
            C = torch.stack(C, dim=0)
            N = torch.stack(N, dim=0)
            M = torch.stack(M, dim=0)

            return H[-1], (H, C, N, M)

        def init_hidden(
            self, batch_size: int, **kwargs
        )
     -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
            return [cell.init_hidden(batch_size, **kwargs) for cell in self.cells]


    if __name__ == '__main__':
        dropout = 0.1
        block = mLSTM(128, 128, 2)
        input_seq = torch.rand(32, 10, 128)
        output, hidden_state = block(input_seq)
        print(input_seq.size())    print(output.size())

    便捷下载方式

    浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

    更多分析可见原文


    ai缝合大王
    聚焦AI前沿,分享相关技术、论文,研究生自救指南
     最新文章