从头实现Transformer

文摘   2024-06-25 20:32   辽宁  
点击上方“进修编程”,选择“星标公众号

超级无敌干货,第一时间送达!!!

当我决定深入研究 Transformer 架构时,经常觉得无从下手,因为我觉得他们总是漏了一些东西:

  • Tensorflow 或 Pytorch 的官方教程使用了它们自己的 API,因此内容比较笼统,迫使我必须进入它们的代码库才能看到底层原理。阅读数千行代码非常耗时,而且并不总是那么容易。

  • 我发现的其他带有自定义代码的教程(文章末尾的链接)通常过于简化用例,并没有解决诸如可变长度序列批处理屏蔽等概念。

因此,我决定编写自己的 Transformer,以确保我理解这些概念并能够将其与任何数据集一起使用。

因此,在本文中,我们将遵循一种系统的方法,逐层、逐块地实现变压器。

显然,Pytorch 或 Tensorflow 已经有很多不同的实现以及现成的高级 API,我确信它们的性能比我们将要构建的模型更好。

“好的,但是为什么不使用 TF/Pytorch 实现呢?”

本文的目的是为了教育,我并不自诩能超越 Pytorch 或 Tensorflow 实现。我确实认为 transformers 背后的理论和代码并不简单,这就是为什么我希望通过这个循序渐进的教程能让你更好地掌握这些概念,并在以后构建自己的代码时感到更舒服。

从头开始构建自己的转换器的另一个原因是,它将使您能够充分理解如何使用上述 API。如果我们查看forward()Transformer 类方法的 Pytorch 实现,您将看到许多晦涩难懂的关键字,例如:

来源:Pytorch 文档

如果您已经熟悉这些关键词,那么您可以愉快地跳过本文。

否则,本文将带您了解每个关键字及其基本概念。

Transformer 的简短介绍

如果你已经听说过 ChatGPT 或 Gemini,那么你之前就已经见过一个转换器。实际上,ChatGPT 中的“T”代表 Transformer。

该架构最初由谷歌研究人员于 2017 年在《注意力就是你所需要的一切》论文中提出。它具有革命性,因为以前用于序列到序列学习(机器翻译、语音到文本等)的模型依赖于 RNN,而 RNN 的计算成本很高,因为它们必须逐步处理序列,而Transformer 只需要查看一次整个序列,将时间复杂度从 O(n) 转移到 O(1)。

transformers 在 NLP 领域的应用非常广泛,包括语言翻译、问答、文档摘要、文本生成等。

Transformer 的整体架构如下:

多头注意力机制

我们将要实现的第一个块实际上是 Transformer 中最重要的部分,称为多头注意力。让我们看看它在整体架构中的位置

Attention 实际上并不是 Transformer 特有的机制,它已经在 RNN 序列到序列模型中使用。

