01
引言
本文是手撕Transformer系列的第三篇。它从头开始介绍多头注意力机制。注意力机制是Transformer的核心概念,因为它为输入序列提供了上下文语义信息。
02
背景介绍
在Transformer模型中,注意力为每个序列提供上下文信息。这有助于模型理解不同词语之间的关系,从而创造出有意义的句子。根据维基百科的描述,"注意力层可以访问之前的所有状态,并根据学习到的相关性对其进行加权,从而提供关于远处Tokens的相关信息"。
03
Multi-Head Attention
注意力函数可以描述为将一个查询Query和一组键值对Key-Value映射到一个输出,其中Query、Key、Value和输出都是向量。输出是以Value的加权和来计算的,其中分配给每个Value的权重是通过查询Query与相应键Key的兼容函数来计算的。
论文中称这种特殊的注意力为 "缩放点积注意力"。输入包括查询Query、维度为 d_key 的Key和维度为 d_value 的Value。我们计算查询Query与所有键Key的点积,将每个点积除以 √(d_key),然后应用softmax函数来获得Value的权重。
对模型进行缩放是为了避免 softmax 函数产生极小的梯度而影响训练。
当使用多头注意力时,通常是 d_key = d_value = (d_model / n_heads),其中 n_heads 是heads头的数量。研究人员发现,之所以采用多头注意力层是因为模型能够 "在不同位置注意来自不同表征子空间的信息"。
04
通过线性层传递输入信息
计算注意力的第一步是获取 Q、K 和 V 张量;它们分别是Query、Key和Value张量。它们的计算方法如下:将位置编码嵌入(记为 X),同时将该张量通过三个线性层(记为 Wq、Wk 和 Wv)。这可以从上一节的图像中看到。
Q = XWq
K = XWk
V = XWv
要了解上述乘法操作是如何计算的,最好先将每个组件分解成不同的维度。
X 的大小为(batch_size, seq_length, d_model)。例如,32个长度为10 的序列的嵌入大小为 512,则其维度为(32, 10, 512)。 Wq、Wk 和 Wv 的维度大小为 (d_model,d_model)。按照上面的例子,它们的形状维度为(512,512)。
这样就能更好地理解上述Tensor乘法的输出结果。每个权重矩阵同时在batch中的每个序列上广播,以创建 Q、K 和 V 张量。
Q = XWq K = XWk V = XWv
上述三个线性变换的维度变化如下:
(batch_size, seq_length, d_model) x (d_model, d_model)
= (batch_size, seq_length, d_model)
下图显示了 Q、K 和 V的计算后的图例。每个紫色方框代表一个序列,每个橙色方框代表序列中的一个标记或单词,而灰色椭圆代表每个标记的嵌入。
下面的代码假定位置编码和嵌入已从本系列的前几篇文章中加载。
# convert the sequences to integers
sequences = ["I wonder what will come next!",
"This is a basic example paragraph.",
"Hello, what is a basic split?"]
# tokenize the sequences
tokenized_sequences = [tokenize(seq) for seq in sequences]
# index the sequences
indexed_sequences = [[stoi[word] for word in seq] for seq in tokenized_sequences]
# convert the sequences to a tensor
tensor_sequences = torch.tensor(indexed_sequences).long()
# vocab size
vocab_size = len(stoi)
# embedding dimensions
d_model = 8
# create the embeddings
lut = Embeddings(vocab_size, d_model) # look-up table (lut)
# create the positional encodings
pe = PositionalEncoding(d_model=d_model, dropout=0.1, max_length=10)
# embed the sequence
embeddings = lut(tensor_sequences)
# positionally encode the sequences
X = pe(embeddings)
print(X)
结果如下:
Wq = nn.Linear(d_model, d_model) # query weights (8,8)
Wk = nn.Linear(d_model, d_model) # key weights (8,8)
Wv = nn.Linear(d_model, d_model) # value weights (8,8)
print(Wq.state_dict()['weight'])
结果如下:
Wq 的权重如上图所示。Wk 和 Wv 的形状相同,权重不同。当 X 经过每个线性层时,它的形状保持不变。
Q = Wq(X) # (3,6,8)x(broadcast 8,8) = (3,6,8)
K = Wk(X) # (3,6,8)x(broadcast 8,8) = (3,6,8)
V = Wv(X) # (3,6,8)x(broadcast 8,8) = (3,6,8)
print(Q)
结果如下:
05
拆分到不同head里
创建了 Q、K 和 V 张量后,现在可以通过将 d_model 的维度更改为 (n_heads, d_key),将它们拆分成各自的头。请记住,d_key = (d_model / n_heads)。
在上一幅图像中,每个标记Token都包含一个维度为d_model的嵌入。现在,这个维度被分成行和列,形成一个矩阵;每一行都表示一个head。这可以在上图中看到。于是,每个张量的维度就变成了如下形状:
(batch_size, seq_length, d_model) → (batch_size, seq_length, n_heads, d_key)
这可以通过view来完成,view可用于添加和设置每个维度的大小。
batch_size = Q.size(0)
n_heads = 4
d_key = d_model//n_heads # 8/4 = 2
# query tensor | -1 = query_length | (3, 6, 8) -> (3, 6, 4, 2)
Q = Q.view(batch_size, -1, n_heads, d_key)
# value tensor | -1 = key_length | (3, 6, 8) -> (3, 6, 4, 2)
K = K.view(batch_size, -1, n_heads, d_key)
# value tensor | -1 = value_length | (3, 6, 8) -> (3, 6, 4, 2)
V = V.view(batch_size, -1, n_heads, d_key)
print(Q)
tensor([
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ]]],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ]]],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ]]]], grad_fn=<ViewBackward0>)
(batch_size, seq_length, n_heads, d_key) → (batch_size, n_heads, seq_length, d_key)
现在,每个序列被拆分成 n_个头,每个头接收 seq_length 个标记,每个标记中包含 d_key 个元素。这就实现了研究人员 "在不同位置关注来自不同表征子空间的信息 "的目标。
该张量的可视化效果如下图所示。每个序列为紫色,每个head为灰色。在head中,每个标记都是一排 d_key 个元素。
从本质上讲,每个头包含每个序列标记的副本,但它只有 d_key = 2 个元素表示,而不是完整的 d_model = 8 个元素表示。这意味着每个序列同时在 n_head = 4 个不同的子空间中表示。
下面的代码使用 permute 来交换每个张量的第二轴和第三轴。
# query tensor | (3, 6, 4, 2) -> (3, 4, 6, 2)
Q = Q.permute(0, 2, 1, 3)
# key tensor | (3, 6, 4, 2) -> (3, 4, 6, 2)
K = K.permute(0, 2, 1, 3)
# value tensor | (3, 6, 4, 2) -> (3, 4, 6, 2)
V = V.permute(0, 2, 1, 3)
print(Q)
tensor([
[ ],
[ ],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ]]],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ]]],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ]]]], grad_fn=<PermuteBackward0>)
# select the first sequence from the Query tensor
print(Q[0])
tensor([
[ ],
[ ],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ]],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ]]], grad_fn=<SelectBackward0>)
06
计算注意力
计算注意力的公式如下:
将 Q、K 和 V 分解到不同的head后,现在就可以计算 Q 和 K 的缩放点积了。从上式可以看出,第一步是进行张量乘法运算。不过,K 必须先进行转置。
接下来,为了清楚起见,每个张量的 seq_length 形状将分别以Q_length、K_length 或 V_length 进行表示:
Q 的形状为 (batch_size, n_heads, Q_length, d_key)
K 的形状为 (batch_size, n_heads, K_length, d_key)
V 的形状为 (batch_size, n_heads, V_length, d_key)
现在,QK^T 的输出维度将是:
(batch_size, n_heads, Q_length, d_key) x (batch_size, n_heads, d_key, K_length)
= (batch_size, n_heads, Q_length, K_length)
每个张量中的相应序列将相互相乘。Q 中的第一个序列与 K 中的第一个序列相乘,Q 中的第二个序列与 K 中的第二个序列相乘,以此类推。当这些序列相互相乘时,Q 的第一个序列的第一个头与 K 的第一个序列的第一个头相乘,Q 的第一个序列的第二个头与 K 的第一个序列的第二个头相乘,以此类推。在对这些heads进行乘法运算时,Q head中形状为(Q_length,d_key)的每个标记将与 K head中形状为(d_key,K_length)的每个标记进行乘法运算。结果就是一个(Q_length, K_length)矩阵,它显示了每个词与其他每个词(包括它自己)的注意力得分。这就是 "自注意力机制 "名称的由来。
QK^T 按 d_key 缩放,有助于使下一步的 softmax 函数输出不那么集中在 0 和 1 附近。
继续示例,按比例点乘的输出结果为 (3, 4, 6, 2) x (3, 4, 2, 6) = (3, 4, 6, 6)。
# calculate scaled dot product
# (batch_size, n_heads, Q_length, K_length)
scaled_dot_prod = torch.matmul(Q, K.permute(0, 1, 3, 2)) / math.sqrt(d_key)
然后将该张量通过 softmax 函数来创建概率分布。请注意 softmax 是如何应用于每个head矩阵的每一行的。
# apply softmax to get context for each token and others
# (batch_size, n_heads, Q_length, K_length)
attn_probs = torch.softmax(scaled_dot_prod, dim=-1)
可以使用 matplotlib 中的 imshow 将这些注意力概率可视化。附录中提供了一个名为 display_attention 的函数,可以同时显示序列的所有头部。白色更接近 1,黑色更接近 0。
display_attention(["i", "wonder", "what", "will", "come", "next"],
[ ],
attn_probs[0], 4, 2, 2)
结果如下:
display_attention(["this", "is", "a", "basic", "example", "paragraph"],
[ ],
attn_probs[1], 4, 2, 2)
结果如下:
display_attention(["hello", "what", "is", "a", "basic", "split"],
[ ],
attn_probs[2], 4, 2, 2)
结果如下:
这显示了每个Query(行)和Key(列)之间的关系。序列中单词之间的每个交叉点都代表了关系的强度。由于这些值是由随机权重生成的,因此目前还不能显示任何有效的关系。下图展示了编码器经过训练后的注意力得分矩阵的可视化效果。
计算出这些概率后,下一步就是将它们与 V 张量相乘。此时每个词的上下文基本上都被汇总在一起。此时矩阵维度变化如下:
代码实现如下:
# multiply attention and values to get reweighted values
# (batch_size, n_heads, Q_length, d_key)
A = torch.matmul(attn_probs, V)
下面是本例中每个步骤的示意图。
这里究竟发生了什么?首先Q 和 K 都是相同序列的表示,它们被分解成不同head的Query和Key。这将计算序列中每个单词与序列中其他单词之间的关系。这发生在 n_heads 个子空间中。计算每个词的Query表示和每个词的Key表示之间的点积。这反映了每个词与其他词之间的 "强度 "或 "权重"。通过训练,这种强度将有助于模型理解哪些词之间的 "权重 "更高;这将表明哪些词对于上下文和预测最为重要。再次强调一下,Query与Key相乘,以产生每个标记与序列中所有其他标记之间的权重。
softmax 张量中的每一行表示一个标记Token与同一序列中其他标记Token之间的关系。在 V 中,每一列都代表一个序列。将这两个张量相乘,可对Values进行重新加权,并计算出每个头部或子空间中每个标记Token的最重要的上下文信息。
下图显示了单个head在一个序列中的自注意力计算过程:
07
通过输出层
此时,在通过最后的线性层之前,可以将这些head重新concat起来,这就是多头注意力机制。concat操作会逆转原来进行的拆分。第一步是对 n_heads 和 Q_length 进行转置。第二步是将 n_heads 和 d_key 连接起来,得到 d_model。完成后,A 的形状将为(batch_size、Q_length、d_model)。
# transpose from (3, 4, 6, 2) -> (3, 6, 4, 2)
A = A.permute(0, 2, 1, 3).contiguous()
# reshape from (3, 6, 4, 2) -> (3, 6, 8) = (batch_size, Q_length, d_model)
A = A.view(batch_size, -1, n_heads*d_key)
print(A)
结果如下:
最后一步是通过 Wo 传递 A,Wo 的形状为(d_model,d_model)。再次,权重张量在批次中的每个序列中进行广播。最终的输出将保持其形状:
代码如下:
Wo = nn.Linear(d_model, d_model)
# (3, 6, 8) x (broadcast 8, 8) = (3, 6, 8)
output = Wo(A)
print(output)
结果如下:
该输出将被传递到下一层,其中包括残差加法和层归一化。这些内容将在以后的文章中介绍。
08
注意事项
在解释了多头注意力的各个组成部分后,实现方法就简单明了了,只需利用前面列出的相同组件即可。唯一增加的是Dropout层。此外,代码中有一个掩码的实现,暂时可以忽略。它不会对后面的示例产生影响。我们将在介绍编码器和解码器时对其进行解释。
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int = 512, n_heads: int = 8, dropout: float = 0.1):
"""
Args:
d_model: dimension of embeddings
n_heads: number of self attention heads
dropout: probability of dropout occurring
"""
super().__init__()
assert d_model % n_heads == 0 # ensure an even num of heads
self.d_model = d_model # 512 dim
self.n_heads = n_heads # 8 heads
self.d_key = d_model // n_heads # assume d_value equals d_key | 512/8=64
self.Wq = nn.Linear(d_model, d_model) # query weights
self.Wk = nn.Linear(d_model, d_model) # key weights
self.Wv = nn.Linear(d_model, d_model) # value weights
self.Wo = nn.Linear(d_model, d_model) # output weights
self.dropout = nn.Dropout(p=dropout) # initialize dropout layer
def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Tensor = None):
"""
Args:
query: query vector (batch_size, q_length, d_model)
key: key vector (batch_size, k_length, d_model)
value: value vector (batch_size, s_length, d_model)
mask: mask for decoder
Returns:
output: attention values (batch_size, q_length, d_model)
attn_probs: softmax scores (batch_size, n_heads, q_length, k_length)
"""
batch_size = key.size(0)
# calculate query, key, and value tensors
Q = self.Wq(query) # (32, 10, 512) x (512, 512) = (32, 10, 512)
K = self.Wk(key) # (32, 10, 512) x (512, 512) = (32, 10, 512)
V = self.Wv(value) # (32, 10, 512) x (512, 512) = (32, 10, 512)
# split each tensor into n-heads to compute attention
# query tensor
Q = Q.view(batch_size, # (32, 10, 512) -> (32, 10, 8, 64)
-1, # -1 = q_length
self.n_heads,
self.d_key
).permute(0, 2, 1, 3) # (32, 10, 8, 64) -> (32, 8, 10, 64) = (batch_size, n_heads, q_length, d_key)
# key tensor
K = K.view(batch_size, # (32, 10, 512) -> (32, 10, 8, 64)
-1, # -1 = k_length
self.n_heads,
self.d_key
).permute(0, 2, 1, 3) # (32, 10, 8, 64) -> (32, 8, 10, 64) = (batch_size, n_heads, k_length, d_key)
# value tensor
V = V.view(batch_size, # (32, 10, 512) -> (32, 10, 8, 64)
-1, # -1 = v_length
self.n_heads,
self.d_key
).permute(0, 2, 1, 3) # (32, 10, 8, 64) -> (32, 8, 10, 64) = (batch_size, n_heads, v_length, d_key)
# computes attention
# scaled dot product -> QK^{T}
scaled_dot_prod = torch.matmul(Q, # (32, 8, 10, 64) x (32, 8, 64, 10) -> (32, 8, 10, 10) = (batch_size, n_heads, q_length, k_length)
K.permute(0, 1, 3, 2)
) / math.sqrt(self.d_key) # sqrt(64)
# fill those positions of product as (-1e10) where mask positions are 0
if mask is not None:
scaled_dot_prod = scaled_dot_prod.masked_fill(mask == 0, -1e10)
# apply softmax
attn_probs = torch.softmax(scaled_dot_prod, dim=-1)
# multiply by values to get attention
A = torch.matmul(self.dropout(attn_probs), V) # (32, 8, 10, 10) x (32, 8, 10, 64) -> (32, 8, 10, 64)
# (batch_size, n_heads, q_length, k_length) x (batch_size, n_heads, v_length, d_key) -> (batch_size, n_heads, q_length, d_key)
# reshape attention back to (32, 10, 512)
A = A.permute(0, 2, 1, 3).contiguous() # (32, 8, 10, 64) -> (32, 10, 8, 64)
A = A.view(batch_size, -1, self.n_heads*self.d_key) # (32, 10, 8, 64) -> (32, 10, 8*64) -> (32, 10, 512) = (batch_size, q_length, d_model)
# push through the final weight layer
output = self.Wo(A) # (32, 10, 512) x (512, 512) = (32, 10, 512)
return output, attn_probs
torch.set_printoptions(precision=2, sci_mode=False)
# convert the sequences to integers
sequences = ["I wonder what will come next!",
"This is a basic example paragraph.",
"Hello, what is a basic split?"]
# tokenize the sequences
tokenized_sequences = [tokenize(seq) for seq in sequences]
# index the sequences
indexed_sequences = [[stoi[word] for word in seq] for seq in tokenized_sequences]
# convert the sequences to a tensor
tensor_sequences = torch.tensor(indexed_sequences).long()
# vocab size
vocab_size = len(stoi)
# embedding dimensions
d_model = 8
# create the embeddings
lut = Embeddings(vocab_size, d_model) # look-up table (lut)
# create the positional encodings
pe = PositionalEncoding(d_model=d_model, dropout=0.1, max_length=10)
# embed the sequence
embeddings = lut(tensor_sequences)
# positionally encode the sequences
X = pe(embeddings)
# set the n_heads
n_heads = 4
# create the attention layer
attention = MultiHeadAttention(d_model, n_heads, dropout=0.1)
# pass X through the attention layer three times to create Q, K, and V
output, attn_probs = attention(X, X, X, mask=None)
print(output)
正如预期的那样,输出与输入的形状相同,即(3,6,8)。
使用 attn_probs 也可以预览注意力的概率。下面是第一个序列的注意力分布。
display_attention(["i", "wonder", "what", "will", "come", "next"],
[ ],
attn_probs[0], 4, 2, 2)
结果如下:
点击上方小卡片关注我
添加个人微信,进专属粉丝群!