手撕Transformer之The Decoder

文摘   科技   2024-10-08 06:40   江苏  


点击蓝字
 
关注我们










01


引言



本文是手撕Transformer系列的第七篇。解码器(The Decoder)是Transformer结构的后半部分,它也包含了之前介绍过的所有层。

闲话少说,我们直接开始吧!







02


背景介绍


解码器层是前几篇文章中提到的子层的封装层。它采用位置嵌入处理目标序列,并将其通过屏蔽多头注意力机制。掩模Mask用于防止解码器查看序列中的下一个标记。它迫使模型仅使用前一个标记作为上下文来预测下一个标记。然后,它将通过另一种多头交叉注意力机制;该机制将编码器层的输出作为额外输入。最后,它将通过FFN。在每个子层之后,它都会执行残差加法和层归一化操作。




03


Transformer中的Decoder Layer


如上所述,解码器只不过是子层的包装。它实现了两个多头注意力层和一个前馈网络,每个子层之后都进行了层归一化和残差加法。
class DecoderLayer(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ffn: int, dropout: float): """ Args: d_model: dimension of embeddings n_heads: number of heads d_ffn: dimension of feed-forward network dropout: probability of dropout occurring """ super().__init__() # masked multi-head attention sublayer self.masked_attention = MultiHeadAttention(d_model, n_heads, dropout) # layer norm for masked multi-head attention self.masked_attn_layer_norm = nn.LayerNorm(d_model)
# multi-head attention sublayer self.attention = MultiHeadAttention(d_model, n_heads, dropout) # layer norm for multi-head attention self.attn_layer_norm = nn.LayerNorm(d_model) # position-wise feed-forward network self.positionwise_ffn = PositionwiseFeedForward(d_model, d_ffn, dropout) # layer norm for position-wise ffn self.ffn_layer_norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, trg: Tensor, src: Tensor, trg_mask: Tensor, src_mask: Tensor): """ Args: trg: embedded sequences (batch_size, trg_seq_length, d_model) src: embedded sequences (batch_size, src_seq_length, d_model) trg_mask: mask for the sequences (batch_size, 1, trg_seq_length, trg_seq_length) src_mask: mask for the sequences (batch_size, 1, 1, src_seq_length)
Returns: trg: sequences after self-attention (batch_size, trg_seq_length, d_model) attn_probs: attention softmax scores """ # pass trg embeddings through masked multi-head attention _trg, masked_attn_probs = self.masked_attention(trg, trg, trg, trg_mask)
# residual add and norm trg = self.masked_attn_layer_norm(trg + self.dropout(_trg)) # pass trg and src embeddings through multi-head attention _trg, attn_probs = self.attention(trg, src, src, src_mask)
# residual add and norm trg = self.attn_layer_norm(trg + self.dropout(_trg))
# position-wise feed-forward network _trg = self.positionwise_ffn(trg)
# residual add and norm trg = self.ffn_layer_norm(trg + self.dropout(_trg))
return trg, masked_attn_probs, attn_probs





04


Decoder Stack


为了利用多头注意力子层的优势,输入Token一次通过一叠解码器层,如下图所示。这在文章开头的图片中被记为 Nx。

本模块包含最后一个线性层,用于创建对数。对数本质上是根据前面的单词计算序列中该位置上每个单词的频率。接着会通过一个 softmax 函数来创建一个概率分布,以显示序列中每个标记的可能性。这主要是通过将 d_model 投射到 vocab_size 来实现的。输出的形状为(batch_size、seq_length、vocab_size)

class Decoder(nn.Module):  def __init__(self, vocab_size: int, d_model: int, n_layers: int,                n_heads: int, d_ffn: int, dropout: float = 0.1):    """    Args:        vocab_size:   size of the vocabulary        d_model:      dimension of embeddings        n_layers:     number of encoder layers        n_heads:      number of heads        d_ffn:        dimension of feed-forward network        dropout:      probability of dropout occurring    """    super().__init__()
# create n_layers encoders self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ffn, dropout) for layer in range(n_layers)]) self.dropout = nn.Dropout(dropout)
# set output layer self.Wo = nn.Linear(d_model, vocab_size) def forward(self, trg: Tensor, src: Tensor, trg_mask: Tensor, src_mask: Tensor): """ Args: trg: embedded sequences (batch_size, trg_seq_length, d_model) src: encoded sequences from encoder (batch_size, src_seq_length, d_model) trg_mask: mask for the sequences (batch_size, 1, trg_seq_length, trg_seq_length) src_mask: mask for the sequences (batch_size, 1, 1, src_seq_length)
Returns: output: sequences after decoder (batch_size, trg_seq_length, vocab_size) attn_probs: attention softmax scores (batch_size, n_heads, trg_seq_length, src_seq_length) masked_attn_probs: masked attention softmax scores (batch_size, n_heads, trg_seq_length, trg_seq_length) """
# pass the sequences through each decoder for layer in self.layers: trg, masked_attn_probs, attn_probs = layer(trg, src, trg_mask, src_mask)
self.masked_attn_probs = masked_attn_probs self.attn_probs = attn_probs
return self.Wo(trg)






05


 Target Mask


要理解目标掩码的必要性,最好先看一个解码器输入和输出的例子。解码器的目标是根据编码源序列和部分目标序列预测序列中的下一个标记。要做到这一点,必须有一个 "start"标记来提示模型预测序列中的下一个标记。这就是下图中<bos>标记的用法。还需要注意的是,输入和输出到解码器的维度大小必须相同。


如果目标是让模型将 "Wie heißt du? "翻译成 "what is your name?",那么编码器将对源序列的含义进行编码,然后将其传递给解码器。鉴于<bos>标记和编码嵌入特征,解码器应该预测 "what"。然后,"What "被附加到<bos>上,形成新的输入,即" <bos> What"。这就是解码器的输入被视为 "右移 "的原因。这可以传递给解码器来预测 "What is"。这个标记被附加到前一个输入中,以创建新的输入" <bos> what is"。然后将其传递给解码器,以预测 "what is your"。这一过程不断重复,直到模型预测出<eos>标记。
给定目标序列为"<bos> What is your name? <eos>",模型可以通过目标掩码同时学习每次迭代:

在推理过程中,每个序列的预期输出结果如下:


请记住,解码器的输入和输出必须长度相同。因此,每个目标序列在传给解码器之前,都需要去掉最后一个标记。如果目标序列存储在 trg 中,那么解码器的输入将是 trg[:,:-1],以选择除最后一个标记以外的所有内容,这可以从上面的目标输入中看到。预期输出将是 trg[:,1:],即除第一个标记外的所有内容,也就是上面看到的预期输出。


总之,与编码器层一样,解码器也需要对其输入进行掩码。输入需要填充掩码,目标序列也需要前续掩码或后续掩码。在推理时,模型将只得到一个起始标记,并必须根据它预测下一个标记。然后,在给定两个标记的情况下,它必须预测第三个标记。这个过程会一直重复,直到预测出序列末端的标记。这就是Transformer的自回归行为。换句话说,未来的标记只能根据过去的标记和编码器的嵌入来预测。


为了模仿这种行为,模型使用后续掩码同时学习所有这些迭代。PyTorch中的 torch.tril 可用于创建后续掩码。它的形状为 (trg_seq_length, trg_seq_length)
trg_seq_length = 10
subsequent_mask = torch.tril(torch.ones((seq_length, seq_length))).int()
结果如下:
tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0],        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],         dtype=torch.int32)
对于序列中的每个Token标记,概率分布只能考虑前面的标记。但是,由于目标序列也必须进行填充,因此必须将填充掩码和后续掩码结合起来。
pad_mask = torch.Tensor([[1,1,1,1,1,1,1,0,0,0]]).unsqueeze(1).unsqueeze(2).int()print(pad_mask)
结果如下:
tensor([[[[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]]]], dtype=torch.int32)
使用 & 运算符可以轻松实现这一操作,只有当两个掩码都是 1 时,才会返回 1。
print(subsequent_mask & pad_mask)
结果如下:
tensor([[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],          [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],          [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],          [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],          [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],          [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],          [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],          [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],          [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],          [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]]]], dtype=torch.int32)
这个最终目标掩码必须为批次中的每个序列创建,这意味着它的形状为(batch_size, 1, trg_seq_length, trg_seq_length)。







06


 如何使用Source和Target Mask?


由于解码器中使用了两种多头注意力机制,因此在向解码器提供编码器的嵌入时,第一种注意力机制将使用目标掩码,第二种注意力机制将使用源掩码。


在第一种机制中,每个标记Token的概率分布将只考虑前面的标记。这反映了模型在推理过程中的行为,可以从下面的目标掩码中看出,每个标记只依赖于前面的标记:

在第二种机制中,目标序列就是Query,而Source序列就是Key。这就在每个目标标记和源标记之间创建了一个概率分布。在推理过程中,这有助于模型识别哪些目标标记最适合给定的源标记。下面是一个经过训练的注意力分布示例:


在此可视化图中,可以看到每个查询Query与其关键字Key之间的关系。例如<bos>与 eine关系密切。a 与 frau关系最密切。woman与mit有关联。with 与 einer有关联。这说明了每个查询标记Query与下一个应该预测的英语标记的关键字或德语对应词之间的关系。


为了重现这一过程,我们需要将编码器和解码器结合起来,创建一个模型,通过训练将德语翻译成英语。







07


 训练简单模型


在建立模型之前,必须创建德语和英语词汇表。附录中的函数以编码器文章中的功能为基础,但针对英语和德语进行了通用化。


本模型使用与前几篇文章相同的英文示例,并使用谷歌翻译生成了对应的德文示例。

de_example = "Hallo! Dies ist ein Beispiel für einen Absatz, der in seine Grundkomponenten aufgeteilt wurde. Ich frage mich, was als nächstes kommt! Irgendwelche Ideen?"en_example = "Hello! This is an example of a paragraph that has been split into its basic components. I wonder what will come next! Any guesses?"
# build the vocabde_stoi = build_vocab(de_example)en_stoi = build_vocab(en_example)
# build integer-to-string decoder for the vocabde_itos = {v:k for k,v in de_stoi.items()}en_itos = {v:k for k,v in en_stoi.items()}

为方便前向传递,可创建三个德语-英语对。这些词对必须进行标记Token化处理,根据词汇表建立索引,并进行填充。
de_sequences = ["Ich frage mich, was als nächstes kommt!",                "Dies ist ein Beispiel für einen Absatz.",                "Hallo, was ist ein Grundkomponenten?"]
en_sequences = ["I wonder what will come next!", "This is a basic example paragraph.", "Hello, what is a basic split?"]
# pad the sequencesmax_length = 9pad_idx = de_stoi['<pad>']
de_padded_seqs = []en_padded_seqs = []
# pad each sequencefor de_seq, en_seq in zip(de_indexed_sequences, en_indexed_sequences): de_padded_seqs.append(pad_seq(torch.Tensor(de_seq), max_length, pad_idx)) en_padded_seqs.append(pad_seq(torch.Tensor(en_seq), max_length, pad_idx))
# create a tensor from the padded sequencesde_tensor_sequences = torch.stack(de_padded_seqs).long()en_tensor_sequences = torch.stack(en_padded_seqs).long()

现在,我们可以为训练准备目标序列了。每个目标序列都有九个标记:六个来自原句、一个起始标记、一个结束标记和一个填充标记。第一个序列的标记化表示如下所示:

['<bos>', 'i', 'wonder', 'what', 'will', 'come', 'next', '<eos>', '<pad>']

如前所述,解码器的输入将是该序列的子集。必须从目标序列中删除最后一个标记,以便解码器预测序列中的下一个标记:

['<bos>', 'i', 'wonder', 'what', 'will', 'come', 'next', '<eos>']

同样,序列的预期输出是序列的另一个子集。要创建预期输出,必须删除第一个标记:

['i', 'wonder', 'what', 'will', 'come', 'next', '<eos>', '<pad>']

这些子集的代码见下文,源掩码和目标掩码也可以对应生成。

# remove last tokentrg = en_tensor_sequences[:,:-1] 
# remove the first tokenexpected_output = en_tensor_sequences[:,1:]
# generate maskssrc_mask = make_src_mask(de_tensor_sequences, pad_idx)trg_mask = make_trg_mask(trg, pad_idx)

可以查看第一个序列的目标掩码:

display_mask(trg[0].int().tolist(), trg_mask[0])

结果如下:

在这里,可以创建模型。nn.Sequential 可用于源序列嵌入、目标序列嵌入、位置编码、编码器和解码器,以创建对这两者的前向传递。

# parametersde_vocab_size = len(de_stoi)en_vocab_size = len(en_stoi)d_model = 32d_ffn = d_model*4 # 32n_heads = 4n_layers = 3dropout = 0.1max_pe_length = 10
# create the embeddingsde_lut = Embeddings(de_vocab_size, d_model) # look-up table (lut)en_lut = Embeddings(en_vocab_size, d_model)
# create the positional encodingspe = PositionalEncoding(d_model=d_model, dropout=0.1, max_length=max_pe_length)
# embed and encodede_embed = nn.Sequential(de_lut, pe)en_embed = nn.Sequential(en_lut, pe)
# initialize encoderencoder = Encoder(d_model, n_layers, n_heads, d_ffn, dropout)
# initialize the decoderdecoder = Decoder(en_vocab_size, d_model, n_layers, n_heads, d_ffn, dropout)

层创建完成后,模型可以在 nn.ModuleList 中初始化,它将所有组件存储在一个列表中,可以通过 Module 方法(如 parameters())访问。

# initialize the modelmodel = nn.ModuleList([de_embed, en_embed, encoder, decoder])
# normalize the weightsfor p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p)

参数总数可以通过一个简单的函数预览。

def count_parameters(model):    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {count_parameters(model):,} trainable parameters.')

结果如下:

The model has 91,675 trainable parameters.

现在,可以对模型进行简单的前向传递,并通过求对数的 argmax 来预览预测结果。

# pass through encoderencoded_embeddings = encoder(src=de_embed(de_tensor_sequences),                                src_mask=src_mask)
# logits for each outputlogits = decoder(trg=en_embed(trg), src=encoded_embeddings, trg_mask=trg_mask, src_mask=src_mask)
predictions = [[en_itos[tok] for tok in seq] for seq in logits.argmax(-1).tolist()]

结果如下:

[['a', '<eos>', 'basic', 'a', 'this', 'a', 'a', 'an'], ['wonder', 'into', 'any', 'wonder', 'i', 'wonder', 'wonder', 'an'], ['that', 'any', 'has', 'basic', 'split', 'wonder', 'example', 'wonder']]
在没有训练的情况下,输出是无用的,但这说明了一个基本的前向传递。现在,可以对模型进行训练,以生成预期输出。必须选择超参数、优化器和损失函数。Adam将是优化器,交叉熵损失将用于评估模型的损失。
# hyperparametersLEARNING_RATE = 0.005EPOCHS = 50
# adam optimizeroptimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
# loss functioncriterion = nn.CrossEntropyLoss(ignore_index = en_stoi["<pad>"])

由于只使用了三个序列,因此可以创建一个训练循环来更新参数,并在每次迭代时预览预测结果。

# set the model to training modemodel.train()
# loop through each epochfor i in range(EPOCHS): epoch_loss = 0
# zero the gradients optimizer.zero_grad()
# pass through encoder encoded_embeddings = encoder(src=de_embed(de_tensor_sequences), src_mask=src_mask)
# logits for each output logits = decoder(trg=en_embed(trg), src=encoded_embeddings, trg_mask=trg_mask, src_mask=src_mask) # calculate the loss loss = criterion(logits.contiguous().view(-1, logits.shape[-1]), expected_output.contiguous().view(-1)) # backpropagation loss.backward()
# clip the weights torch.nn.utils.clip_grad_norm_(model.parameters(), 1) # update the weights optimizer.step() # preview the predictions predictions = [[en_itos[tok] for tok in seq] for seq in logits.argmax(-1).tolist()]
if i % 7 == 0: print("="*25) print(f"epoch: {i}") print(f"loss: {loss.item()}") print(f"predictions: {predictions}")

结果如下:

=========================epoch: 0loss: 3.8633525371551514predictions: [['an', 'an', 'an', 'an', 'an', 'an', 'an', 'an'],               ['of', 'an', 'an', 'an', 'an', 'an', 'an', 'an'],               ['an', 'an', 'an', 'an', 'an', 'an', 'an', 'been']]=========================epoch: 7loss: 2.7589643001556396predictions: [['i', 'i', 'i', 'i', 'i', 'i', 'i', 'i'],               ['is', 'is', 'is', 'is', 'is', 'paragraph', 'is', 'is'],               ['is', 'is', 'is', 'a', 'is', 'basic', 'basic', 'basic']]=========================epoch: 14loss: 1.7105616331100464predictions: [['i', 'i', 'i', 'a', 'will', '<eos>', '<eos>', '<eos>'],               ['hello', 'is', 'this', 'is', 'paragraph', 'paragraph', '<eos>', 'paragraph'],               ['hello', 'example', 'is', 'is', 'basic', '<eos>', '<eos>', '<eos>']]=========================epoch: 21loss: 1.2171827554702759predictions: [['i', 'what', 'what', 'next', 'next', 'next', '<eos>', '<eos>'],               ['this', 'is', 'a', 'basic', 'paragraph', 'a', 'paragraph', '<eos>'],               ['this', 'basic', 'is', 'a', 'basic', 'basic', '<eos>', '<eos>']]=========================epoch: 28loss: 0.8726108074188232predictions: [['i', 'what', 'what', 'will', 'come', '<eos>', '<eos>', '<eos>'],               ['this', 'is', 'a', 'basic', 'paragraph', 'example', '<eos>', 'paragraph'],               ['hello', 'what', 'is', 'a', 'basic', 'split', '<eos>', '<eos>']]=========================epoch: 35loss: 0.6604534387588501predictions: [['i', 'wonder', 'next', 'will', 'come', 'next', '<eos>', 'next'],               ['this', 'is', 'a', 'basic', 'example', 'paragraph', '<eos>', 'paragraph'],               ['hello', 'what', 'is', 'a', 'basic', 'basic', '<eos>', '<eos>']]=========================epoch: 42loss: 0.3311622142791748predictions: [['i', 'wonder', 'what', 'will', 'come', 'next', '<eos>', '<eos>'],               ['this', 'is', 'a', 'basic', 'paragraph', 'paragraph', '<eos>', '<eos>'],               ['hello', 'what', 'is', 'a', 'basic', 'split', '<eos>', '<eos>']]=========================epoch: 49loss: 0.19808804988861084predictions: [['i', 'wonder', 'what', 'will', 'come', 'next', '<eos>', '<eos>'],               ['this', 'is', 'a', 'basic', 'example', 'paragraph', '<eos>', 'paragraph'],              ['hello', 'what', 'is', 'a', 'basic', 'split', '<eos>', 'split']]
在第50个epoch中,模型成功预测了所有三个序列。可以可视化第一个序列的source和 Target之间的解码器注意力。
# convert the indices to stringsdecoder_input = [en_itos[i] for i in trg[0].tolist()]
display_attention(de_tokenized_sequences[0], decoder_input, decoder.attn_probs[0],n_heads, n_rows=2, n_cols=2)
结果如下:


还可以查看基于mask的注意力。
display_attention(decoder_input, decoder_input, decoder.masked_attn_probs[0],n_heads, n_rows=2, n_cols=2)
结果如下:



这些例子看起来与文章前面的例子不一样,这是因为训练模型的样本数量较少。不过,这种方法可以很容易地扩展到成千上万的示例,这将是本系列最后一篇文章的目标。


请不要忘记点赞和关注,以获取更多信息!






附录: 
  • Tokenization

def tokenize(sequence, special_toks=True):  # remove punctuation  for punc in ["!", ".", "?", ","]:    sequence = sequence.replace(punc, "")    # split the sequence on spaces and lowercase each token  sequence = [token.lower() for token in sequence.split(" ")]
# add beginning and end tokens if special_toks: sequence = ['<bos>'] + sequence + ['<eos>']
return sequence
  • Build Vocabulary

def build_vocab(data):  # tokenize the data and remove duplicates  vocab = list(set(tokenize(data, special_toks=False)))
# sort the vocabulary vocab.sort()
# add special tokens vocab = ['<pad>', '<bos>', '<eos>'] + vocab
# assign an integer to each word stoi = {word:i for i, word in enumerate(vocab)}
return stoi
  • Padding

def pad_seq(seq: Tensor, max_length: int = 10, pad_idx: int = 0):  """  Args:      seq:          raw sequence (batch_size, seq_length)      max_length:   maximum length of a sequence      pad_idx:      index for padding tokens                 Returns:      padded seq:   padded sequence (batch_size, max_length)  """  pad_to_add = max_length - len(seq) # amount of padding to add    return pad(seq,(0, pad_to_add), value=pad_idx,)
  • Source Mask

def make_src_mask(src: Tensor, pad_idx: int = 0):  """  Args:      src:          raw sequences with padding        (batch_size, seq_length)                    Returns:      src_mask:     mask for each sequence            (batch_size, 1, 1, seq_length)  """  # assign 1 to tokens that need attended to and 0 to padding tokens, then add 2 dimensions  src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
return src_mask
  • Target Mask

def make_trg_mask(trg: Tensor, pad_idx: int = 0):  """  Args:      trg:          raw sequences with padding        (batch_size, seq_length)                    Returns:      trg_mask:     mask for each sequence            (batch_size, 1, seq_length, seq_length)  """
seq_length = trg.shape[1]
# assign True to tokens that need attended to and False to padding tokens, then add 2 dimensions trg_mask = (trg != pad_idx).unsqueeze(1).unsqueeze(2) # (batch_size, 1, 1, seq_length)
# generate subsequent mask trg_sub_mask = torch.tril(torch.ones((seq_length, seq_length))).bool() # (batch_size, 1, seq_length, seq_length)
# bitwise "and" operator | 0 & 0 = 0, 1 & 1 = 1, 1 & 0 = 0 trg_mask = trg_mask & trg_sub_mask
return trg_mask
  • Display Mask

def display_mask(sentence: list, mask: Tensor):  """    Display the target mask for each sequence.
Args: sequence: sequence to be masked mask: target mask for the heads """ # figure size fig = plt.figure(figsize=(8,8)) # create a plot ax = fig.add_subplot(mask.shape[0], 1, 1)
# select the respective head and make it a numpy array for plotting mask = mask.squeeze(0).cpu().detach().numpy() # plot the matrix cax = ax.matshow(mask, cmap='bone')
# set the size of the labels ax.tick_params(labelsize=12)
# set the indices for the tick marks ax.set_xticks(range(len(sentence))) ax.set_yticks(range(len(sentence)))
# set labels ax.xaxis.set_label_position('top') ax.set_ylabel("$Q$") ax.set_xlabel("$K^T$")
if isinstance(sentence[0], int): # convert indices to German/English sentence = [en_itos[tok] for tok in sentence]
ax.set_xticklabels(sentence, rotation=75) ax.set_yticklabels(sentence)
plt.show()
  • Display Attention

def display_attention(sentence: list, translation: list, attention: Tensor,                       n_heads: int = 8, n_rows: int = 4, n_cols: int = 2):  """    Display the attention matrix for each head of a sequence.
Args: sentence: German sentence to be translated to English; list translation: English sentence predicted by the model attention: attention scores for the heads n_heads: number of heads n_rows: number of rows n_cols: number of columns """ # ensure the number of rows and columns are equal to the number of heads assert n_rows * n_cols == n_heads # figure size fig = plt.figure(figsize=(15,25)) # visualize each head for i in range(n_heads): # create a plot ax = fig.add_subplot(n_rows, n_cols, i+1) # select the respective head and make it a numpy array for plotting _attention = attention.squeeze(0)[i,:,:].cpu().detach().numpy()
# plot the matrix cax = ax.matshow(_attention, cmap='bone')
# set the size of the labels ax.tick_params(labelsize=12)
# set the indices for the tick marks ax.set_xticks(range(len(sentence))) ax.set_yticks(range(len(translation)))
# if the provided sequences are sentences or indices if isinstance(sentence[0], str): ax.set_xticklabels([t.lower() for t in sentence], rotation=45) ax.set_yticklabels(translation) elif isinstance(sentence[0], int): ax.set_xticklabels(sentence) ax.set_yticklabels(translation)
plt.show()





点击上方小卡片关注我




添加个人微信,进专属粉丝群!

AI算法之道
一个专注于深度学习、计算机视觉和自动驾驶感知算法的公众号,涵盖视觉CV、神经网络、模式识别等方面,包括相应的硬件和软件配置,以及开源项目等。
 最新文章