手撕Transformer之Feed-Forward Network

文摘   科技   2024-10-03 12:46   河南  



本文是手撕Transformer系列的第四篇。它从头开始介绍前馈神经网络Position-wise Feed-Forward Network它使用全连接层对每个序列进行变换




基于位置的前馈神经网络(FFN)由两个全连接层或多层感知机(MLP)组成。隐藏层(称为 d_ffn)的维度一般设定为 d_model 的四倍左右。因此,它有时也被称为扩展收缩网络。

FNN 第一层的权重维度为(d_model, d_ffn),这意味着在张量乘法过程中,必须对每个序列进行广播。这意味着每个序列都乘以相同的权重。如果输入相同的序列,输出也将相同。这一逻辑同样适用于大小为 (d_ffn, d_model) 的第二个全连接层,它将张量返回到原始大小。

各层之间使用 ReLU 激活函数 max(0,X)。任何大于 0 的值都保持不变,任何小于或等于 0 的值都变为 0。它引入了非线性,有助于防止梯度消失。



下面的代码依赖于Transformer模型的前几个模块的实现。到此为止,各层的输出为 (3,6,8)。有 3 个由 6 个 token 组成的序列,每个Token由8 维嵌入表示。
torch.set_printoptions(precision=2, sci_mode=False)
# convert the sequences to integerssequences = ["I wonder what will come next!", "This is a basic example paragraph.", "Hello, what is a basic split?"]
# tokenize the sequencestokenized_sequences = [tokenize(seq) for seq in sequences]
# index the sequences indexed_sequences = [[stoi[word] for word in seq] for seq in tokenized_sequences]
# convert the sequences to a tensortensor_sequences = torch.tensor(indexed_sequences).long()
# vocab sizevocab_size = len(stoi)
# embedding dimensionsd_model = 8
# create the embeddingslut = Embeddings(vocab_size, d_model) # look-up table (lut)
# create the positional encodingspe = PositionalEncoding(d_model=d_model, dropout=0.1, max_length=10)
# embed the sequenceembeddings = lut(tensor_sequences)
# positionally encode the sequencesX = pe(embeddings)
# set the n_headsn_heads = 4
# create the attention layerattention = MultiHeadAttention(d_model, n_heads, dropout=0.1)
# pass X through the attention layer three times to create Q, K, and Voutput, attn_probs = attention(X, X, X, mask=None)print(output)

现在,可以将上述输出送入FFN中。这将把8维嵌入转换为32 维嵌入表示。这也会通过 ReLU 激活函数。新的张量的维度将是(3,6,8)x(8,32)→(3,6,32)。