Transformer 中的注意力机制(来源:Tensorflow文档

Transformer 中的注意力机制
import torchimport torch.nn as nnimport math

class MultiHeadAttention(nn.Module): def __init__(self, hidden_dim=256, num_heads=4): """ input_dim: Dimensionality of the input. num_heads: The number of attention heads to split the input into. """ super(MultiHeadAttention, self).__init__() self.hidden_dim = hidden_dim self.num_heads = num_heads assert hidden_dim % num_heads == 0, "Hidden dim must be divisible by num heads" self.Wv = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Value part self.Wk = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Key part self.Wq = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Query part self.Wo = nn.Linear(hidden_dim, hidden_dim, bias=False) # the output layer def check_sdpa_inputs(self, x): assert x.size(1) == self.num_heads, f"Expected size of x to be ({-1, self.num_heads, -1, self.hidden_dim // self.num_heads}), got {x.size()}" assert x.size(3) == self.hidden_dim // self.num_heads def scaled_dot_product_attention( self, query, key, value, attention_mask=None, key_padding_mask=None): """ query : tensor of shape (batch_size, num_heads, query_sequence_length, hidden_dim//num_heads) key : tensor of shape (batch_size, num_heads, key_sequence_length, hidden_dim//num_heads) value : tensor of shape (batch_size, num_heads, key_sequence_length, hidden_dim//num_heads) attention_mask : tensor of shape (query_sequence_length, key_sequence_length) key_padding_mask : tensor of shape (sequence_length, key_sequence_length) """ self.check_sdpa_inputs(query) self.check_sdpa_inputs(key) self.check_sdpa_inputs(value) d_k = query.size(-1) tgt_len, src_len = query.size(-2), key.size(-2)
# logits = (B, H, tgt_len, E) * (B, H, E, src_len) = (B, H, tgt_len, src_len) logits = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # Attention mask here if attention_mask is not None: if attention_mask.dim() == 2: assert attention_mask.size() == (tgt_len, src_len) attention_mask = attention_mask.unsqueeze(0) logits = logits + attention_mask else: raise ValueError(f"Attention mask size {attention_mask.size()}") # Key mask here if key_padding_mask is not None: key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # Broadcast over batch size, num heads logits = logits + key_padding_mask attention = torch.softmax(logits, dim=-1) output = torch.matmul(attention, value) # (batch_size, num_heads, sequence_length, hidden_dim) return output, attention
def split_into_heads(self, x, num_heads): batch_size, seq_length, hidden_dim = x.size() x = x.view(batch_size, seq_length, num_heads, hidden_dim // num_heads) return x.transpose(1, 2) # Final dim will be (batch_size, num_heads, seq_length, , hidden_dim // num_heads)
def combine_heads(self, x): batch_size, num_heads, seq_length, head_hidden_dim = x.size() return x.transpose(1, 2).contiguous().view(batch_size, seq_length, num_heads * head_hidden_dim) def forward( self, q, k, v, attention_mask=None, key_padding_mask=None): """ q : tensor of shape (batch_size, query_sequence_length, hidden_dim) k : tensor of shape (batch_size, key_sequence_length, hidden_dim) v : tensor of shape (batch_size, key_sequence_length, hidden_dim) attention_mask : tensor of shape (query_sequence_length, key_sequence_length) key_padding_mask : tensor of shape (sequence_length, key_sequence_length) """ q = self.Wq(q) k = self.Wk(k) v = self.Wv(v)
q = self.split_into_heads(q, self.num_heads) k = self.split_into_heads(k, self.num_heads) v = self.split_into_heads(v, self.num_heads) # attn_values, attn_weights = self.multihead_attn(q, k, v, attn_mask=attention_mask) attn_values, attn_weights = self.scaled_dot_product_attention( query=q, key=k, value=v, attention_mask=attention_mask, key_padding_mask=key_padding_mask, ) grouped = self.combine_heads(attn_values) output = self.Wo(grouped) self.attention_weigths = attn_weights return output

这里我们需要解释几个概念。

1)查询、键和值。

查询是 您要匹配的信息,
键和值是存储的信息。

将其视为使用字典:每当使用 Python 字典时,如果您的查询与字典键不匹配,则不会返回任何内容。但是,如果我们希望字典返回非常接近的信息组合,该怎么办?例如,如果我们有:

d = {"panther": 1, "bear": 10, "dog":3}d["wolf"] = 0.2*d["panther"] + 0.7*d["dog"] + 0.1*d["bear"]

这基本上就是注意力:查看数据的不同部分,并将它们混合在一起以获得综合结果作为查询的答案。

代码的相关部分是这个,我们计算查询和键之间的注意力权重

logits = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # we compute the weights of attention

这里,我们将标准化的权重应用到值上:

attention = torch.softmax(logits, dim=-1)output = torch.matmul(attention, value) # (batch_size, num_heads, sequence_length, hidden_dim)

2)注意力掩蔽和填充

当关注连续输入的部分时,我们不想包含无用或禁止的信息。

