01
引言
本文是手撕Transformer系列的第七篇。解码器(The Decoder)是Transformer结构的后半部分,它也包含了之前介绍过的所有层。
02
背景介绍
解码器层是前几篇文章中提到的子层的封装层。它采用位置嵌入处理目标序列,并将其通过屏蔽多头注意力机制。掩模Mask用于防止解码器查看序列中的下一个标记。它迫使模型仅使用前一个标记作为上下文来预测下一个标记。然后,它将通过另一种多头交叉注意力机制;该机制将编码器层的输出作为额外输入。最后,它将通过FFN。在每个子层之后,它都会执行残差加法和层归一化操作。
03
Transformer中的Decoder Layer
class DecoderLayer(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ffn: int, dropout: float):
"""
Args:
d_model: dimension of embeddings
n_heads: number of heads
d_ffn: dimension of feed-forward network
dropout: probability of dropout occurring
"""
super().__init__()
# masked multi-head attention sublayer
self.masked_attention = MultiHeadAttention(d_model, n_heads, dropout)
# layer norm for masked multi-head attention
self.masked_attn_layer_norm = nn.LayerNorm(d_model)
# multi-head attention sublayer
self.attention = MultiHeadAttention(d_model, n_heads, dropout)
# layer norm for multi-head attention
self.attn_layer_norm = nn.LayerNorm(d_model)
# position-wise feed-forward network
self.positionwise_ffn = PositionwiseFeedForward(d_model, d_ffn, dropout)
# layer norm for position-wise ffn
self.ffn_layer_norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, trg: Tensor, src: Tensor, trg_mask: Tensor, src_mask: Tensor):
"""
Args:
trg: embedded sequences (batch_size, trg_seq_length, d_model)
src: embedded sequences (batch_size, src_seq_length, d_model)
trg_mask: mask for the sequences (batch_size, 1, trg_seq_length, trg_seq_length)
src_mask: mask for the sequences (batch_size, 1, 1, src_seq_length)
Returns:
trg: sequences after self-attention (batch_size, trg_seq_length, d_model)
attn_probs: attention softmax scores
"""
# pass trg embeddings through masked multi-head attention
_trg, masked_attn_probs = self.masked_attention(trg, trg, trg, trg_mask)
# residual add and norm
trg = self.masked_attn_layer_norm(trg + self.dropout(_trg))
# pass trg and src embeddings through multi-head attention
_trg, attn_probs = self.attention(trg, src, src, src_mask)
# residual add and norm
trg = self.attn_layer_norm(trg + self.dropout(_trg))
# position-wise feed-forward network
_trg = self.positionwise_ffn(trg)
# residual add and norm
trg = self.ffn_layer_norm(trg + self.dropout(_trg))
return trg, masked_attn_probs, attn_probs
04
Decoder Stack
为了利用多头注意力子层的优势,输入Token一次通过一叠解码器层,如下图所示。这在文章开头的图片中被记为 Nx。
本模块包含最后一个线性层,用于创建对数。对数本质上是根据前面的单词计算序列中该位置上每个单词的频率。接着会通过一个 softmax 函数来创建一个概率分布,以显示序列中每个标记的可能性。这主要是通过将 d_model 投射到 vocab_size 来实现的。输出的形状为(batch_size、seq_length、vocab_size)。
class Decoder(nn.Module):
def __init__(self, vocab_size: int, d_model: int, n_layers: int,
n_heads: int, d_ffn: int, dropout: float = 0.1):
"""
Args:
vocab_size: size of the vocabulary
d_model: dimension of embeddings
n_layers: number of encoder layers
n_heads: number of heads
d_ffn: dimension of feed-forward network
dropout: probability of dropout occurring
"""
super().__init__()
# create n_layers encoders
self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ffn, dropout)
for layer in range(n_layers)])
self.dropout = nn.Dropout(dropout)
# set output layer
self.Wo = nn.Linear(d_model, vocab_size)
def forward(self, trg: Tensor, src: Tensor, trg_mask: Tensor, src_mask: Tensor):
"""
Args:
trg: embedded sequences (batch_size, trg_seq_length, d_model)
src: encoded sequences from encoder (batch_size, src_seq_length, d_model)
trg_mask: mask for the sequences (batch_size, 1, trg_seq_length, trg_seq_length)
src_mask: mask for the sequences (batch_size, 1, 1, src_seq_length)
Returns:
output: sequences after decoder (batch_size, trg_seq_length, vocab_size)
attn_probs: attention softmax scores (batch_size, n_heads, trg_seq_length, src_seq_length)
masked_attn_probs: masked attention softmax scores (batch_size, n_heads, trg_seq_length, trg_seq_length)
"""
# pass the sequences through each decoder
for layer in self.layers:
trg, masked_attn_probs, attn_probs = layer(trg, src, trg_mask, src_mask)
self.masked_attn_probs = masked_attn_probs
self.attn_probs = attn_probs
return self.Wo(trg)
05
Target Mask
要理解目标掩码的必要性,最好先看一个解码器输入和输出的例子。解码器的目标是根据编码源序列和部分目标序列预测序列中的下一个标记。要做到这一点,必须有一个 "start"标记来提示模型预测序列中的下一个标记。这就是下图中<bos>标记的用法。还需要注意的是,输入和输出到解码器的维度大小必须相同。
请记住,解码器的输入和输出必须长度相同。因此,每个目标序列在传给解码器之前,都需要去掉最后一个标记。如果目标序列存储在 trg 中,那么解码器的输入将是 trg[:,:-1],以选择除最后一个标记以外的所有内容,这可以从上面的目标输入中看到。预期输出将是 trg[:,1:],即除第一个标记外的所有内容,也就是上面看到的预期输出。
总之,与编码器层一样,解码器也需要对其输入进行掩码。输入需要填充掩码,目标序列也需要前续掩码或后续掩码。在推理时,模型将只得到一个起始标记,并必须根据它预测下一个标记。然后,在给定两个标记的情况下,它必须预测第三个标记。这个过程会一直重复,直到预测出序列末端的标记。这就是Transformer的自回归行为。换句话说,未来的标记只能根据过去的标记和编码器的嵌入来预测。
trg_seq_length = 10
subsequent_mask = torch.tril(torch.ones((seq_length, seq_length))).int()
tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ]],
dtype=torch.int32)
pad_mask = torch.Tensor([[1,1,1,1,1,1,1,0,0,0]]).unsqueeze(1).unsqueeze(2).int()
print(pad_mask)
tensor([[[[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]]]], dtype=torch.int32)
print(subsequent_mask & pad_mask)
tensor([[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ],
[ ]]]], dtype=torch.int32)
06
如何使用Source和Target Mask?
由于解码器中使用了两种多头注意力机制,因此在向解码器提供编码器的嵌入时,第一种注意力机制将使用目标掩码,第二种注意力机制将使用源掩码。
在第一种机制中,每个标记Token的概率分布将只考虑前面的标记。这反映了模型在推理过程中的行为,可以从下面的目标掩码中看出,每个标记只依赖于前面的标记:
在第二种机制中,目标序列就是Query,而Source序列就是Key。这就在每个目标标记和源标记之间创建了一个概率分布。在推理过程中,这有助于模型识别哪些目标标记最适合给定的源标记。下面是一个经过训练的注意力分布示例:
在此可视化图中,可以看到每个查询Query与其关键字Key之间的关系。例如<bos>与 eine关系密切。a 与 frau关系最密切。woman与mit有关联。with 与 einer有关联。这说明了每个查询标记Query与下一个应该预测的英语标记的关键字或德语对应词之间的关系。
为了重现这一过程,我们需要将编码器和解码器结合起来,创建一个模型,通过训练将德语翻译成英语。
07
训练简单模型
在建立模型之前,必须创建德语和英语词汇表。附录中的函数以编码器文章中的功能为基础,但针对英语和德语进行了通用化。
本模型使用与前几篇文章相同的英文示例,并使用谷歌翻译生成了对应的德文示例。
de_example = "Hallo! Dies ist ein Beispiel für einen Absatz, der in seine Grundkomponenten aufgeteilt wurde. Ich frage mich, was als nächstes kommt! Irgendwelche Ideen?"
en_example = "Hello! This is an example of a paragraph that has been split into its basic components. I wonder what will come next! Any guesses?"
# build the vocab
de_stoi = build_vocab(de_example)
en_stoi = build_vocab(en_example)
# build integer-to-string decoder for the vocab
de_itos = {v:k for k,v in de_stoi.items()}
en_itos = {v:k for k,v in en_stoi.items()}
de_sequences = ["Ich frage mich, was als nächstes kommt!",
"Dies ist ein Beispiel für einen Absatz.",
"Hallo, was ist ein Grundkomponenten?"]
en_sequences = ["I wonder what will come next!",
"This is a basic example paragraph.",
"Hello, what is a basic split?"]
# pad the sequences
max_length = 9
pad_idx = de_stoi['<pad>']
de_padded_seqs = []
en_padded_seqs = []
# pad each sequence
for de_seq, en_seq in zip(de_indexed_sequences, en_indexed_sequences):
de_padded_seqs.append(pad_seq(torch.Tensor(de_seq), max_length, pad_idx))
en_padded_seqs.append(pad_seq(torch.Tensor(en_seq), max_length, pad_idx))
# create a tensor from the padded sequences
de_tensor_sequences = torch.stack(de_padded_seqs).long()
en_tensor_sequences = torch.stack(en_padded_seqs).long()
现在,我们可以为训练准备目标序列了。每个目标序列都有九个标记:六个来自原句、一个起始标记、一个结束标记和一个填充标记。第一个序列的标记化表示如下所示:
['<bos>', 'i', 'wonder', 'what', 'will', 'come', 'next', '<eos>', '<pad>']
如前所述,解码器的输入将是该序列的子集。必须从目标序列中删除最后一个标记,以便解码器预测序列中的下一个标记:
['<bos>', 'i', 'wonder', 'what', 'will', 'come', 'next', '<eos>']
同样,序列的预期输出是序列的另一个子集。要创建预期输出,必须删除第一个标记:
['i', 'wonder', 'what', 'will', 'come', 'next', '<eos>', '<pad>']
这些子集的代码见下文,源掩码和目标掩码也可以对应生成。
# remove last token
trg = en_tensor_sequences[:,:-1]
# remove the first token
expected_output = en_tensor_sequences[:,1:]
# generate masks
src_mask = make_src_mask(de_tensor_sequences, pad_idx)
trg_mask = make_trg_mask(trg, pad_idx)
可以查看第一个序列的目标掩码:
display_mask(trg[0].int().tolist(), trg_mask[0])
结果如下:
在这里,可以创建模型。nn.Sequential 可用于源序列嵌入、目标序列嵌入、位置编码、编码器和解码器,以创建对这两者的前向传递。
# parameters
de_vocab_size = len(de_stoi)
en_vocab_size = len(en_stoi)
d_model = 32
d_ffn = d_model*4 # 32
n_heads = 4
n_layers = 3
dropout = 0.1
max_pe_length = 10
# create the embeddings
de_lut = Embeddings(de_vocab_size, d_model) # look-up table (lut)
en_lut = Embeddings(en_vocab_size, d_model)
# create the positional encodings
pe = PositionalEncoding(d_model=d_model, dropout=0.1, max_length=max_pe_length)
# embed and encode
de_embed = nn.Sequential(de_lut, pe)
en_embed = nn.Sequential(en_lut, pe)
# initialize encoder
encoder = Encoder(d_model, n_layers, n_heads, d_ffn, dropout)
# initialize the decoder
decoder = Decoder(en_vocab_size, d_model, n_layers, n_heads, d_ffn, dropout)
层创建完成后,模型可以在 nn.ModuleList 中初始化,它将所有组件存储在一个列表中,可以通过 Module 方法(如 parameters())访问。
# initialize the model
model = nn.ModuleList([de_embed, en_embed, encoder, decoder])
# normalize the weights
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
参数总数可以通过一个简单的函数预览。
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {count_parameters(model):,} trainable parameters.')
结果如下:
The model has 91,675 trainable parameters.
现在,可以对模型进行简单的前向传递,并通过求对数的 argmax 来预览预测结果。
# pass through encoder
encoded_embeddings = encoder(src=de_embed(de_tensor_sequences),
src_mask=src_mask)
# logits for each output
logits = decoder(trg=en_embed(trg), src=encoded_embeddings,
trg_mask=trg_mask,
src_mask=src_mask)
predictions = [[en_itos[tok] for tok in seq] for seq in logits.argmax(-1).tolist()]
结果如下:
[['a', '<eos>', 'basic', 'a', 'this', 'a', 'a', 'an'],
['wonder', 'into', 'any', 'wonder', 'i', 'wonder', 'wonder', 'an'],
['that', 'any', 'has', 'basic', 'split', 'wonder', 'example', 'wonder']]
# hyperparameters
LEARNING_RATE = 0.005
EPOCHS = 50
# adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
# loss function
criterion = nn.CrossEntropyLoss(ignore_index = en_stoi["<pad>"])
由于只使用了三个序列,因此可以创建一个训练循环来更新参数,并在每次迭代时预览预测结果。
# set the model to training mode
model.train()
# loop through each epoch
for i in range(EPOCHS):
epoch_loss = 0
# zero the gradients
optimizer.zero_grad()
# pass through encoder
encoded_embeddings = encoder(src=de_embed(de_tensor_sequences),
src_mask=src_mask)
# logits for each output
logits = decoder(trg=en_embed(trg), src=encoded_embeddings,
trg_mask=trg_mask,
src_mask=src_mask)
# calculate the loss
loss = criterion(logits.contiguous().view(-1, logits.shape[-1]),
expected_output.contiguous().view(-1))
# backpropagation
loss.backward()
# clip the weights
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
# update the weights
optimizer.step()
# preview the predictions
predictions = [[en_itos[tok] for tok in seq] for seq in logits.argmax(-1).tolist()]
if i % 7 == 0:
print("="*25)
print(f"epoch: {i}")
print(f"loss: {loss.item()}")
print(f"predictions: {predictions}")
结果如下:
=========================
epoch: 0
loss: 3.8633525371551514
predictions: [['an', 'an', 'an', 'an', 'an', 'an', 'an', 'an'],
['of', 'an', 'an', 'an', 'an', 'an', 'an', 'an'],
['an', 'an', 'an', 'an', 'an', 'an', 'an', 'been']]
=========================
epoch: 7
loss: 2.7589643001556396
predictions: [['i', 'i', 'i', 'i', 'i', 'i', 'i', 'i'],
['is', 'is', 'is', 'is', 'is', 'paragraph', 'is', 'is'],
['is', 'is', 'is', 'a', 'is', 'basic', 'basic', 'basic']]
=========================
epoch: 14
loss: 1.7105616331100464
predictions: [['i', 'i', 'i', 'a', 'will', '<eos>', '<eos>', '<eos>'],
['hello', 'is', 'this', 'is', 'paragraph', 'paragraph', '<eos>', 'paragraph'],
['hello', 'example', 'is', 'is', 'basic', '<eos>', '<eos>', '<eos>']]
=========================
epoch: 21
loss: 1.2171827554702759
predictions: [['i', 'what', 'what', 'next', 'next', 'next', '<eos>', '<eos>'],
['this', 'is', 'a', 'basic', 'paragraph', 'a', 'paragraph', '<eos>'],
['this', 'basic', 'is', 'a', 'basic', 'basic', '<eos>', '<eos>']]
=========================
epoch: 28
loss: 0.8726108074188232
predictions: [['i', 'what', 'what', 'will', 'come', '<eos>', '<eos>', '<eos>'],
['this', 'is', 'a', 'basic', 'paragraph', 'example', '<eos>', 'paragraph'],
['hello', 'what', 'is', 'a', 'basic', 'split', '<eos>', '<eos>']]
=========================
epoch: 35
loss: 0.6604534387588501
predictions: [['i', 'wonder', 'next', 'will', 'come', 'next', '<eos>', 'next'],
['this', 'is', 'a', 'basic', 'example', 'paragraph', '<eos>', 'paragraph'],
['hello', 'what', 'is', 'a', 'basic', 'basic', '<eos>', '<eos>']]
=========================
epoch: 42
loss: 0.3311622142791748
predictions: [['i', 'wonder', 'what', 'will', 'come', 'next', '<eos>', '<eos>'],
['this', 'is', 'a', 'basic', 'paragraph', 'paragraph', '<eos>', '<eos>'],
['hello', 'what', 'is', 'a', 'basic', 'split', '<eos>', '<eos>']]
=========================
epoch: 49
loss: 0.19808804988861084
predictions: [['i', 'wonder', 'what', 'will', 'come', 'next', '<eos>', '<eos>'],
['this', 'is', 'a', 'basic', 'example', 'paragraph', '<eos>', 'paragraph'],
['hello', 'what', 'is', 'a', 'basic', 'split', '<eos>', 'split']]
# convert the indices to strings
decoder_input = [en_itos[i] for i in trg[0].tolist()]
display_attention(de_tokenized_sequences[0], decoder_input, decoder.attn_probs[0],n_heads, n_rows=2, n_cols=2)
display_attention(decoder_input, decoder_input, decoder.masked_attn_probs[0],n_heads, n_rows=2, n_cols=2)
Tokenization
def tokenize(sequence, special_toks=True):
# remove punctuation
for punc in ["!", ".", "?", ","]:
sequence = sequence.replace(punc, "")
# split the sequence on spaces and lowercase each token
sequence = [token.lower() for token in sequence.split(" ")]
# add beginning and end tokens
if special_toks:
sequence = ['<bos>'] + sequence + ['<eos>']
return sequence
Build Vocabulary
def build_vocab(data):
# tokenize the data and remove duplicates
vocab = list(set(tokenize(data, special_toks=False)))
# sort the vocabulary
vocab.sort()
# add special tokens
vocab = ['<pad>', '<bos>', '<eos>'] + vocab
# assign an integer to each word
stoi = {word:i for i, word in enumerate(vocab)}
return stoi
Padding
def pad_seq(seq: Tensor, max_length: int = 10, pad_idx: int = 0):
"""
Args:
seq: raw sequence (batch_size, seq_length)
max_length: maximum length of a sequence
pad_idx: index for padding tokens
Returns:
padded seq: padded sequence (batch_size, max_length)
"""
pad_to_add = max_length - len(seq) # amount of padding to add
return pad(seq,(0, pad_to_add), value=pad_idx,)
Source Mask
def make_src_mask(src: Tensor, pad_idx: int = 0):
"""
Args:
src: raw sequences with padding (batch_size, seq_length)
Returns:
src_mask: mask for each sequence (batch_size, 1, 1, seq_length)
"""
# assign 1 to tokens that need attended to and 0 to padding tokens, then add 2 dimensions
src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
return src_mask
Target Mask
def make_trg_mask(trg: Tensor, pad_idx: int = 0):
"""
Args:
trg: raw sequences with padding (batch_size, seq_length)
Returns:
trg_mask: mask for each sequence (batch_size, 1, seq_length, seq_length)
"""
seq_length = trg.shape[1]
# assign True to tokens that need attended to and False to padding tokens, then add 2 dimensions
trg_mask = (trg != pad_idx).unsqueeze(1).unsqueeze(2) # (batch_size, 1, 1, seq_length)
# generate subsequent mask
trg_sub_mask = torch.tril(torch.ones((seq_length, seq_length))).bool() # (batch_size, 1, seq_length, seq_length)
# bitwise "and" operator | 0 & 0 = 0, 1 & 1 = 1, 1 & 0 = 0
trg_mask = trg_mask & trg_sub_mask
return trg_mask
Display Mask
def display_mask(sentence: list, mask: Tensor):
"""
Display the target mask for each sequence.
Args:
sequence: sequence to be masked
mask: target mask for the heads
"""
# figure size
fig = plt.figure(figsize=(8,8))
# create a plot
ax = fig.add_subplot(mask.shape[0], 1, 1)
# select the respective head and make it a numpy array for plotting
mask = mask.squeeze(0).cpu().detach().numpy()
# plot the matrix
cax = ax.matshow(mask, cmap='bone')
# set the size of the labels
ax.tick_params(labelsize=12)
# set the indices for the tick marks
ax.set_xticks(range(len(sentence)))
ax.set_yticks(range(len(sentence)))
# set labels
ax.xaxis.set_label_position('top')
ax.set_ylabel("$Q$")
ax.set_xlabel("$K^T$")
if isinstance(sentence[0], int):
# convert indices to German/English
sentence = [en_itos[tok] for tok in sentence]
ax.set_xticklabels(sentence, rotation=75)
ax.set_yticklabels(sentence)
plt.show()
Display Attention
def display_attention(sentence: list, translation: list, attention: Tensor,
n_heads: int = 8, n_rows: int = 4, n_cols: int = 2):
"""
Display the attention matrix for each head of a sequence.
Args:
sentence: German sentence to be translated to English; list
translation: English sentence predicted by the model
attention: attention scores for the heads
n_heads: number of heads
n_rows: number of rows
n_cols: number of columns
"""
# ensure the number of rows and columns are equal to the number of heads
assert n_rows * n_cols == n_heads
# figure size
fig = plt.figure(figsize=(15,25))
# visualize each head
for i in range(n_heads):
# create a plot
ax = fig.add_subplot(n_rows, n_cols, i+1)
# select the respective head and make it a numpy array for plotting
_attention = attention.squeeze(0)[i,:,:].cpu().detach().numpy()
# plot the matrix
cax = ax.matshow(_attention, cmap='bone')
# set the size of the labels
ax.tick_params(labelsize=12)
# set the indices for the tick marks
ax.set_xticks(range(len(sentence)))
ax.set_yticks(range(len(translation)))
# if the provided sequences are sentences or indices
if isinstance(sentence[0], str):
ax.set_xticklabels([t.lower() for t in sentence], rotation=45)
ax.set_yticklabels(translation)
elif isinstance(sentence[0], int):
ax.set_xticklabels(sentence)
ax.set_yticklabels(translation)
plt.show()
点击上方小卡片关注我
添加个人微信,进专属粉丝群!