d_ffn = d_model * 4  # 32
w_1 = nn.Linear(d_model, d_ffn) # (8, 32)w_2 = nn.Linear(d_ffn, d_model) # (32, 8)
ffn_1 = w_1(output).relu()print(ffn_1)
tensor([        # sequence 0        [[    0.00,     0.00,     0.58,     0.00,     0.86,     0.00,     0.00,     0.44,     0.00,     0.00,     0.00,     0.23,     0.00,     0.40,     0.00,     0.30,     0.10,     0.00,     0.48,     0.00,     0.00,     0.00,     0.30,     0.71,     0.17,     0.00,     0.47,     0.00,     0.00,     0.00,     0.00,     0.40],         [    0.00,     0.00,     0.62,     0.00,     0.90,     0.00,     0.00,     0.51,     0.00,     0.00,     0.05,     0.29,     0.00,     0.37,     0.00,     0.33,     0.02,     0.00,     0.44,     0.00,     0.00,     0.00,     0.20,     0.83,     0.19,     0.00,     0.47,     0.00,     0.00,     0.00,     0.00,     0.32],         [    0.00,     0.00,     0.28,     0.00,     0.81,     0.00,     0.00,     0.53,     0.00,     0.00,     0.00,     0.04,     0.23,     0.30,     0.00,     0.61,     0.00,     0.00,     0.52,     0.00,     0.00,     0.00,     0.17,     0.80,     0.08,     0.00,     0.46,     0.00,     0.00,     0.00,     0.00,     0.50],         [    0.06,     0.00,     0.11,     0.00,     0.60,     0.00,     0.00,     0.47,     0.00,     0.00,     0.00,     0.00,     0.41,     0.10,     0.00,     0.76,     0.00,     0.14,     0.35,     0.00,     0.00,     0.00,     0.13,     0.49,     0.00,     0.00,     0.28,     0.00,     0.00,     0.00,     0.00,     0.57],         [    0.00,     0.12,     0.40,     0.00,     0.63,     0.00,     0.00,     0.34,     0.00,     0.25,     0.26,     0.40,     0.00,     0.31,     0.00,     0.21,     0.03,     0.00,     0.62,     0.00,     0.00,     0.00,     0.00,     1.83,     0.45,     0.00,     0.65,     0.00,     0.00,     0.09,     0.00,     0.00],         [    0.00,     0.13,     0.29,     0.00,     0.67,     0.00,     0.00,     0.41,     0.00,     0.15,     0.27,     0.30,     0.00,     0.27,     0.00,     0.40,     0.00,     0.00,     0.58,     0.00,     0.00,     0.00,     0.00,     1.78,     0.34,     0.00,     0.62,     0.00,     0.00,     0.06,     0.00,     0.00]],
# sequence 1 [[ 0.00, 0.00, 0.89, 0.00, 0.51, 0.00, 0.00, 0.28, 0.00, 0.00, 0.00, 0.38, 0.00, 0.17, 0.00, 0.00, 0.32, 0.00, 0.35, 0.06, 0.00, 0.00, 0.11, 0.54, 0.47, 0.00, 0.32, 0.20, 0.20, 0.04, 0.00, 0.17], [ 0.00, 0.00, 0.96, 0.00, 0.39, 0.00, 0.00, 0.15, 0.00, 0.00, 0.00, 0.31, 0.00, 0.03, 0.05, 0.00, 0.50, 0.00, 0.12, 0.00, 0.00, 0.00, 0.20, 0.36, 0.41, 0.00, 0.32, 0.29, 0.56, 0.08, 0.00, 0.22], [ 0.07, 0.00, 0.56, 0.00, 0.22, 0.00, 0.00, 0.38, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.13, 0.00, 0.26, 0.00, 0.00, 0.04, 0.07, 0.31, 0.00, 0.11, 0.25, 0.00, 0.41, 0.15, 0.00, 0.34], [ 0.68, 0.00, 0.01, 0.31, 0.00, 0.18, 0.00, 0.00, 0.77, 0.23, 0.00, 0.00, 0.00, 0.00, 0.34, 0.00, 0.45, 0.00, 0.37, 0.00, 0.10, 0.00, 0.00, 0.50, 0.00, 0.00, 0.05, 0.00, 0.34, 0.00, 0.00, 0.00], [ 0.00, 0.00, 0.31, 0.32, 0.00, 0.00, 0.00, 0.11, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.23, 0.05, 0.58, 0.00, 0.23, 0.00, 0.00, 0.83, 0.62, 0.19, 0.34, 0.19, 0.19, 0.00, 0.00, 0.25], [ 0.24, 0.00, 0.12, 0.00, 0.00, 0.00, 0.00, 0.22, 0.00, 0.00, 0.00, 0.00, 0.13, 0.00, 0.00, 0.32, 0.08, 0.00, 0.49, 0.00, 0.00, 0.00, 0.00, 0.59, 0.00, 0.00, 0.28, 0.00, 0.00, 0.00, 0.00, 0.24]],
# sequence 2 [[ 0.00, 1.00, 0.67, 0.07, 1.18, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.98, 0.00, 0.44, 0.00, 0.17, 0.00, 0.09, 1.07, 0.38, 0.10, 0.12, 0.00, 1.89, 2.11, 1.44, 0.69, 0.91, 0.00, 0.06, 0.00, 0.22], [ 0.00, 0.10, 0.00, 0.68, 0.00, 0.00, 0.42, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.19, 0.00, 0.92, 0.00, 0.43, 0.05, 0.00, 1.76, 0.92, 0.00, 0.57, 0.07, 0.00, 0.00, 0.00, 0.12], [ 0.00, 0.00, 0.00, 0.14, 0.00, 0.00, 0.30, 0.00, 0.00, 0.10, 0.00, 0.00, 0.26, 0.00, 0.00, 0.50, 0.05, 0.00, 0.77, 0.00, 0.08, 0.00, 0.00, 1.43, 0.00, 0.00, 0.53, 0.00, 0.00, 0.00, 0.00, 0.23], [ 0.00, 0.08, 0.00, 0.22, 0.00, 0.00, 0.45, 0.32, 0.00, 0.00, 0.00, 0.00, 0.27, 0.00, 0.00, 0.46, 0.00, 0.11, 1.03, 0.00, 0.22, 0.39, 0.00, 1.66, 0.49, 0.49, 0.54, 0.00, 0.00, 0.00, 0.00, 0.22], [ 0.35, 0.00, 0.00, 0.55, 0.00, 0.13, 0.03, 0.00, 0.51, 0.42, 0.00, 0.00, 0.00, 0.00, 0.03, 0.00, 0.63, 0.00, 0.66, 0.00, 0.32, 0.00, 0.00, 0.61, 0.00, 0.00, 0.27, 0.00, 0.00, 0.00, 0.28, 0.09], [ 0.01, 0.00, 0.00, 0.79, 0.00, 0.01, 0.42, 0.00, 0.28, 0.52, 0.00, 0.00, 0.00
然后,上述张量可以通过第二层全连接层恢复到正常大小,即 (3, 6, 32) x (32, 8) = (3, 6, 8)。根据权重和激活函数,数值发生了对应的改变。
ffn_2 = w_2(ffn_1)print(ffn_2)




在Transformer中的FFN的实现非常简单。它主要由两个线性层构成,第一个线性层的大小为(d_model, d_ffn),第二个线性层的大小为(d_ffn, d_model)

模型的输入 X 的维度大小为(batch_size、seq_length、d_model)。因此,输入将经过以下转换:

  • (batch_size, seq_length, d_model) x (d_model, d_ffn) = (batch_size, seq_length, d_ffn)
  • max(0, (batch_size, seq_length, d_ffn)) = (batch_size, seq_length, d_ffn)
  • (batch_size, seq_length, d_ffn) x (d_ffn, d_model) = (batch_size, seq_length, d_model)


class PositionwiseFeedForward(nn.Module):  def __init__(self, d_model: int, d_ffn: int, dropout: float = 0.1):    """    Args:        d_model:      dimension of embeddings        d_ffn:        dimension of feed-forward network        dropout:      probability of dropout occurring    """    super().__init__()
self.w_1 = nn.Linear(d_model, d_ffn) self.w_2 = nn.Linear(d_ffn, d_model) self.dropout = nn.Dropout(dropout)
def forward(self, x): """ Args: x: output from attention (batch_size, seq_length, d_model) Returns: expanded-and-contracted representation (batch_size, seq_length, d_model) """ # w_1(x).relu(): (batch_size, seq_length, d_model) x (d_model,d_ffn) -> (batch_size, seq_length, d_ffn) # w_2(w_1(x).relu()): (batch_size, seq_length, d_ffn) x (d_ffn, d_model) -> (batch_size, seq_length, d_model) return self.w_2(self.dropout(self.w_1(x).relu()))




torch.set_printoptions(precision=2, sci_mode=False)
# convert the sequences to integerssequences = ["I wonder what will come next!", "This is a basic example paragraph.", "Hello, what is a basic split?"]
# tokenize the sequencestokenized_sequences = [tokenize(seq) for seq in sequences]
# index the sequences indexed_sequences = [[stoi[word] for word in seq] for seq in tokenized_sequences]
# convert the sequences to a tensortensor_sequences = torch.tensor(indexed_sequences).long()
# vocab sizevocab_size = len(stoi)
# embedding dimensionsd_model = 8
# create the embeddingslut = Embeddings(vocab_size, d_model) # look-up table (lut)
# create the positional encodingspe = PositionalEncoding(d_model=d_model, dropout=0.1, max_length=10)
# embed the sequenceembeddings = lut(tensor_sequences)
# positionally encode the sequencesX = pe(embeddings)
# set the n_headsn_heads = 4
# create the attention layerattention = MultiHeadAttention(d_model, n_heads, dropout=0.1)
# pass X through the attention layer three times to create Q, K, and Voutput, attn_probs = attention(X, X, X, mask=None)
# calculate the d_ffnd_ffn = d_model*4 # 32
# pass the tensor through the position-wise feed-forward networkffn = PositionwiseFeedForward(d_model, d_ffn, dropout=0.1)




有了上面的介绍,希望大家都可以看懂前馈神经网络的代码实现。该系列的下一篇文章是 "层归一化"。