例如,无用信息是填充:用于将批次中的所有序列对齐到相同序列大小的填充符号应该被我们的模型忽略。我们将在最后一节中回顾这一点

禁忌信息稍微复杂一些。在训练过程中,模型会学习对输入序列进行编码,并将目标与输入对齐。但是,由于推理过程涉及查看之前发出的标记以预测下一个标记(想想 ChatGPT 中的文本生成),因此我们需要在训练期间应用相同的规则。

这就是为什么我们要应用因果掩码来确保目标在每个时间步骤只能看到过去的信息。以下是应用掩码的相应部分(计算掩码在最后介绍)

if attention_mask is not None:    if attention_mask.dim() == 2:        assert attention_mask.size() == (tgt_len, src_len)        attention_mask = attention_mask.unsqueeze(0)        logits = logits + attention_mask

位置编码

它对应Transformer的如下部分:

在接收和处理输入时,Transformer 没有顺序感,因为它将序列视为一个整体,这与 RNN 的做法相反。因此,我们需要添加一点时间顺序,以便 Transformer 能够学习依赖关系。

位置编码的工作原理的具体细节超出了本文的范围,但请随意阅读原始论文以了解。

# Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html#define-the-modelclass PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer('pe', pe)
def forward(self, x): """ Arguments: x: Tensor, shape ``[batch_size, seq_len, embedding_dim]`` """ x = x + self.pe[:, :x.size(1), :] return x

编码器

我们即将拥有一个完整的编码器!编码器是 Transformer 的左侧部分

我们将在代码中添加一小部分,即前馈部分:

class PositionWiseFeedForward(nn.Module):    def __init__(self, d_model: int, d_ff: int):        super(PositionWiseFeedForward, self).__init__()        self.fc1 = nn.Linear(d_model, d_ff)        self.fc2 = nn.Linear(d_ff, d_model)        self.relu = nn.ReLU()
def forward(self, x): return self.fc2(self.relu(self.fc1(x)))

将各个部分组合在一起,我们得到了一个编码器模块!

class EncoderBlock(nn.Module):    def __init__(self, n_dim: int, dropout: float, n_heads: int):        super(EncoderBlock, self).__init__()        self.mha = MultiHeadAttention(hidden_dim=n_dim, num_heads=n_heads)        self.norm1 = nn.LayerNorm(n_dim)        self.ff = PositionWiseFeedForward(n_dim, n_dim)        self.norm2 = nn.LayerNorm(n_dim)        self.dropout = nn.Dropout(dropout)            def forward(self, x, src_padding_mask=None):        assert x.ndim==3, "Expected input to be 3-dim, got {}".format(x.ndim)        att_output = self.mha(x, x, x, key_padding_mask=src_padding_mask)        x = x + self.dropout(self.norm1(att_output))                ff_output = self.ff(x)        output = x + self.norm2(ff_output)               return output

如图所示,编码器实际上包含 N 个编码器块或层,以及用于输入的嵌入层。因此,让我们通过添加嵌入、位置编码和编码器块来创建一个编码器:


class Encoder(nn.Module):    def __init__(            self,             vocab_size: int,             n_dim: int,             dropout: float,             n_encoder_blocks: int,            n_heads: int):                super(Encoder, self).__init__()        self.n_dim = n_dim
self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=n_dim ) self.positional_encoding = PositionalEncoding( d_model=n_dim, dropout=dropout ) self.encoder_blocks = nn.ModuleList([ EncoderBlock(n_dim, dropout, n_heads) for _ in range(n_encoder_blocks) ]) def forward(self, x, padding_mask=None): x = self.embedding(x) * math.sqrt(self.n_dim) x = self.positional_encoding(x) for block in self.encoder_blocks: x = block(x=x, src_padding_mask=padding_mask) return x

解码器

解码器部分是左边的部分,需要更多的制作。

