手写self-attention的四重境界 self-attention

学术   2025-01-03 18:22   江苏  

背景

在 AI 相关的面试中,经常会有面试官让写 self-attention,但是因为 transformer 这篇文章其实包含很多的细节,因此可能面试官对于 self-attention 实现到什么程度是有不同的预期。因此这里想通过写不同版本的 self-attention 实现来达到不同面试官的预期。以此告诉面试官,了解细节,但是于时间考虑,可能只写了简化版本,如果有时间可以把完整的写出来。

> 本文首发于:https://bruceyuan.com/hands-on-code/from-self-attention-to-multi-head-self-attention.html  (更好的阅读体验)


如果对于文字不感冒,可以查看视频号 

Self-Attention

MultiHead Attention 的时候下一章介绍;先熟悉当前这个公式。

Self Attention 的公式

 ,其中Q K V 对应不同的矩阵 W

补充知识点

  1. 1. matmul 和 @ 符号是一样的作用

  2. 2. 为什么要除以 ?a. 防止梯度消失 b. 为了让 QK 的内积分布保持和输入一样

  3. 3. 爱因斯坦方程表达式用法:torch.einsum("bqd,bkd-> bqk", X, X).shape

  4. 4. X.repeat(1, 1, 3) 表示在不同的维度进行 repeat操作,也可以用 tensor.expand 操作

第一重: 简化版本

  • • 直接对着公式实现, 

# 导入相关需要的包
import math
import torch
import torch.nn as nn

import warnings
warnings.filterwarnings(action="ignore")


class SelfAttV1(nn.Module):
    def __init__(self, hidden_dim):
        super(SelfAttV1, self).__init__()
        self.hidden_dim = hidden_dim
        # 一般 Linear 都是默认有 bias
        # 一般来说, input dim 的 hidden dim
        self.query_proj = nn.Linear(hidden_dim, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, X):
        # X shape is: (batch, seq_len, hidden_dim), 一般是和 hidden_dim 相同
        # 但是 X 的 final dim 可以和 hidden_dim 不同
        Q = self.query_proj(X)
        K = self.key_proj(X)
        V = self.value_proj(X)

        # shape is: (batch, seq_len, seq_len)
        # torch.matmul 可以改成 Q @ K.T
        # 其中 K 需要改成 shape 为: (batch, hidden_dim, seq_len)
        attention_value = torch.matmul(Q, K.transpose(-1, -2))
        attention_wight = torch.softmax(
            attention_value / math.sqrt(self.hidden_dim), dim=-1
        )
        # print(attention_wight)
        # shape is: (batch, seq_len, hidden_dim)
        output = torch.matmul(attention_wight, V)
        return output


X = torch.rand(324)
net = SelfAttV1(4)
net(X)

第二重: 效率优化

  • • 上面哪些操作可以合并矩阵优化呢?- QKV 矩阵计算的时候,可以合并成一个大矩阵计算。

    但是当前 transformers 实现中,其实是三个不同的 Linear 层

class SelfAttV2(nn.Module):
    def __init__(self, dim) -> None:
        super().__init__()
        self.dim = dim
        # 这样可以进行加速, 那么为什么现在 Llama, qwen, gpt 等
        self.proj = nn.Linear(dim, dim * 3)

        self.output_proj = nn.Linear(dim, dim)

    def forward(self, X):
        # X shape is: (batch, seq, dim)

        QKV = self.proj(X)  # (batch, seq, dim * 3)
        # reshape 从希望的 q, k, 的形式
        Q, K, V = torch.split(QKV, self.dim, dim=-1)

        # print(x)
        att_weight = torch.softmax(
            Q @ K.transpose(-1, -2) / math.sqrt(self.dim), dim=-1
        )
        output = att_weight @ V
        return self.output_proj(output)


X = torch.rand(324)
net = SelfAttV2(4)
net(X).shape

第三重: 加入细节

  • • 看上去 self attention 实现很简单,但里面还有一些细节,还有哪些细节呢?

    • • attention 计算的时候有 dropout,而且是比较奇怪的位置

    • • attention 计算的时候一般会加入 attention_mask,因为样本会进行一些 padding 操作;

    • • MultiHeadAttention 过程中,除了 QKV 三个矩阵之外,还有一个 output 对应的投影矩阵,因此虽然面试让你写 SingleHeadAttention,但是依然要问清楚,是否要第四个矩阵?

