手写self-attention的四重境界 self-attention

教育   2025-01-03 19:41   江苏  

背景

在 AI 相关的面试中,经常会有面试官让写 self-attention,但是因为 transformer 这篇文章其实包含很多的细节,因此可能面试官对于 self-attention 实现到什么程度是有不同的预期。因此这里想通过写不同版本的 self-attention 实现来达到不同面试官的预期。以此告诉面试官,了解细节,但是于时间考虑,可能只写了简化版本,如果有时间可以把完整的写出来。

> 本文首发于:https://bruceyuan.com/hands-on-code/from-self-attention-to-multi-head-self-attention.html  (更好的阅读体验)

来自:chaofa用代码打点酱油

LLM所有细分方向群+ACL25/ICML25/NAACL25投稿群->LLM所有细分领域群、投稿群从这里进入!


如果对于文字不感冒,可以查看视频号 

Self-Attention

MultiHead Attention 的时候下一章介绍;先熟悉当前这个公式。

Self Attention 的公式

 ,其中Q K V 对应不同的矩阵 W

补充知识点

  1. 1. matmul 和 @ 符号是一样的作用

  2. 2. 为什么要除以 ?a. 防止梯度消失 b. 为了让 QK 的内积分布保持和输入一样

  3. 3. 爱因斯坦方程表达式用法:torch.einsum('bqd,bkd-> bqk', X, X).shape

  4. 4. X.repeat(1, 1, 3) 表示在不同的维度进行 repeat操作,也可以用 tensor.expand 操作

第一重: 简化版本

  • • 直接对着公式实现, 

# 导入相关需要的包
import math
import torch
import torch.nn as nn

import warnings
warnings.filterwarnings(action='ignore')


class SelfAttV1(nn.Module):
    def __init__(self, hidden_dim):
        super(SelfAttV1, self).__init__()
        self.hidden_dim = hidden_dim
        # 一般 Linear 都是默认有 bias
        # 一般来说, input dim 的 hidden dim
        self.query_proj = nn.Linear(hidden_dim, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, X):
        # X shape is: (batch, seq_len, hidden_dim), 一般是和 hidden_dim 相同
        # 但是 X 的 final dim 可以和 hidden_dim 不同
        Q = self.query_proj(X)
        K = self.key_proj(X)
        V = self.value_proj(X)

        # shape is: (batch, seq_len, seq_len)
        # torch.matmul 可以改成 Q @ K.T
        # 其中 K 需要改成 shape 为: (batch, hidden_dim, seq_len)
        attention_value = torch.matmul(Q, K.transpose(-1, -2))
        attention_wight = torch.softmax(
            attention_value / math.sqrt(self.hidden_dim), dim=-1
        )
        # print(attention_wight)
        # shape is: (batch, seq_len, hidden_dim)
        output = torch.matmul(attention_wight, V)
        return output


X = torch.rand(324)
net = SelfAttV1(4)
net(X)

第二重: 效率优化

  • • 上面哪些操作可以合并矩阵优化呢?- QKV 矩阵计算的时候,可以合并成一个大矩阵计算。

    但是当前 transformers 实现中,其实是三个不同的 Linear 层

class SelfAttV2(nn.Module):
    def __init__(self, dim) -> None:
        super().__init__()
        self.dim = dim
        # 这样可以进行加速, 那么为什么现在 Llama, qwen, gpt 等
        self.proj = nn.Linear(dim, dim * 3)

        self.output_proj = nn.Linear(dim, dim)

    def forward(self, X):
        # X shape is: (batch, seq, dim)

        QKV = self.proj(X)  # (batch, seq, dim * 3)
        # reshape 从希望的 q, k, 的形式
        Q, K, V = torch.split(QKV, self.dim, dim=-1)

        # print(x)
        att_weight = torch.softmax(
            Q @ K.transpose(-1, -2) / math.sqrt(self.dim), dim=-1
        )
        output = att_weight @ V
        return self.output_proj(output)


X = torch.rand(324)
net = SelfAttV2(4)
net(X).shape

第三重: 加入细节

  • • 看上去 self attention 实现很简单,但里面还有一些细节,还有哪些细节呢?

    • • attention 计算的时候有 dropout,而且是比较奇怪的位置

    • • attention 计算的时候一般会加入 attention_mask,因为样本会进行一些 padding 操作;

    • • MultiHeadAttention 过程中,除了 QKV 三个矩阵之外,还有一个 output 对应的投影矩阵,因此虽然面试让你写 SingleHeadAttention,但是依然要问清楚,是否要第四个矩阵?

