本文约9000字,建议阅读9分钟 本文深入探讨Transformer模型中三种关键的注意力机制:自注意力、交叉注意力和因果自注意力。
文章目录
自注意力机制
理论基础 PyTorch实现 多头注意力扩展
概念介绍 与自注意力的区别 PyTorch实现
在语言模型中的应用 实现细节 优化技巧
自注意力概述
输入句子嵌入
sentence = 'The sun rises in the east'
dc = {s:i for i,s in enumerate(sorted(sentence.split()))}
print(dc)
{'The': 0, 'east': 1, 'in': 2, 'rises': 3, 'sun': 4, 'the': 5}
import torch
sentence_int = torch.tensor(
[dc[s] for s in sentence.split()]
)
print(sentence_int)
tensor([0, 4, 3, 2, 5, 1])
有了这个输入句子的整数表示,可以使用嵌入层将每个单词转换为向量。为简化演示,我们这里使用3维嵌入,但在实际应用中,嵌入维度通常要大得多(例如,Llama 2模型中使用4,096维)。较小的维度有助于直观理解向量而不会使页面充满数字。
vocab_size = 50_000
torch.manual_seed(123)
embed = torch.nn.Embedding(vocab_size, 3)
embedded_sentence = embed(sentence_int).detach()
print(embedded_sentence)
print(embedded_sentence.shape)
tensor([[ 0.3374, -0.1778, -0.3035],
[ 0.1794, 1.8951, 0.4954],
[ 0.2692, -0.0770, -1.0205],
[-0.2196, -0.3792, 0.7671],
[-0.5880, 0.3486, 0.6603],
[-1.1925, 0.6984, -1.4097]])
torch.Size([6, 3])
缩放点积注意力的权重矩阵
查询、键和值的转换
查询 (q) 键 (k) 值 (v)
查询:q(i) = x(i)Wq 键:k(i) = x(i)Wk 值:v(i) = x(i)Wv
torch.manual_seed(123)
d = embedded_sentence.shape[1]
d_q, d_k, d_v = 2, 2, 4
W_query = torch.nn.Parameter(torch.rand(d, d_q))
W_key = torch.nn.Parameter(torch.rand(d, d_k))
W_value = torch.nn.Parameter(torch.rand(d, d_v))
计算自注意力机制中的非归一化注意力权重
x_3 = embedded_sentence[2] # 第三个元素(索引2)
query_3 = x_3 @ W_query
key_3 = x_3 @ W_key
value_3 = x_3 @ W_value
print("Query shape:", query_3.shape)
print("Key shape:", key_3.shape)
print("Value shape:", value_3.shape)
Query shape: torch.Size([2])
Key shape: torch.Size([2])
Value shape: torch.Size([4])
keys = embedded_sentence @ W_key
values = embedded_sentence @ W_value
print("All keys shape:", keys.shape)
print("All values shape:", values.shape)
All keys shape: torch.Size([6, 2])
All values shape: torch.Size([6, 4])
omega_3 = query_3 @ keys.T
print("Unnormalized attention weights for query 3:")
print(omega_3)
Unnormalized attention weights for query 3:
tensor([ 0.8721, -0.5302, 2.1436, -1.7589, 0.9103, 1.3245])
max_score = omega_3.max()
min_score = omega_3.min()
max_index = omega_3.argmax()
min_index = omega_3.argmin()
print(f"Highest compatibility: {max_score:.4f} with input {max_index+1}")
print(f"Lowest compatibility: {min_score:.4f} with input {min_index+1}")
Highest compatibility: 2.1436 with input 3
Lowest compatibility: -1.7589 with input 4
注意力权重归一化与上下文向量计算
import torch.nn.functional as F
d_k = 2 # 键向量的维度
omega_3 = query_3 @ keys.T # 使用前面的例子
attention_weights_3 = F.softmax(omega_3 / d_k**0.5, dim=0)
print("Normalized attention weights for input 3:")
print(attention_weights_3)
Normalized attention weights for input 3:
tensor([0.1834, 0.0452, 0.6561, 0.0133, 0.1906, 0.2885])
max_weight = attention_weights_3.max()
max_weight_index = attention_weights_3.argmax()
print(f"Input {max_weight_index+1} has the highest attention weight: {max_weight:.4f}")
Input 3 has the highest attention weight: 0.6561
context_vector_3 = attention_weights_3 @ values
print("Context vector shape:", context_vector_3.shape)
print("Context vector:")
print(context_vector_3)
Context vector shape: torch.Size([4])
Context vector:
tensor([0.6237, 0.9845, 1.0523, 1.2654])
自注意力的PyTorch实现
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, d_in, d_out_kq, d_out_v):
super().__init__()
self.d_out_kq = d_out_kq
self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
self.W_key = nn.Parameter(torch.rand(d_in, d_out_kq))
self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))
def forward(self, x):
keys = x @ self.W_key
queries = x @ self.W_query
values = x @ self.W_value
attn_scores = queries @ keys.T
attn_weights = torch.softmax(
attn_scores / self.d_out_kq**0.5, dim=-1
)
context_vec = attn_weights @ values
return context_vec
这个类封装了以下步骤:
将输入投影到键、查询和值空间 计算注意力分数 缩放和归一化注意力权重 生成最终的上下文向量
在__init__中,我们将权重矩阵初始化为nn.Parameter对象,使PyTorch能够在训练过程中自动跟踪和更新它们。 forward方法以简洁的方式实现了整个自注意力过程。 我们使用@运算符进行矩阵乘法,这等同于torch.matmul。 缩放因子self.d_out_kq**0.5在softmax之前应用,如前所述。
torch.manual_seed(123)
d_in, d_out_kq, d_out_v = 3, 2, 4
sa = SelfAttention(d_in, d_out_kq, d_out_v)
# 假设embedded_sentence是我们的输入张量
output = sa(embedded_sentence)
print(output)
tensor([[-0.1564, 0.1028, -0.0763, -0.0764],
[ 0.5313, 1.3607, 0.7891, 1.3110],
[-0.3542, -0.1234, -0.2627, -0.3706],
[ 0.0071, 0.3345, 0.0969, 0.1998],
[ 0.1008, 0.4780, 0.2021, 0.3674],
[-0.5296, -0.2799, -0.4107, -0.6006]], grad_fn=<MmBackward0>)
多头注意力机制:自注意力的高级扩展
多头注意力的核心概念
创建多组查询、键和值权重矩阵。 每组矩阵形成一个"注意力头"。 每个头可能关注输入序列的不同方面。 所有头的输出被连接并进行线性变换,生成最终输出。
多头注意力的实现
class MultiHeadAttentionWrapper(nn.Module):
def __init__(self, d_in, d_out_kq, d_out_v, num_heads):
super().__init__()
self.heads = nn.ModuleList(
[SelfAttention(d_in, d_out_kq, d_out_v)
for _ in range(num_heads)]
)
def forward(self, x):
return torch.cat([head(x) for head in self.heads], dim=-1)
torch.manual_seed(123)
d_in, d_out_kq, d_out_v = 3, 2, 1
num_heads = 4
mha = MultiHeadAttentionWrapper(d_in, d_out_kq, d_out_v, num_heads)
context_vecs = mha(embedded_sentence)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)
tensor([[-0.0185, 0.0170, 0.1999, -0.0860],
[ 0.4003, 1.7137, 1.3981, 1.0497],
[-0.1103, -0.1609, 0.0079, -0.2416],
[ 0.0668, 0.3534, 0.2322, 0.1008],
[ 0.1180, 0.6949, 0.3157, 0.2807],
[-0.1827, -0.2060, -0.2393, -0.3167]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([6, 4])
多头注意力的优势
多样化特征学习:每个头可以学习关注输入的不同方面。例如,一个头可能专注于局部关系而另一个可能捕捉长距离依赖。 增强模型容量:多个头允许模型表示数据中更复杂的关系,而不显著增加参数数量。 并行处理效率:每个头的独立性使得在GPU或TPU上能进行高效的并行计算。 提高模型稳定性和鲁棒性:使用多个头可以使模型更加鲁棒,因为它不太可能过度拟合单一注意力机制捕捉到的特定模式。
多头注意力与单头大输出的比较
独立学习能力:多头注意力中的每个头学习自己的查询、键和值投影集,允许更多样化的特征提取。 计算效率优势:多头注意力可以更高效地并行化,可能导致更快的训练和推理速度。 集成学习效果:多个头的作用类似于注意力机制的集成,每个头可能专门处理输入的不同方面。
实际应用考虑
交叉注意力:连接不同输入序列的桥梁
交叉注意力的核心概念
处理两个不同的输入序列。 查询由一个序列生成,而键和值来自另一个序列。 允许模型基于另一个序列的内容有选择地关注一个序列的部分。
交叉注意力的实现
class CrossAttention(nn.Module):
def __init__(self, d_in, d_out_kq, d_out_v):
super().__init__()
self.d_out_kq = d_out_kq
self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
self.W_key = nn.Parameter(torch.rand(d_in, d_out_kq))
self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))
def forward(self, x_1, x_2):
queries_1 = x_1 @ self.W_query
keys_2 = x_2 @ self.W_key
values_2 = x_2 @ self.W_value
attn_scores = queries_1 @ keys_2.T
attn_weights = torch.softmax(
attn_scores / self.d_out_kq**0.5, dim=-1)
context_vec = attn_weights @ values_2
return context_vec
torch.manual_seed(123)
d_in, d_out_kq, d_out_v = 3, 2, 4
crossattn = CrossAttention(d_in, d_out_kq, d_out_v)
first_input = embedded_sentence
second_input = torch.rand(8, d_in)
print("First input shape:", first_input.shape)
print("Second input shape:", second_input.shape)
context_vectors = crossattn(first_input, second_input)
print(context_vectors)
print("Output shape:", context_vectors.shape)
First input shape: torch.Size([6, 3])
Second input shape: torch.Size([8, 3])
tensor([[0.4231, 0.8665, 0.6503, 1.0042],
[0.4874, 0.9718, 0.7359, 1.1353],
[0.4054, 0.8359, 0.6258, 0.9667],
[0.4357, 0.8886, 0.6678, 1.0311],
[0.4429, 0.9006, 0.6775, 1.0460],
[0.3860, 0.8021, 0.5985, 0.9250]], grad_fn=<MmBackward0>)
Output shape: torch.Size([6, 4])
交叉注意力与自注意力的主要区别
双输入序列:交叉注意力接受两个输入,x_1和x_2,而不是单一输入。 查询-键交互方式:查询来自x_1,而键和值来自x_2。 序列长度灵活性:两个输入序列可以具有不同的长度。
交叉注意力的应用领域
机器翻译:在原始Transformer模型中,交叉注意力允许解码器在生成翻译时关注源句子的相关部分。 图像描述生成:模型可以在生成描述的每个词时关注图像的不同部分(表示为图像特征序列)。 Stable Diffusion模型:交叉注意力用于将图像生成与文本提示相关联,允许模型将文本信息整合到视觉生成过程中。 问答系统:模型可以根据问题的内容关注上下文段落的不同部分。
交叉注意力的优势
信息整合能力:允许模型有选择地将一个序列的信息整合到另一个序列的处理中。 处理多模态输入的灵活性:可以处理不同长度和模态的输入。 增强可解释性:注意力权重可以提供洞察,说明模型如何关联两个序列的不同部分。
实际应用中的考虑因素
嵌入维度(d_in)必须对两个输入序列保持一致,即使它们的长度不同。 对于长序列,交叉注意力可能计算密集,需要考虑计算效率。 与自注意力类似,交叉注意力也可以扩展到多头版本,以获得更强的表达能力。
Stable Diffusion模型也利用了交叉注意力机制。在该模型中交叉注意力发生在U-Net架构内生成的图像特征和用于指导的文本提示之间。这种技术最初在介绍Stable Diffusion概念的论文《High-Resolution Image Synthesis with Latent Diffusion Models》中被提出。随后Stability AI采用了这种方法来实现广受欢迎的Stable Diffusion模型。
因果自注意力
"The" → "cat""The cat" → "sits""The cat sits" → "on""The cat sits on" → "the""The cat sits on the" → "mat"
torch.manual_seed(123)
d_in, d_out_kq, d_out_v = 3, 2, 4
W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
W_key = nn.Parameter(torch.rand(d_in, d_out_kq))
W_value = nn.Parameter(torch.rand(d_in, d_out_v))
x = embedded_sentence
keys = x @ W_key
queries = x @ W_query
values = x @ W_value
attn_scores = queries @ keys.T
print(attn_scores)
print(attn_scores.shape)
tensor([[ 0.0613, -0.3491, 0.1443, -0.0437, -0.1303, 0.1076],
[ ],
[ ],
[ ],
[ ],
[ ]],
grad_fn=<MmBackward0>)
torch.Size([6, 6])
attn_weights = torch.softmax(attn_scores / d_out_kq**0.5, dim=1)
print(attn_weights)
tensor([[0.1772, 0.1326, 0.1879, 0.1645, 0.1547, 0.1831],
[ ],
[ ],
[ ],
[ ],
[ ]],
grad_fn=<SoftmaxBackward0>)
block_size = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(block_size, block_size))
print(mask_simple)
tensor([[1., 0., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0.],
[1., 1., 1., 1., 0., 0.],
[1., 1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1., 1.]])
masked_simple = attn_weights * mask_simple
print(masked_simple)
tensor([[0.1772, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ ],
[ ],
[ ],
[ ],
[ ]],
grad_fn=<MulBackward0>)
row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ ],
[ ],
[ ],
[ ],
[ ]],
grad_fn=<DivBackward0>)
mask = torch.triu(torch.ones(block_size, block_size), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), float('-inf'))
print(masked)
tensor([[ 0.0613, -inf, -inf, -inf, -inf, -inf],
[ ],
[ ],
[ ],
[ ],
[ ]],
grad_fn=<MaskedFillBackward0>)
attn_weights = torch.softmax(masked / d_out_kq**0.5, dim=1)
print(attn_weights)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ ],
[ ],
[ ],
[ ],
[ ]],
grad_fn=<SoftmaxBackward0>)