class SelfAttV3(nn.Module):
    def __init__(self, dim) -> None:
        super().__init__()
        self.dim = dim
        # 这样可以进行加速
        self.proj = nn.Linear(dim, dim * 3)
        # 一般是 0.1 的 dropout,一般写作 config.attention_probs_dropout_prob
        # hidden_dropout_prob 一般也是 0.1
        self.att_drop = nn.Dropout(0.1)

        # 不写这个应该也没人怪,应该好像是 MultiHeadAttention 中的产物,这个留给 MultiHeadAttention 也没有问题;
        self.output_proj = nn.Linear(dim, dim)

    def forward(self, X, attention_mask=None):
        # attention_mask shape is: (batch, seq)
        # X shape is: (batch, seq, dim)

        QKV = self.proj(X)  # (batch, seq, dim * 3)
        # reshape 从希望的 q, k, 的形式
        Q, K, V = torch.split(QKV, self.dim, dim=-1)

        att_weight = Q @ K.transpose(-1, -2) / math.sqrt(self.dim)
        if attention_mask is not None:
            # 给 weight 填充一个极小的值
            att_weight = att_weight.masked_fill(attention_mask == 0float("-1e20"))

        att_weight = torch.softmax(att_weight, dim=-1)

        # 这里在 BERT中的官方代码也说很奇怪,但是原文中这么用了,所以继承了下来
        # (用于 output 后面会更符合直觉?)
        att_weight = self.att_drop(att_weight)

        output = att_weight @ V
        ret = self.output_proj(output)
        return ret


X = torch.rand(342)
b = torch.tensor(
    [
        [1110],
        [1100],
        [1000],
    ]
)
print(b.shape)
mask = b.unsqueeze(dim=1).repeat(141)

net = SelfAttV3(2)
net(X, mask).shape

面试写法 (完整版)--注意注释

# 导入相关需要的包
import math
import torch
import torch.nn as nn

import warnings

warnings.filterwarnings(action="ignore")

class SelfAttV4(nn.Module):
    def __init__(self, dim) -> None:
        super().__init__()
        self.dim = dim

        # 这样很清晰
        self.query_proj = nn.Linear(dim, dim)
        self.key_proj = nn.Linear(dim, dim)
        self.value_proj = nn.Linear(dim, dim)
        # 一般是 0.1 的 dropout,一般写作 config.attention_probs_dropout_prob
        # hidden_dropout_prob 一般也是 0.1
        self.att_drop = nn.Dropout(0.1)

        # 可以不写;具体和面试官沟通。
        # 这是 MultiHeadAttention 中的产物,这个留给 MultiHeadAttention 也没有问题;
        self.output_proj = nn.Linear(dim, dim)

    def forward(self, X, attention_mask=None):
        # attention_mask shape is: (batch, seq)
        # X shape is: (batch, seq, dim)

        Q = self.query_proj(X)
        K = self.key_proj(X)
        V = self.value_proj(X)

        att_weight = Q @ K.transpose(-1, -2) / math.sqrt(self.dim)
        if attention_mask is not None:
            # 给 weight 填充一个极小的值
            att_weight = att_weight.masked_fill(attention_mask == 0float("-1e20"))

        att_weight = torch.softmax(att_weight, dim=-1)
        print(att_weight)

        # 这里在 BERT中的官方代码也说很奇怪,但是原文中这么用了,所以继承了下来
        # (用于 output 后面会更符合直觉?)
        att_weight = self.att_drop(att_weight)

        output = att_weight @ V
        ret = self.output_proj(output)
        return ret


X = torch.rand(342)
b = torch.tensor(
    [
        [1110],
        [1100],
        [1000],
    ]
)
print(b.shape)
mask = b.unsqueeze(dim=1).repeat(141)

net = SelfAttV4(2)
net(X, mask).shape

这里再次解释一下,为什么现在现在的代码实现都是 q k v 的投影矩阵都是分开写的,这是因为现在的模型很大,本身可能会做 张量并行,流水线并行等方式,所以分开写问题也不大(分开写很清晰),可能是加速效果并不明显。

进技术交流群请添加AINLP小助手微信(id: ainlp2)

请备注具体方向+所用到的相关技术点

关于AINLP

AINLP 是一个有趣有AI的自然语言处理社区,专注于 AI、NLP、机器学习、深度学习、推荐算法等相关技术的分享,主题包括LLM、预训练模型、自动生成、文本摘要、智能问答、聊天机器人、机器翻译、知识图谱、推荐系统、计算广告、招聘信息、求职经验分享等,欢迎关注!加技术交流群请添加AINLP小助手微信(id:ainlp2),备注工作/研究方向+加群目的。

AINLP
一个有趣有AI的自然语言处理公众号:关注AI、NLP、大模型LLM、机器学习、推荐系统、计算广告等相关技术。公众号可直接对话双语聊天机器人,尝试对对联、作诗机、藏头诗生成器、自动写作等,查询相似词,测试NLP相关工具包。
 最新文章