有一种叫做Masked Multi-Head Attention 的东西。还记得我们之前说过的因果掩码吗?好吧,这发生在这里。我们将使用多头注意力模块的 Attention_mask 参数来表示这一点(有关我们如何在最后计算掩码的更多详细信息):

# Stuff before
self.self_attention = MultiHeadAttention(hidden_dim=n_dim, num_heads=n_heads)masked_att_output = self.self_attention( q=tgt, k=tgt, v=tgt, attention_mask=tgt_mask, <-- HERE IS THE CAUSAL MASK key_padding_mask=tgt_padding_mask)
# Stuff after

第二个注意力机制称为交叉注意力机制。它将使用解码器的查询来匹配编码器的键和值!注意:它们在训练期间可能具有不同的长度,因此通常最好明确定义输入的预期形状,如下所示:

def scaled_dot_product_attention(            self,             query,             key,             value,             attention_mask=None,             key_padding_mask=None):        """        query : tensor of shape (batch_size, num_heads, query_sequence_length, hidden_dim//num_heads)        key : tensor of shape (batch_size, num_heads, key_sequence_length, hidden_dim//num_heads)        value : tensor of shape (batch_size, num_heads, key_sequence_length, hidden_dim//num_heads)        attention_mask : tensor of shape (query_sequence_length, key_sequence_length)        key_padding_mask : tensor of shape (sequence_length, key_sequence_length)            """

下面是我们将编码器的输出(称为内存)与解码器输入一起使用的部分:

# Stuff beforeself.cross_attention = MultiHeadAttention(hidden_dim=n_dim, num_heads=n_heads)cross_att_output = self.cross_attention(        q=x1,         k=memory,         v=memory,         attention_mask=None,  <-- NO CAUSAL MASK HERE        key_padding_mask=memory_padding_mask) <-- WE NEED TO USE THE PADDING OF THE SOURCE# Stuff after

把这些部分放在一起,我们最终得到了解码器:

class DecoderBlock(nn.Module):    def __init__(self, n_dim: int, dropout: float, n_heads: int):        super(DecoderBlock, self).__init__()                # The first Multi-Head Attention has a mask to avoid looking at the future        self.self_attention = MultiHeadAttention(hidden_dim=n_dim, num_heads=n_heads)        self.norm1 = nn.LayerNorm(n_dim)                # The second Multi-Head Attention will take inputs from the encoder as key/value inputs        self.cross_attention = MultiHeadAttention(hidden_dim=n_dim, num_heads=n_heads)        self.norm2 = nn.LayerNorm(n_dim)                self.ff = PositionWiseFeedForward(n_dim, n_dim)        self.norm3 = nn.LayerNorm(n_dim)        # self.dropout = nn.Dropout(dropout)                    def forward(self, tgt, memory, tgt_mask=None, tgt_padding_mask=None, memory_padding_mask=None):                masked_att_output = self.self_attention(            q=tgt, k=tgt, v=tgt, attention_mask=tgt_mask, key_padding_mask=tgt_padding_mask)        x1 = tgt + self.norm1(masked_att_output)                cross_att_output = self.cross_attention(            q=x1, k=memory, v=memory, attention_mask=None, key_padding_mask=memory_padding_mask)        x2 = x1 + self.norm2(cross_att_output)                ff_output = self.ff(x2)        output = x2 + self.norm3(ff_output)
return output
class Decoder(nn.Module): def __init__( self, vocab_size: int, n_dim: int, dropout: float, n_decoder_blocks: int, n_heads: int): super(Decoder, self).__init__() self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=n_dim, padding_idx=0 ) self.positional_encoding = PositionalEncoding( d_model=n_dim, dropout=dropout ) self.decoder_blocks = nn.ModuleList([ DecoderBlock(n_dim, dropout, n_heads) for _ in range(n_decoder_blocks) ]) def forward(self, tgt, memory, tgt_mask=None, tgt_padding_mask=None, memory_padding_mask=None): x = self.embedding(tgt) x = self.positional_encoding(x)
for block in self.decoder_blocks: x = block( x, memory, tgt_mask=tgt_mask, tgt_padding_mask=tgt_padding_mask, memory_padding_mask=memory_padding_mask) return x

