01
引言
Causal Self-Attention 通常又被称之为基于mask的Self-Attention , 是Transformer 模型的一个基本概念,尤其是在语言建模等自回归任务中。其目的是确保序列中的每个位置上的Token只关注它之前的位置(包括它自己),而不关注它之后位置上的Token 。这种机制可以防止未来位置的信息泄露,这对于预测句子中的下一个单词等任务至关重要。让我们深入了解Causal Self-Attention的工作原理以及如何实现它。
02
Self-Attention
在自注意力机制中,序列中的每个Token都可以关注其他每个位置上的Token。这需要计算一组注意力得分,用以表明每个位置对其他位置的关注程度。
注意力得分是通过Query 和Key 向量的点积,再乘以缩放因子经过softmax后计算得出的。这里,Q、K和 V均由输入序列经过线性层变换后得到。
03
Causal Self-Attention中的mask
为了加强因果关系的约束,我们对注意力得分进行了掩码处理。该掩码会将当前位置之后所有位置上的注意力得分设置为-∞(或一个非常大的负数),从而有效地在softmax操作后将其贡献值置为零。
其中,M是屏蔽mask矩阵,其中需要被屏蔽的位置处的值为-∞ ,其他位置的值为0。
04
掩码矩阵mask的作用
防止信息泄漏:确保模型在预测当前标记时不会使用未来信息。 支持自回归生成:允许模型一次生成一个文本标记,只关注已经生成的标记。
05
代码实现
让我们用 PyTorch 在 Python 中实现Causal Self-Attention 。
import torch
import torch.nn as nn
import torch.nn.functional as F
class CausalSelfAttention(nn.Module):
def __init__(self, embed_size, num_heads):
super(CausalSelfAttention, self).__init__()
self.num_heads = num_heads
self.embed_size = embed_size
self.head_dim = embed_size // num_heads
assert self.head_dim * num_heads == embed_size, "Embedding size must be divisible by number of heads"
self.values = nn.Linear(embed_size, embed_size, bias=False)
self.keys = nn.Linear(embed_size, embed_size, bias=False)
self.queries = nn.Linear(embed_size, embed_size, bias=False)
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, x):
N, seq_length, embed_size = x.shape
# Split the embedding into num_heads different pieces
values = self.values(x).view(N, seq_length, self.num_heads, self.head_dim)
keys = self.keys(x).view(N, seq_length, self.num_heads, self.head_dim)
queries = self.queries(x).view(N, seq_length, self.num_heads, self.head_dim)
values = values.transpose(1, 2)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
# Scaled dot-product attention
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
mask = torch.tril(torch.ones((seq_length, seq_length))).expand(N, 1, seq_length, seq_length)
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, seq_length, self.embed_size)
out = self.fc_out(out)
return out
测试代码如下:
# Example usage
embed_size = 512
num_heads = 8
seq_length = 10
x = torch.rand((1, seq_length, embed_size))
causal_self_attention = CausalSelfAttention(embed_size, num_heads)
output = causal_self_attention(x)
print(output.shape) # Output: torch.Size([1, seq_length, embed_size])
上述过程可以总结如下:
初始化: CausalSelfAttention
类使用embed_size
和num_heads
进行初始化。嵌入embeddings
在各头之间平均分配。线性层:我们为
key
、value
和query
创建线性层。重塑:对输入张量进行
reshape
操作,以分离头部。基于点积的注意力:计算注意力得分,应用掩码
mask
以确保因果关系,并用 softmax 进行归一化处理。合并heads: 合并各头并通过最后的线性层后输出。
06
结论
Causal SelfAttention是一种强大的注意力机制,可确保模型在自回归任务中保证数据的顺序性。通过实施该机制,大家可以构建模型,在不窥探未来的情况下一步步生成新的序列。通过本文,希望大家可以对因果自注意力机制有更加扎实的了解,并知道如何在自己的模型中实现它。祝大家编码愉快!
点击上方小卡片关注我
添加个人微信,进专属粉丝群!