01
引言
本文是手撕Transformer系列的第四篇。它从头开始介绍前馈神经网络(Position-wise Feed-Forward Network),它使用全连接层对每个序列进行变换。
02
背景介绍
基于位置的前馈神经网络(FFN)由两个全连接层或多层感知机(MLP)组成。隐藏层(称为 d_ffn)的维度一般设定为 d_model 的四倍左右。因此,它有时也被称为扩展收缩网络。
FNN 第一层的权重维度为(d_model, d_ffn),这意味着在张量乘法过程中,必须对每个序列进行广播。这意味着每个序列都乘以相同的权重。如果输入相同的序列,输出也将相同。这一逻辑同样适用于大小为 (d_ffn, d_model) 的第二个全连接层,它将张量返回到原始大小。
各层之间使用 ReLU 激活函数 max(0,X)。任何大于 0 的值都保持不变,任何小于或等于 0 的值都变为 0。它引入了非线性,有助于防止梯度消失。
03
基础实现
torch.set_printoptions(precision=2, sci_mode=False)
# convert the sequences to integers
sequences = ["I wonder what will come next!",
"This is a basic example paragraph.",
"Hello, what is a basic split?"]
# tokenize the sequences
tokenized_sequences = [tokenize(seq) for seq in sequences]
# index the sequences
indexed_sequences = [[stoi[word] for word in seq] for seq in tokenized_sequences]
# convert the sequences to a tensor
tensor_sequences = torch.tensor(indexed_sequences).long()
# vocab size
vocab_size = len(stoi)
# embedding dimensions
d_model = 8
# create the embeddings
lut = Embeddings(vocab_size, d_model) # look-up table (lut)
# create the positional encodings
pe = PositionalEncoding(d_model=d_model, dropout=0.1, max_length=10)
# embed the sequence
embeddings = lut(tensor_sequences)
# positionally encode the sequences
X = pe(embeddings)
# set the n_heads
n_heads = 4
# create the attention layer
attention = MultiHeadAttention(d_model, n_heads, dropout=0.1)
# pass X through the attention layer three times to create Q, K, and V
output, attn_probs = attention(X, X, X, mask=None)
print(output)
现在,可以将上述输出送入FFN中。这将把8维嵌入转换为32 维嵌入表示。这也会通过 ReLU 激活函数。新的张量的维度将是(3,6,8)x(8,32)→(3,6,32)。
d_ffn = d_model * 4 # 32
w_1 = nn.Linear(d_model, d_ffn) # (8, 32)
w_2 = nn.Linear(d_ffn, d_model) # (32, 8)
ffn_1 = w_1(output).relu()
print(ffn_1)
tensor([
# sequence 0
[[ 0.00, 0.00, 0.58, 0.00, 0.86, 0.00, 0.00, 0.44, 0.00, 0.00, 0.00, 0.23, 0.00, 0.40, 0.00, 0.30, 0.10, 0.00, 0.48, 0.00, 0.00, 0.00, 0.30, 0.71, 0.17, 0.00, 0.47, 0.00, 0.00, 0.00, 0.00, 0.40],
[ 0.00, 0.00, 0.62, 0.00, 0.90, 0.00, 0.00, 0.51, 0.00, 0.00, 0.05, 0.29, 0.00, 0.37, 0.00, 0.33, 0.02, 0.00, 0.44, 0.00, 0.00, 0.00, 0.20, 0.83, 0.19, 0.00, 0.47, 0.00, 0.00, 0.00, 0.00, 0.32],
[ 0.00, 0.00, 0.28, 0.00, 0.81, 0.00, 0.00, 0.53, 0.00, 0.00, 0.00, 0.04, 0.23, 0.30, 0.00, 0.61, 0.00, 0.00, 0.52, 0.00, 0.00, 0.00, 0.17, 0.80, 0.08, 0.00, 0.46, 0.00, 0.00, 0.00, 0.00, 0.50],
[ 0.06, 0.00, 0.11, 0.00, 0.60, 0.00, 0.00, 0.47, 0.00, 0.00, 0.00, 0.00, 0.41, 0.10, 0.00, 0.76, 0.00, 0.14, 0.35, 0.00, 0.00, 0.00, 0.13, 0.49, 0.00, 0.00, 0.28, 0.00, 0.00, 0.00, 0.00, 0.57],
[ 0.00, 0.12, 0.40, 0.00, 0.63, 0.00, 0.00, 0.34, 0.00, 0.25, 0.26, 0.40, 0.00, 0.31, 0.00, 0.21, 0.03, 0.00, 0.62, 0.00, 0.00, 0.00, 0.00, 1.83, 0.45, 0.00, 0.65, 0.00, 0.00, 0.09, 0.00, 0.00],
[ 0.00, 0.13, 0.29, 0.00, 0.67, 0.00, 0.00, 0.41, 0.00, 0.15, 0.27, 0.30, 0.00, 0.27, 0.00, 0.40, 0.00, 0.00, 0.58, 0.00, 0.00, 0.00, 0.00, 1.78, 0.34, 0.00, 0.62, 0.00, 0.00, 0.06, 0.00, 0.00]],
# sequence 1
[[ 0.00, 0.00, 0.89, 0.00, 0.51, 0.00, 0.00, 0.28, 0.00, 0.00, 0.00, 0.38, 0.00, 0.17, 0.00, 0.00, 0.32, 0.00, 0.35, 0.06, 0.00, 0.00, 0.11, 0.54, 0.47, 0.00, 0.32, 0.20, 0.20, 0.04, 0.00, 0.17],
[ 0.00, 0.00, 0.96, 0.00, 0.39, 0.00, 0.00, 0.15, 0.00, 0.00, 0.00, 0.31, 0.00, 0.03, 0.05, 0.00, 0.50, 0.00, 0.12, 0.00, 0.00, 0.00, 0.20, 0.36, 0.41, 0.00, 0.32, 0.29, 0.56, 0.08, 0.00, 0.22],
[ 0.07, 0.00, 0.56, 0.00, 0.22, 0.00, 0.00, 0.38, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.13, 0.00, 0.26, 0.00, 0.00, 0.04, 0.07, 0.31, 0.00, 0.11, 0.25, 0.00, 0.41, 0.15, 0.00, 0.34],
[ 0.68, 0.00, 0.01, 0.31, 0.00, 0.18, 0.00, 0.00, 0.77, 0.23, 0.00, 0.00, 0.00, 0.00, 0.34, 0.00, 0.45, 0.00, 0.37, 0.00, 0.10, 0.00, 0.00, 0.50, 0.00, 0.00, 0.05, 0.00, 0.34, 0.00, 0.00, 0.00],
[ 0.00, 0.00, 0.31, 0.32, 0.00, 0.00, 0.00, 0.11, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.23, 0.05, 0.58, 0.00, 0.23, 0.00, 0.00, 0.83, 0.62, 0.19, 0.34, 0.19, 0.19, 0.00, 0.00, 0.25],
[ 0.24, 0.00, 0.12, 0.00, 0.00, 0.00, 0.00, 0.22, 0.00, 0.00, 0.00, 0.00, 0.13, 0.00, 0.00, 0.32, 0.08, 0.00, 0.49, 0.00, 0.00, 0.00, 0.00, 0.59, 0.00, 0.00, 0.28, 0.00, 0.00, 0.00, 0.00, 0.24]],
# sequence 2
[[ 0.00, 1.00, 0.67, 0.07, 1.18, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.98, 0.00, 0.44, 0.00, 0.17, 0.00, 0.09, 1.07, 0.38, 0.10, 0.12, 0.00, 1.89, 2.11, 1.44, 0.69, 0.91, 0.00, 0.06, 0.00, 0.22],
[ 0.00, 0.10, 0.00, 0.68, 0.00, 0.00, 0.42, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.19, 0.00, 0.92, 0.00, 0.43, 0.05, 0.00, 1.76, 0.92, 0.00, 0.57, 0.07, 0.00, 0.00, 0.00, 0.12],
[ 0.00, 0.00, 0.00, 0.14, 0.00, 0.00, 0.30, 0.00, 0.00, 0.10, 0.00, 0.00, 0.26, 0.00, 0.00, 0.50, 0.05, 0.00, 0.77, 0.00, 0.08, 0.00, 0.00, 1.43, 0.00, 0.00, 0.53, 0.00, 0.00, 0.00, 0.00, 0.23],
[ 0.00, 0.08, 0.00, 0.22, 0.00, 0.00, 0.45, 0.32, 0.00, 0.00, 0.00, 0.00, 0.27, 0.00, 0.00, 0.46, 0.00, 0.11, 1.03, 0.00, 0.22, 0.39, 0.00, 1.66, 0.49, 0.49, 0.54, 0.00, 0.00, 0.00, 0.00, 0.22],
[ 0.35, 0.00, 0.00, 0.55, 0.00, 0.13, 0.03, 0.00, 0.51, 0.42, 0.00, 0.00, 0.00, 0.00, 0.03, 0.00, 0.63, 0.00, 0.66, 0.00, 0.32, 0.00, 0.00, 0.61, 0.00, 0.00, 0.27, 0.00, 0.00, 0.00, 0.28, 0.09],
[ 0.01, 0.00, 0.00, 0.79, 0.00, 0.01, 0.42, 0.00, 0.28, 0.52, 0.00, 0.00, 0.00
ffn_2 = w_2(ffn_1)
print(ffn_2)
结果如下:
04
Transformer中的FFN
在Transformer中的FFN的实现非常简单。它主要由两个线性层构成,第一个线性层的大小为(d_model, d_ffn),第二个线性层的大小为(d_ffn, d_model)。
模型的输入 X 的维度大小为(batch_size、seq_length、d_model)。因此,输入将经过以下转换:
(batch_size, seq_length, d_model) x (d_model, d_ffn) = (batch_size, seq_length, d_ffn) max(0, (batch_size, seq_length, d_ffn)) = (batch_size, seq_length, d_ffn) (batch_size, seq_length, d_ffn) x (d_ffn, d_model) = (batch_size, seq_length, d_model)
上述代码实现如下:
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model: int, d_ffn: int, dropout: float = 0.1):
"""
Args:
d_model: dimension of embeddings
d_ffn: dimension of feed-forward network
dropout: probability of dropout occurring
"""
super().__init__()
self.w_1 = nn.Linear(d_model, d_ffn)
self.w_2 = nn.Linear(d_ffn, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
Args:
x: output from attention (batch_size, seq_length, d_model)
Returns:
expanded-and-contracted representation (batch_size, seq_length, d_model)
"""
# w_1(x).relu(): (batch_size, seq_length, d_model) x (d_model,d_ffn) -> (batch_size, seq_length, d_ffn)
# w_2(w_1(x).relu()): (batch_size, seq_length, d_ffn) x (d_ffn, d_model) -> (batch_size, seq_length, d_model)
return self.w_2(self.dropout(self.w_1(x).relu()))
05
前向过程
前向传递过程可以假定数据已通过嵌入层、位置编码层和多头注意力层。它暂时不使用层归一化或残差加法,这些功能将在后面的编码器中的该网络的前后实现。
torch.set_printoptions(precision=2, sci_mode=False)
# convert the sequences to integers
sequences = ["I wonder what will come next!",
"This is a basic example paragraph.",
"Hello, what is a basic split?"]
# tokenize the sequences
tokenized_sequences = [tokenize(seq) for seq in sequences]
# index the sequences
indexed_sequences = [[stoi[word] for word in seq] for seq in tokenized_sequences]
# convert the sequences to a tensor
tensor_sequences = torch.tensor(indexed_sequences).long()
# vocab size
vocab_size = len(stoi)
# embedding dimensions
d_model = 8
# create the embeddings
lut = Embeddings(vocab_size, d_model) # look-up table (lut)
# create the positional encodings
pe = PositionalEncoding(d_model=d_model, dropout=0.1, max_length=10)
# embed the sequence
embeddings = lut(tensor_sequences)
# positionally encode the sequences
X = pe(embeddings)
# set the n_heads
n_heads = 4
# create the attention layer
attention = MultiHeadAttention(d_model, n_heads, dropout=0.1)
# pass X through the attention layer three times to create Q, K, and V
output, attn_probs = attention(X, X, X, mask=None)
# calculate the d_ffn
d_ffn = d_model*4 # 32
# pass the tensor through the position-wise feed-forward network
ffn = PositionwiseFeedForward(d_model, d_ffn, dropout=0.1)
print(ffn(output))
结果如下:
有了上面的介绍,希望大家都可以看懂前馈神经网络的代码实现。该系列的下一篇文章是 "层归一化"。
点击上方小卡片关注我
添加个人微信,进专属粉丝群!