自回归模型的关键:Causal self-Attention

文摘   科技   2024-11-08 09:49   江苏  
点击蓝字
 
关注我们










01


引言



Causal Self-Attention 通常又被称之为基于mask的Self-Attention , 是Transformer 模型的一个基本概念,尤其是在语言建模等自回归任务中。其目的是确保序列中的每个位置上的Token只关注它之前的位置(包括它自己),而不关注它之后位置上的Token 。这种机制可以防止未来位置的信息泄露,这对于预测句子中的下一个单词等任务至关重要。让我们深入了解Causal Self-Attention的工作原理以及如何实现它。






02


Self-Attention


在自注意力机制中,序列中的每个Token都可以关注其他每个位置上的Token。这需要计算一组注意力得分,用以表明每个位置对其他位置的关注程度。


注意力得分是通过Query 和Key 向量的点积,再乘以缩放因子经过softmax后计算得出的。这里,Q、K和 V均由输入序列经过线性层变换后得到。





03


 Causal Self-Attention中的mask


为了加强因果关系的约束,我们对注意力得分进行了掩码处理。该掩码会将当前位置之后所有位置上的注意力得分设置为-∞(或一个非常大的负数),从而有效地在softmax操作后将其贡献值置为零。 

其中,M是屏蔽mask矩阵,其中需要被屏蔽的位置处的值为-∞ ,其他位置的值为0。





04


  掩码矩阵mask的作用

在Causal Self-Attention  中的掩码mask的作用可以总结为以下亮点:
  • 防止信息泄漏:确保模型在预测当前标记时不会使用未来信息。
  • 支持自回归生成:允许模型一次生成一个文本标记,只关注已经生成的标记。






05


代码实现


让我们用 PyTorch 在 Python 中实现Causal Self-Attention 。

import torchimport torch.nn as nnimport torch.nn.functional as F
class CausalSelfAttention(nn.Module): def __init__(self, embed_size, num_heads): super(CausalSelfAttention, self).__init__() self.num_heads = num_heads self.embed_size = embed_size self.head_dim = embed_size // num_heads
assert self.head_dim * num_heads == embed_size, "Embedding size must be divisible by number of heads"
self.values = nn.Linear(embed_size, embed_size, bias=False) self.keys = nn.Linear(embed_size, embed_size, bias=False) self.queries = nn.Linear(embed_size, embed_size, bias=False) self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, x): N, seq_length, embed_size = x.shape
# Split the embedding into num_heads different pieces values = self.values(x).view(N, seq_length, self.num_heads, self.head_dim) keys = self.keys(x).view(N, seq_length, self.num_heads, self.head_dim) queries = self.queries(x).view(N, seq_length, self.num_heads, self.head_dim)
values = values.transpose(1, 2) keys = keys.transpose(1, 2) queries = queries.transpose(1, 2)
# Scaled dot-product attention energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) mask = torch.tril(torch.ones((seq_length, seq_length))).expand(N, 1, seq_length, seq_length) energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, seq_length, self.embed_size)
out = self.fc_out(out) return out


测试代码如下:

# Example usageembed_size = 512num_heads = 8seq_length = 10x = torch.rand((1, seq_length, embed_size))
causal_self_attention = CausalSelfAttention(embed_size, num_heads)output = causal_self_attention(x)
print(output.shape) # Output: torch.Size([1, seq_length, embed_size])


上述过程可以总结如下:

  • 初始化: CausalSelfAttention类使用embed_sizenum_heads进行初始化。嵌入embeddings 在各头之间平均分配。
  • 线性层:我们为keyvaluequery创建线性层。

  • 重塑:对输入张量进行reshape操作,以分离头部。

  • 基于点积的注意力:计算注意力得分,应用掩码mask以确保因果关系,并用 softmax 进行归一化处理。

  • 合并heads:  合并各头并通过最后的线性层后输出。







06


结论


Causal SelfAttention是一种强大的注意力机制,可确保模型在自回归任务中保证数据的顺序性。通过实施该机制,大家可以构建模型,在不窥探未来的情况下一步步生成新的序列。通过本文,希望大家可以对因果自注意力机制有更加扎实的了解,并知道如何在自己的模型中实现它。祝大家编码愉快!





点击上方小卡片关注我




添加个人微信,进专属粉丝群!



AI算法之道
一个专注于深度学习、计算机视觉和自动驾驶感知算法的公众号,涵盖视觉CV、神经网络、模式识别等方面,包括相应的硬件和软件配置,以及开源项目等。
 最新文章