填充和遮罩

记住多头注意力部分,我们在其中提到过在进行注意力时排除输入的某些部分。

在训练过程中,我们考虑输入和目标的批次,其中每个实例的长度可能不定。考虑以下示例,其中我们批量处理 4 个单词:香蕉、西瓜、梨、蓝莓。为了将它们作为单个批次处理,我们需要将所有单词与最长单词(西瓜)的长度对齐。因此,我们将为每个单词添加一个额外的标记 PAD,以便它们最终都具有与西瓜相同的长度。

下图中,上表代表原始数据,下表代表编码版本:

在我们的例子中,我们希望从正在计算的注意力权重中排除填充索引。因此,我们可以按如下方式计算源数据和目标数据的掩码:

padding_mask = (x == PAD_IDX)

那么因果掩码现在怎么样了?如果我们希望模型在每个时间步骤中只能关注过去的步骤,这意味着对于每个时间步骤 T,模型只能关注 1…T 中的每个步骤 t。这是一个双 for 循环,因此我们可以使用矩阵来计算:

def generate_square_subsequent_mask(size: int):      """Generate a triangular (size, size) mask. From PyTorch docs."""      mask = (1 - torch.triu(torch.ones(size, size), diagonal=1)).bool()      mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))      return mask

案例研究:Word-Reverse Transformer

现在让我们将各个部分组合在一起来构建我们的 Transformer!

在我们的用例中,我们将使用一个非常简单的数据集来展示 Transformers 的实际学习方式。

“但是为什么要使用 Transformer 来反转单词呢?我已经知道如何在 Python 中使用 word[::-1] 来实现这一点了!”

这里的目标是看看 Transformer 的注意力机制是否有效。我们期望看到注意力权重在给定输入序列时从右向左移动。如果是这样,这意味着我们的 Transformer 已经学会了一种非常简单的语法,即从右向左阅读,并且可以在进行实际语言翻译时推广到更复杂的语法。

让我们首先从自定义的 Transformer 类开始:

import torchimport torch.nn as nnimport math
from .encoder import Encoderfrom .decoder import Decoder

class Transformer(nn.Module): def __init__(self, **kwargs): super(Transformer, self).__init__() for k, v in kwargs.items(): print(f" * {k}={v}") self.vocab_size = kwargs.get('vocab_size') self.model_dim = kwargs.get('model_dim') self.dropout = kwargs.get('dropout') self.n_encoder_layers = kwargs.get('n_encoder_layers') self.n_decoder_layers = kwargs.get('n_decoder_layers') self.n_heads = kwargs.get('n_heads') self.batch_size = kwargs.get('batch_size') self.PAD_IDX = kwargs.get('pad_idx', 0)
self.encoder = Encoder( self.vocab_size, self.model_dim, self.dropout, self.n_encoder_layers, self.n_heads) self.decoder = Decoder( self.vocab_size, self.model_dim, self.dropout, self.n_decoder_layers, self.n_heads) self.fc = nn.Linear(self.model_dim, self.vocab_size)
@staticmethod def generate_square_subsequent_mask(size: int): """Generate a triangular (size, size) mask. From PyTorch docs.""" mask = (1 - torch.triu(torch.ones(size, size), diagonal=1)).bool() mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask

def encode( self, x: torch.Tensor, ) -> torch.Tensor: """ Input x: (B, S) with elements in (0, C) where C is num_classes Output (B, S, E) embedding """
mask = (x == self.PAD_IDX).float() encoder_padding_mask = mask.masked_fill(mask == 1, float('-inf')) # (B, S, E) encoder_output = self.encoder( x, padding_mask=encoder_padding_mask ) return encoder_output, encoder_padding_mask def decode( self, tgt: torch.Tensor, memory: torch.Tensor, memory_padding_mask=None ) -> torch.Tensor: """ B = Batch size S = Source sequence length L = Target sequence length E = Model dimension Input encoded_x: (B, S, E) y: (B, L) with elements in (0, C) where C is num_classes Output (B, L, C) logits """ mask = (tgt == self.PAD_IDX).float() tgt_padding_mask = mask.masked_fill(mask == 1, float('-inf'))
decoder_output = self.decoder( tgt=tgt, memory=memory, tgt_mask=self.generate_square_subsequent_mask(tgt.size(1)), tgt_padding_mask=tgt_padding_mask, memory_padding_mask=memory_padding_mask, ) output = self.fc(decoder_output) # shape (B, L, C) return output
def forward( self, x: torch.Tensor, y: torch.Tensor, ) -> torch.Tensor: """ Input x: (B, Sx) with elements in (0, C) where C is num_classes y: (B, Sy) with elements in (0, C) where C is num_classes Output (B, L, C) logits """ # Encoder output shape (B, S, E) encoder_output, encoder_padding_mask = self.encode(x)
# Decoder output shape (B, L, C) decoder_output = self.decode( tgt=y, memory=encoder_output, memory_padding_mask=encoder_padding_mask ) return decoder_output

使用贪婪解码进行推理

我们需要添加一个方法,它将充当model.predictscikit.learn 的函数。目标是要求模型根据输入动态输出预测。在推理过程中,没有目标:模型首先通过关注输出来输出一个 token,然后使用自己的预测继续发出 token。这就是为什么这些模型通常被称为自回归模型,因为它们使用过去的预测来预测下一个预测。

贪婪解码的问题在于,它在每一步都考虑概率最高的标记。如果第一个标记完全错误,这可能会导致非常糟糕的预测。还有其他解码方法,例如 Beam 搜索,它考虑候选序列的候选列表(考虑在每个时间步骤中保留前 k 个标记而不是 argmax)并返回总概率最高的序列。

现在,让我们实现贪婪解码并将其添加到我们的 Transformer 模型中:

def predict(            self,            x: torch.Tensor,            sos_idx: int=1,            eos_idx: int=2,            max_length: int=None        ) -> torch.Tensor:        """        Method to use at inference time. Predict y from x one token at a time. This method is greedy        decoding. Beam search can be used instead for a potential accuracy boost.
Input x: str Output (B, L, C) logits """
# Pad the tokens with beginning and end of sentence tokens x = torch.cat([ torch.tensor([sos_idx]), x, torch.tensor([eos_idx])] ).unsqueeze(0)
encoder_output, mask = self.transformer.encode(x) # (B, S, E) if not max_length: max_length = x.size(1)
outputs = torch.ones((x.size()[0], max_length)).type_as(x).long() * sos_idx for step in range(1, max_length): y = outputs[:, :step] probs = self.transformer.decode(y, encoder_output) output = torch.argmax(probs, dim=-1) # Uncomment if you want to see step by step predicitons # print(f"Knowing {y} we output {output[:, -1]}")
if output[:, -1].detach().numpy() in (eos_idx, sos_idx): break outputs[:, step] = output[:, -1] return outputs

创建玩具数据

我们定义一个反转单词的小数据集,这意味着“helloworld”将返回“dlrowolleh”:

import numpy as npimport torchfrom torch.utils.data import Dataset