class SelfAttV3(nn.Module):
    def __init__(self, dim) -> None:
        super().__init__()
        self.dim = dim
        # 这样可以进行加速
        self.proj = nn.Linear(dim, dim * 3)
        # 一般是 0.1 的 dropout,一般写作 config.attention_probs_dropout_prob
        # hidden_dropout_prob 一般也是 0.1
        self.att_drop = nn.Dropout(0.1)

        # 不写这个应该也没人怪,应该好像是 MultiHeadAttention 中的产物,这个留给 MultiHeadAttention 也没有问题;
        self.output_proj = nn.Linear(dim, dim)

    def forward(self, X, attention_mask=None):
        # attention_mask shape is: (batch, seq)
        # X shape is: (batch, seq, dim)

        QKV = self.proj(X)  # (batch, seq, dim * 3)
        # reshape 从希望的 q, k, 的形式
        Q, K, V = torch.split(QKV, self.dim, dim=-1)

        att_weight = Q @ K.transpose(-1, -2) / math.sqrt(self.dim)
        if attention_mask is not None:
            # 给 weight 填充一个极小的值
            att_weight = att_weight.masked_fill(attention_mask == 0float('-1e20'))

        att_weight = torch.softmax(att_weight, dim=-1)

        # 这里在 BERT中的官方代码也说很奇怪,但是原文中这么用了,所以继承了下来
        # (用于 output 后面会更符合直觉?)
        att_weight = self.att_drop(att_weight)

        output = att_weight @ V
        ret = self.output_proj(output)
        return ret


X = torch.rand(342)
b = torch.tensor(
    [
        [1110],
        [1100],
        [1000],
    ]
)
print(b.shape)
mask = b.unsqueeze(dim=1).repeat(141)

net = SelfAttV3(2)
net(X, mask).shape

面试写法 (完整版)--注意注释

# 导入相关需要的包
import math
import torch
import torch.nn as nn

import warnings

warnings.filterwarnings(action='ignore')

class SelfAttV4(nn.Module):
    def __init__(self, dim) -> None:
        super().__init__()
        self.dim = dim

        # 这样很清晰
        self.query_proj = nn.Linear(dim, dim)
        self.key_proj = nn.Linear(dim, dim)
        self.value_proj = nn.Linear(dim, dim)
        # 一般是 0.1 的 dropout,一般写作 config.attention_probs_dropout_prob
        # hidden_dropout_prob 一般也是 0.1
        self.att_drop = nn.Dropout(0.1)

        # 可以不写;具体和面试官沟通。
        # 这是 MultiHeadAttention 中的产物,这个留给 MultiHeadAttention 也没有问题;
        self.output_proj = nn.Linear(dim, dim)

    def forward(self, X, attention_mask=None):
        # attention_mask shape is: (batch, seq)
        # X shape is: (batch, seq, dim)

        Q = self.query_proj(X)
        K = self.key_proj(X)
        V = self.value_proj(X)

        att_weight = Q @ K.transpose(-1, -2) / math.sqrt(self.dim)
        if attention_mask is not None:
            # 给 weight 填充一个极小的值
            att_weight = att_weight.masked_fill(attention_mask == 0float('-1e20'))

        att_weight = torch.softmax(att_weight, dim=-1)
        print(att_weight)

        # 这里在 BERT中的官方代码也说很奇怪,但是原文中这么用了,所以继承了下来
        # (用于 output 后面会更符合直觉?)
        att_weight = self.att_drop(att_weight)

        output = att_weight @ V
        ret = self.output_proj(output)
        return ret


X = torch.rand(342)
b = torch.tensor(
    [
        [1110],
        [1100],
        [1000],
    ]
)
print(b.shape)
mask = b.unsqueeze(dim=1).repeat(141)

net = SelfAttV4(2)
net(X, mask).shape

这里再次解释一下,为什么现在现在的代码实现都是 q k v 的投影矩阵都是分开写的,这是因为现在的模型很大,本身可能会做 张量并行,流水线并行等方式,所以分开写问题也不大(分开写很清晰),可能是加速效果并不明显。



备注:昵称-学校/公司-方向/会议(eg.ACL),进入技术/投稿群


id:DLNLPer,记得备注呦

深度学习自然语言处理
一个热衷于深度学习与NLP前沿技术的平台,期待在知识的殿堂与你相遇~
 最新文章