手撕Transformer之The Decoder

文摘   科技   2024-10-08 06:40   江苏  




本文是手撕Transformer系列的第七篇。解码器(The Decoder)是Transformer结构的后半部分,它也包含了之前介绍过的所有层。






Transformer中的Decoder Layer

class DecoderLayer(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ffn: int, dropout: float): """ Args: d_model: dimension of embeddings n_heads: number of heads d_ffn: dimension of feed-forward network dropout: probability of dropout occurring """ super().__init__() # masked multi-head attention sublayer self.masked_attention = MultiHeadAttention(d_model, n_heads, dropout) # layer norm for masked multi-head attention self.masked_attn_layer_norm = nn.LayerNorm(d_model)
# multi-head attention sublayer self.attention = MultiHeadAttention(d_model, n_heads, dropout) # layer norm for multi-head attention self.attn_layer_norm = nn.LayerNorm(d_model) # position-wise feed-forward network self.positionwise_ffn = PositionwiseFeedForward(d_model, d_ffn, dropout) # layer norm for position-wise ffn self.ffn_layer_norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, trg: Tensor, src: Tensor, trg_mask: Tensor, src_mask: Tensor): """ Args: trg: embedded sequences (batch_size, trg_seq_length, d_model) src: embedded sequences (batch_size, src_seq_length, d_model) trg_mask: mask for the sequences (batch_size, 1, trg_seq_length, trg_seq_length) src_mask: mask for the sequences (batch_size, 1, 1, src_seq_length)
Returns: trg: sequences after self-attention (batch_size, trg_seq_length, d_model) attn_probs: attention softmax scores """ # pass trg embeddings through masked multi-head attention _trg, masked_attn_probs = self.masked_attention(trg, trg, trg, trg_mask)
# residual add and norm trg = self.masked_attn_layer_norm(trg + self.dropout(_trg)) # pass trg and src embeddings through multi-head attention _trg, attn_probs = self.attention(trg, src, src, src_mask)
# residual add and norm trg = self.attn_layer_norm(trg + self.dropout(_trg))
# position-wise feed-forward network _trg = self.positionwise_ffn(trg)
# residual add and norm trg = self.ffn_layer_norm(trg + self.dropout(_trg))
return trg, masked_attn_probs, attn_probs


Decoder Stack

为了利用多头注意力子层的优势,输入Token一次通过一叠解码器层,如下图所示。这在文章开头的图片中被记为 Nx。

本模块包含最后一个线性层,用于创建对数。对数本质上是根据前面的单词计算序列中该位置上每个单词的频率。接着会通过一个 softmax 函数来创建一个概率分布,以显示序列中每个标记的可能性。这主要是通过将 d_model 投射到 vocab_size 来实现的。输出的形状为(batch_size、seq_length、vocab_size)

class Decoder(nn.Module):  def __init__(self, vocab_size: int, d_model: int, n_layers: int,                n_heads: int, d_ffn: int, dropout: float = 0.1):    """    Args:        vocab_size:   size of the vocabulary        d_model:      dimension of embeddings        n_layers:     number of encoder layers        n_heads:      number of heads        d_ffn:        dimension of feed-forward network        dropout:      probability of dropout occurring    """    super().__init__()
# create n_layers encoders self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ffn, dropout) for layer in range(n_layers)]) self.dropout = nn.Dropout(dropout)
# set output layer self.Wo = nn.Linear(d_model, vocab_size) def forward(self, trg: Tensor, src: Tensor, trg_mask: Tensor, src_mask: Tensor): """ Args: trg: embedded sequences (batch_size, trg_seq_length, d_model) src: encoded sequences from encoder (batch_size, src_seq_length, d_model) trg_mask: mask for the sequences (batch_size, 1, trg_seq_length, trg_seq_length) src_mask: mask for the sequences (batch_size, 1, 1, src_seq_length)
Returns: output: sequences after decoder (batch_size, trg_seq_length, vocab_size) attn_probs: attention softmax scores (batch_size, n_heads, trg_seq_length, src_seq_length) masked_attn_probs: masked attention softmax scores (batch_size, n_heads, trg_seq_length, trg_seq_length) """
# pass the sequences through each decoder for layer in self.layers: trg, masked_attn_probs, attn_probs = layer(trg, src, trg_mask, src_mask)
self.masked_attn_probs = masked_attn_probs self.attn_probs = attn_probs
return self.Wo(trg)


 Target Mask

要理解目标掩码的必要性,最好先看一个解码器输入和输出的例子。解码器的目标是根据编码源序列和部分目标序列预测序列中的下一个标记。要做到这一点,必须有一个 "start"标记来提示模型预测序列中的下一个标记。这就是下图中<bos>标记的用法。还需要注意的是,输入和输出到解码器的维度大小必须相同。

如果目标是让模型将 "Wie heißt du? "翻译成 "what is your name?",那么编码器将对源序列的含义进行编码,然后将其传递给解码器。鉴于<bos>标记和编码嵌入特征,解码器应该预测 "what"。然后,"What "被附加到<bos>上,形成新的输入,即" <bos> What"。这就是解码器的输入被视为 "右移 "的原因。这可以传递给解码器来预测 "What is"。这个标记被附加到前一个输入中,以创建新的输入" <bos> what is"。然后将其传递给解码器,以预测 "what is your"。这一过程不断重复,直到模型预测出<eos>标记。
给定目标序列为"<bos> What is your name? <eos>",模型可以通过目标掩码同时学习每次迭代:


请记住,解码器的输入和输出必须长度相同。因此,每个目标序列在传给解码器之前,都需要去掉最后一个标记。如果目标序列存储在 trg 中,那么解码器的输入将是 trg[:,:-1],以选择除最后一个标记以外的所有内容,这可以从上面的目标输入中看到。预期输出将是 trg[:,1:],即除第一个标记外的所有内容,也就是上面看到的预期输出。


为了模仿这种行为,模型使用后续掩码同时学习所有这些迭代。PyTorch中的 torch.tril 可用于创建后续掩码。它的形状为 (trg_seq_length, trg_seq_length)
trg_seq_length = 10
subsequent_mask = torch.tril(torch.ones((seq_length, seq_length))).int()
tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0],        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],         dtype=torch.int32)
pad_mask = torch.Tensor([[1,1,1,1,1,1,1,0,0,0]]).unsqueeze(1).unsqueeze(2).int()print(pad_mask)
tensor([[[[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]]]], dtype=torch.int32)
使用 & 运算符可以轻松实现这一操作,只有当两个掩码都是 1 时,才会返回 1。
print(subsequent_mask & pad_mask)
tensor([[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],          [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],          [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],          [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],          [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],          [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],          [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],          [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],          [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],          [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]]]], dtype=torch.int32)
这个最终目标掩码必须为批次中的每个序列创建,这意味着它的形状为(batch_size, 1, trg_seq_length, trg_seq_length)。


 如何使用Source和Target Mask?




在此可视化图中,可以看到每个查询Query与其关键字Key之间的关系。例如<bos>与 eine关系密切。a 与 frau关系最密切。woman与mit有关联。with 与 einer有关联。这说明了每个查询标记Query与下一个应该预测的英语标记的关键字或德语对应词之间的关系。






de_example = "Hallo! Dies ist ein Beispiel für einen Absatz, der in seine Grundkomponenten aufgeteilt wurde. Ich frage mich, was als nächstes kommt! Irgendwelche Ideen?"en_example = "Hello! This is an example of a paragraph that has been split into its basic components. I wonder what will come next! Any guesses?"
# build the vocabde_stoi = build_vocab(de_example)en_stoi = build_vocab(en_example)
# build integer-to-string decoder for the vocabde_itos = {v:k for k,v in de_stoi.items()}en_itos = {v:k for k,v in en_stoi.items()}

de_sequences = ["Ich frage mich, was als nächstes kommt!",                "Dies ist ein Beispiel für einen Absatz.",                "Hallo, was ist ein Grundkomponenten?"]
en_sequences = ["I wonder what will come next!", "This is a basic example paragraph.", "Hello, what is a basic split?"]
# pad the sequencesmax_length = 9pad_idx = de_stoi['<pad>']
de_padded_seqs = []en_padded_seqs = []
# pad each sequencefor de_seq, en_seq in zip(de_indexed_sequences, en_indexed_sequences): de_padded_seqs.append(pad_seq(torch.Tensor(de_seq), max_length, pad_idx)) en_padded_seqs.append(pad_seq(torch.Tensor(en_seq), max_length, pad_idx))
# create a tensor from the padded sequencesde_tensor_sequences = torch.stack(de_padded_seqs).long()en_tensor_sequences = torch.stack(en_padded_seqs).long()


['<bos>', 'i', 'wonder', 'what', 'will', 'come', 'next', '<eos>', '<pad>']


['<bos>', 'i', 'wonder', 'what', 'will', 'come', 'next', '<eos>']


['i', 'wonder', 'what', 'will', 'come', 'next', '<eos>', '<pad>']


# remove last tokentrg = en_tensor_sequences[:,:-1] 
# remove the first tokenexpected_output = en_tensor_sequences[:,1:]
# generate maskssrc_mask = make_src_mask(de_tensor_sequences, pad_idx)trg_mask = make_trg_mask(trg, pad_idx)


display_mask(trg[0].int().tolist(), trg_mask[0])


在这里,可以创建模型。nn.Sequential 可用于源序列嵌入、目标序列嵌入、位置编码、编码器和解码器,以创建对这两者的前向传递。

# parametersde_vocab_size = len(de_stoi)en_vocab_size = len(en_stoi)d_model = 32d_ffn = d_model*4 # 32n_heads = 4n_layers = 3dropout = 0.1max_pe_length = 10
# create the embeddingsde_lut = Embeddings(de_vocab_size, d_model) # look-up table (lut)en_lut = Embeddings(en_vocab_size, d_model)
# create the positional encodingspe = PositionalEncoding(d_model=d_model, dropout=0.1, max_length=max_pe_length)
# embed and encodede_embed = nn.Sequential(de_lut, pe)en_embed = nn.Sequential(en_lut, pe)
# initialize encoderencoder = Encoder(d_model, n_layers, n_heads, d_ffn, dropout)
# initialize the decoderdecoder = Decoder(en_vocab_size, d_model, n_layers, n_heads, d_ffn, dropout)

层创建完成后,模型可以在 nn.ModuleList 中初始化,它将所有组件存储在一个列表中,可以通过 Module 方法(如 parameters())访问。

# initialize the modelmodel = nn.ModuleList([de_embed, en_embed, encoder, decoder])
# normalize the weightsfor p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p)


def count_parameters(model):    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {count_parameters(model):,} trainable parameters.')


The model has 91,675 trainable parameters.

现在,可以对模型进行简单的前向传递,并通过求对数的 argmax 来预览预测结果。

# pass through encoderencoded_embeddings = encoder(src=de_embed(de_tensor_sequences),                                src_mask=src_mask)
# logits for each outputlogits = decoder(trg=en_embed(trg), src=encoded_embeddings, trg_mask=trg_mask, src_mask=src_mask)
predictions = [[en_itos[tok] for tok in seq] for seq in logits.argmax(-1).tolist()]


[['a', '<eos>', 'basic', 'a', 'this', 'a', 'a', 'an'], ['wonder', 'into', 'any', 'wonder', 'i', 'wonder', 'wonder', 'an'], ['that', 'any', 'has', 'basic', 'split', 'wonder', 'example', 'wonder']]
# hyperparametersLEARNING_RATE = 0.005EPOCHS = 50
# adam optimizeroptimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
# loss functioncriterion = nn.CrossEntropyLoss(ignore_index = en_stoi["<pad>"])


# set the model to training modemodel.train()
# loop through each epochfor i in range(EPOCHS): epoch_loss = 0
# zero the gradients optimizer.zero_grad()
# pass through encoder encoded_embeddings = encoder(src=de_embed(de_tensor_sequences), src_mask=src_mask)
# logits for each output logits = decoder(trg=en_embed(trg), src=encoded_embeddings, trg_mask=trg_mask, src_mask=src_mask) # calculate the loss loss = criterion(logits.contiguous().view(-1, logits.shape[-1]), expected_output.contiguous().view(-1)) # backpropagation loss.backward()
# clip the weights torch.nn.utils.clip_grad_norm_(model.parameters(), 1) # update the weights optimizer.step() # preview the predictions predictions = [[en_itos[tok] for tok in seq] for seq in logits.argmax(-1).tolist()]
if i % 7 == 0: print("="*25) print(f"epoch: {i}") print(f"loss: {loss.item()}") print(f"predictions: {predictions}")


=========================epoch: 0loss: 3.8633525371551514predictions: [['an', 'an', 'an', 'an', 'an', 'an', 'an', 'an'],               ['of', 'an', 'an', 'an', 'an', 'an', 'an', 'an'],               ['an', 'an', 'an', 'an', 'an', 'an', 'an', 'been']]=========================epoch: 7loss: 2.7589643001556396predictions: [['i', 'i', 'i', 'i', 'i', 'i', 'i', 'i'],               ['is', 'is', 'is', 'is', 'is', 'paragraph', 'is', 'is'],               ['is', 'is', 'is', 'a', 'is', 'basic', 'basic', 'basic']]=========================epoch: 14loss: 1.7105616331100464predictions: [['i', 'i', 'i', 'a', 'will', '<eos>', '<eos>', '<eos>'],               ['hello', 'is', 'this', 'is', 'paragraph', 'paragraph', '<eos>', 'paragraph'],               ['hello', 'example', 'is', 'is', 'basic', '<eos>', '<eos>', '<eos>']]=========================epoch: 21loss: 1.2171827554702759predictions: [['i', 'what', 'what', 'next', 'next', 'next', '<eos>', '<eos>'],               ['this', 'is', 'a', 'basic', 'paragraph', 'a', 'paragraph', '<eos>'],               ['this', 'basic', 'is', 'a', 'basic', 'basic', '<eos>', '<eos>']]=========================epoch: 28loss: 0.8726108074188232predictions: [['i', 'what', 'what', 'will', 'come', '<eos>', '<eos>', '<eos>'],               ['this', 'is', 'a', 'basic', 'paragraph', 'example', '<eos>', 'paragraph'],               ['hello', 'what', 'is', 'a', 'basic', 'split', '<eos>', '<eos>']]=========================epoch: 35loss: 0.6604534387588501predictions: [['i', 'wonder', 'next', 'will', 'come', 'next', '<eos>', 'next'],               ['this', 'is', 'a', 'basic', 'example', 'paragraph', '<eos>', 'paragraph'],               ['hello', 'what', 'is', 'a', 'basic', 'basic', '<eos>', '<eos>']]=========================epoch: 42loss: 0.3311622142791748predictions: [['i', 'wonder', 'what', 'will', 'come', 'next', '<eos>', '<eos>'],               ['this', 'is', 'a', 'basic', 'paragraph', 'paragraph', '<eos>', '<eos>'],               ['hello', 'what', 'is', 'a', 'basic', 'split', '<eos>', '<eos>']]=========================epoch: 49loss: 0.19808804988861084predictions: [['i', 'wonder', 'what', 'will', 'come', 'next', '<eos>', '<eos>'],               ['this', 'is', 'a', 'basic', 'example', 'paragraph', '<eos>', 'paragraph'],              ['hello', 'what', 'is', 'a', 'basic', 'split', '<eos>', 'split']]
在第50个epoch中,模型成功预测了所有三个序列。可以可视化第一个序列的source和 Target之间的解码器注意力。
# convert the indices to stringsdecoder_input = [en_itos[i] for i in trg[0].tolist()]
display_attention(de_tokenized_sequences[0], decoder_input, decoder.attn_probs[0],n_heads, n_rows=2, n_cols=2)

display_attention(decoder_input, decoder_input, decoder.masked_attn_probs[0],n_heads, n_rows=2, n_cols=2)



  • Tokenization

def tokenize(sequence, special_toks=True):  # remove punctuation  for punc in ["!", ".", "?", ","]:    sequence = sequence.replace(punc, "")    # split the sequence on spaces and lowercase each token  sequence = [token.lower() for token in sequence.split(" ")]
# add beginning and end tokens if special_toks: sequence = ['<bos>'] + sequence + ['<eos>']
return sequence
  • Build Vocabulary

def build_vocab(data):  # tokenize the data and remove duplicates  vocab = list(set(tokenize(data, special_toks=False)))
# sort the vocabulary vocab.sort()
# add special tokens vocab = ['<pad>', '<bos>', '<eos>'] + vocab
# assign an integer to each word stoi = {word:i for i, word in enumerate(vocab)}
return stoi
  • Padding

def pad_seq(seq: Tensor, max_length: int = 10, pad_idx: int = 0):  """  Args:      seq:          raw sequence (batch_size, seq_length)      max_length:   maximum length of a sequence      pad_idx:      index for padding tokens                 Returns:      padded seq:   padded sequence (batch_size, max_length)  """  pad_to_add = max_length - len(seq) # amount of padding to add    return pad(seq,(0, pad_to_add), value=pad_idx,)
  • Source Mask

def make_src_mask(src: Tensor, pad_idx: int = 0):  """  Args:      src:          raw sequences with padding        (batch_size, seq_length)                    Returns:      src_mask:     mask for each sequence            (batch_size, 1, 1, seq_length)  """  # assign 1 to tokens that need attended to and 0 to padding tokens, then add 2 dimensions  src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
return src_mask
  • Target Mask

def make_trg_mask(trg: Tensor, pad_idx: int = 0):  """  Args:      trg:          raw sequences with padding        (batch_size, seq_length)                    Returns:      trg_mask:     mask for each sequence            (batch_size, 1, seq_length, seq_length)  """
seq_length = trg.shape[1]
# assign True to tokens that need attended to and False to padding tokens, then add 2 dimensions trg_mask = (trg != pad_idx).unsqueeze(1).unsqueeze(2) # (batch_size, 1, 1, seq_length)
# generate subsequent mask trg_sub_mask = torch.tril(torch.ones((seq_length, seq_length))).bool() # (batch_size, 1, seq_length, seq_length)
# bitwise "and" operator | 0 & 0 = 0, 1 & 1 = 1, 1 & 0 = 0 trg_mask = trg_mask & trg_sub_mask
return trg_mask
  • Display Mask

def display_mask(sentence: list, mask: Tensor):  """    Display the target mask for each sequence.
Args: sequence: sequence to be masked mask: target mask for the heads """ # figure size fig = plt.figure(figsize=(8,8)) # create a plot ax = fig.add_subplot(mask.shape[0], 1, 1)
# select the respective head and make it a numpy array for plotting mask = mask.squeeze(0).cpu().detach().numpy() # plot the matrix cax = ax.matshow(mask, cmap='bone')
# set the size of the labels ax.tick_params(labelsize=12)
# set the indices for the tick marks ax.set_xticks(range(len(sentence))) ax.set_yticks(range(len(sentence)))
# set labels ax.xaxis.set_label_position('top') ax.set_ylabel("$Q$") ax.set_xlabel("$K^T$")
if isinstance(sentence[0], int): # convert indices to German/English sentence = [en_itos[tok] for tok in sentence]
ax.set_xticklabels(sentence, rotation=75) ax.set_yticklabels(sentence)
  • Display Attention

def display_attention(sentence: list, translation: list, attention: Tensor,                       n_heads: int = 8, n_rows: int = 4, n_cols: int = 2):  """    Display the attention matrix for each head of a sequence.
Args: sentence: German sentence to be translated to English; list translation: English sentence predicted by the model attention: attention scores for the heads n_heads: number of heads n_rows: number of rows n_cols: number of columns """ # ensure the number of rows and columns are equal to the number of heads assert n_rows * n_cols == n_heads # figure size fig = plt.figure(figsize=(15,25)) # visualize each head for i in range(n_heads): # create a plot ax = fig.add_subplot(n_rows, n_cols, i+1) # select the respective head and make it a numpy array for plotting _attention = attention.squeeze(0)[i,:,:].cpu().detach().numpy()
# plot the matrix cax = ax.matshow(_attention, cmap='bone')
# set the size of the labels ax.tick_params(labelsize=12)
# set the indices for the tick marks ax.set_xticks(range(len(sentence))) ax.set_yticks(range(len(translation)))
# if the provided sequences are sentences or indices if isinstance(sentence[0], str): ax.set_xticklabels([t.lower() for t in sentence], rotation=45) ax.set_yticklabels(translation) elif isinstance(sentence[0], int): ax.set_xticklabels(sentence) ax.set_yticklabels(translation)