np.random.seed(0)
def generate_random_string(): len = np.random.randint(10, 20) return "".join([chr(x) for x in np.random.randint(97, 97+26, len)])
class ReverseDataset(Dataset): def __init__(self, n_samples, pad_idx, sos_idx, eos_idx): super(ReverseDataset, self).__init__() self.pad_idx = pad_idx self.sos_idx = sos_idx self.eos_idx = eos_idx self.values = [generate_random_string() for _ in range(n_samples)] self.labels = [x[::-1] for x in self.values]
def __len__(self): return len(self.values) # number of samples in the dataset
def __getitem__(self, index): return self.text_transform(self.values[index].rstrip("\n")), \ self.text_transform(self.labels[index].rstrip("\n")) def text_transform(self, x): return torch.tensor([self.sos_idx] + [ord(z)-97+3 for z in x] + [self.eos_idx])

我们现在将定义训练和评估步骤:

PAD_IDX = 0SOS_IDX = 1EOS_IDX = 2
def train(model, optimizer, loader, loss_fn, epoch): model.train() losses = 0 acc = 0 history_loss = [] history_acc = []
with tqdm(loader, position=0, leave=True) as tepoch: for x, y in tepoch: tepoch.set_description(f"Epoch {epoch}")
optimizer.zero_grad() logits = model(x, y[:, :-1]) loss = loss_fn(logits.contiguous().view(-1, model.vocab_size), y[:, 1:].contiguous().view(-1)) loss.backward() optimizer.step() losses += loss.item() preds = logits.argmax(dim=-1) masked_pred = preds * (y[:, 1:]!=PAD_IDX) accuracy = (masked_pred == y[:, 1:]).float().mean() acc += accuracy.item() history_loss.append(loss.item()) history_acc.append(accuracy.item()) tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy.item())
return losses / len(list(loader)), acc / len(list(loader)), history_loss, history_acc

def evaluate(model, loader, loss_fn): model.eval() losses = 0 acc = 0 history_loss = [] history_acc = []
for x, y in tqdm(loader, position=0, leave=True):
logits = model(x, y[:, :-1]) loss = loss_fn(logits.contiguous().view(-1, model.vocab_size), y[:, 1:].contiguous().view(-1)) losses += loss.item() preds = logits.argmax(dim=-1) masked_pred = preds * (y[:, 1:]!=PAD_IDX) accuracy = (masked_pred == y[:, 1:]).float().mean() acc += accuracy.item() history_loss.append(loss.item()) history_acc.append(accuracy.item())
return losses / len(list(loader)), acc / len(list(loader)), history_loss, history_acc

并对模型进行几个阶段的训练:

import torchimport timeimport torch.nn as nnimport numpy as npimport matplotlib.pyplot as plt
from tqdm import tqdmfrom torch.utils.data import DataLoaderfrom torch.nn.utils.rnn import pad_sequencefrom mpl_toolkits.axes_grid1 import ImageGrid

def collate_fn(batch): """ This function pads inputs with PAD_IDX to have batches of equal length """ src_batch, tgt_batch = [], [] for src_sample, tgt_sample in batch: src_batch.append(src_sample) tgt_batch.append(tgt_sample)
src_batch = pad_sequence(src_batch, padding_value=PAD_IDX, batch_first=True) tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX, batch_first=True) return src_batch, tgt_batch
# Model hyperparametersargs = { 'vocab_size': 128, 'model_dim': 128, 'dropout': 0.1, 'n_encoder_layers': 1, 'n_decoder_layers': 1, 'n_heads': 4}
# Define model heremodel = Transformer(**args)
# Instantiate datasetstrain_iter = ReverseDataset(50000, pad_idx=PAD_IDX, sos_idx=SOS_IDX, eos_idx=EOS_IDX)eval_iter = ReverseDataset(10000, pad_idx=PAD_IDX, sos_idx=SOS_IDX, eos_idx=EOS_IDX)dataloader_train = DataLoader(train_iter, batch_size=256, collate_fn=collate_fn)dataloader_val = DataLoader(eval_iter, batch_size=256, collate_fn=collate_fn)
# During debugging, we ensure sources and targets are indeed reversed# s, t = next(iter(dataloader_train))# print(s[:4, ...])# print(t[:4, ...])# print(s.size())
# Initialize model parametersfor p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p)
# Define loss function : we ignore logits which are padding tokensloss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)
# Save history to dictionnaryhistory = { 'train_loss': [], 'eval_loss': [], 'train_acc': [], 'eval_acc': []}
# Main loopfor epoch in range(1, 4): start_time = time.time() train_loss, train_acc, hist_loss, hist_acc = train(model, optimizer, dataloader_train, loss_fn, epoch) history['train_loss'] += hist_loss history['train_acc'] += hist_acc end_time = time.time() val_loss, val_acc, hist_loss, hist_acc = evaluate(model, dataloader_val, loss_fn) history['eval_loss'] += hist_loss history['eval_acc'] += hist_acc print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Train acc: {train_acc:.3f}, Val loss: {val_loss:.3f}, Val acc: {val_acc:.3f} "f"Epoch time = {(end_time - start_time):.3f}s"))

