来源 | 知乎问答
地址 | https://www.zhihu.com/question/298810062
本文仅作学术分享,若侵权请联系后台删文处理
回答一:作者-不是大叔
class BertSelfAttention(nn.Module):
self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
2. 假设三种操作的输入都是同一个矩阵(暂且先别管为什么输入是同一个矩阵),这里暂且定为长度为L的句子,每个token的特征维度是768,那么输入就是(L, 768),每一行就是一个字,像这样:
乘以上面三种操作就得到了Q/K/V,(L, 768)*(768,768) = (L,768),维度其实没变,即此刻的Q/K/V分别为:
代码为:
class BertSelfAttention(nn.Module):
def __init__(self, config):
self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
def forward(self,hidden_states): # hidden_states 维度是(L, 768)
Q = self.query(hidden_states)
K = self.key(hidden_states)
V = self.value(hidden_states)
3. 然后来实现这个操作:
① 首先是Q和K矩阵乘,(L, 768)*(L, 768)的转置=(L,L),看图:
③ 然后就是刚才的注意力权重和V矩阵乘了,如图:
整个过程在草稿纸上画一画简单的矩阵乘就出来了,一目了然~最后上代码:
class BertSelfAttention(nn.Module):
def __init__(self, config):
self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
def forward(self,hidden_states): # hidden_states 维度是(L, 768)
Q = self.query(hidden_states)
K = self.key(hidden_states)
V = self.value(hidden_states)
attention_scores = torch.matmul(Q, K.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = nn.Softmax(dim=-1)(attention_scores)
out = torch.matmul(attention_probs, V)
return out
回答二:作者-到处挖坑蒋玉成
回答三:作者-渠梁
首先,Attention机制是由Encoder-Decoder架构而来,且最初是用于完成NLP领域中的翻译(Translation)任务。那么输入输出就是非常明显的 Source-Target的对应关系,经典的Seq2Seq结构是从Encoder生成出一个语义向量(Context vector)而不再变化,然后将这个语义向量送入Decoder配合解码输出。这种方法的最大问题就是这个语义向量,我们是希望它一成不变好呢?还是它最好能配合Decoder动态调整自己,来使Target中的某些token与Source中的真正“有决定意义”的token关联起来好呢?