可视化注意力

我们定义一个小函数来访问注意力头的权重:

fig = plt.figure(figsize=(10., 10.))images = model.decoder.decoder_blocks[0].cross_attention.attention_weigths[0,...].detach().numpy()grid = ImageGrid(fig, 111,  # similar to subplot(111)                nrows_ncols=(2, 2),  # creates 2x2 grid of axes                axes_pad=0.1,  # pad between axes in inch.                )
for ax, im in zip(grid, images): # Iterating over the grid returns the Axes. ax.imshow(im)

从顶部读取权重时,我们可以看到一个很好的从右到左的模式。由于填充掩码,y 轴底部的垂直部分肯定表示被掩码的权重

测试我们的模型!

为了用新数据测试我们的模型,我们将定义一个小类Translator来帮助我们解码:

class Translator(nn.Module):    def __init__(self, transformer):        super(Translator, self).__init__()        self.transformer = transformer        @staticmethod    def str_to_tokens(s):        return [ord(z)-97+3 for z in s]        @staticmethod    def tokens_to_str(tokens):        return "".join([chr(x+94) for x in tokens])        def __call__(self, sentence, max_length=None, pad=False):                x = torch.tensor(self.str_to_tokens(sentence))        x = torch.cat([torch.tensor([SOS_IDX]), x, torch.tensor([EOS_IDX])]).unsqueeze(0)                encoder_output, mask = self.transformer.encode(x) # (B, S, E)                if not max_length:            max_length = x.size(1)                    outputs = torch.ones((x.size()[0], max_length)).type_as(x).long() * SOS_IDX                for step in range(1, max_length):            y = outputs[:, :step]            probs = self.transformer.decode(y, encoder_output)            output = torch.argmax(probs, dim=-1)            print(f"Knowing {y} we output {output[:, -1]}")            if output[:, -1].detach().numpy() in (EOS_IDX, SOS_IDX):                break            outputs[:, step] = output[:, -1]                            return self.tokens_to_str(outputs[0])
translator = Translator(model)

您应该能够看到以下内容:

如果我们打印注意力头,我们会观察到以下情况:

fig = plt.figure()images = model.decoder.decoder_blocks[0].cross_attention.attention_weigths[0,...].detach().numpy().mean(axis=0)
fig, ax = plt.subplots(1,1, figsize=(10., 10.))# Iterating over the grid returs the Axes.ax.set_yticks(range(len(out)))ax.set_xticks(range(len(sentence)))
ax.xaxis.set_label_position('top')
ax.set_xticklabels(iter(sentence))ax.set_yticklabels([f"step {i}" for i in range(len(out))])ax.imshow(images)

我们可以清楚地看到,当反转句子“reversethis”时,模型从右到左进行关注!(步骤 0 实际上接收了句子开头的标记)。

结论

您现在可以编写 Transformer 并将其与更大的数据集一起使用来执行机器翻译,例如创建您自己的 BERT!

敬请关注 !

python、matlab程序设计找我

—  —


进修编程
提升编程技能,学习编程技巧