论文介绍
题目:Multi-Head Attention as Mixture-of-Head Attention
论文地址:https://arxiv.org/pdf/2410.11842
QQ深度学习交流群:994264161
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
引入Mixture-of-Head Attention (MoH)
该论文提出了一种新的多头注意力机制,称为Mixture-of-Head Attention (MoH),将注意力头视为Mixture-of-Experts (MoE)框架中的专家。通过这种机制,每个token可以动态选择最相关的注意力头,从而提高推理效率,同时保持甚至超越原始多头注意力的性能。动态路由机制
MoH通过动态路由机制实现了每个token选择适当的注意力头,从而减少冗余头的激活。这种设计在不增加参数数量的情况下提升了模型性能和推理效率。加权求和替代标准求和
在传统多头注意力中,输出是所有头的简单求和,而MoH引入了加权求和机制,为注意力机制带来了更大的灵活性和性能潜力。共享头和两阶段路由策略
论文中提出了共享头的概念,用于捕获跨上下文的通用知识。同时设计了两阶段路由策略,在共享头和动态路由头之间平衡权重,从而进一步优化模型性能。支持预训练模型的迁移和改进
MoH能够将现有的预训练多头注意力模型(如LLaMA3-8B)继续微调为MoH模型。例如,MoH-LLaMA3-8B在14个基准测试上的平均准确率提高了2.4%,仅激活了75%的注意力头。广泛的实验验证
论文通过在图像分类(ViT)、类条件图像生成(DiT)和大语言模型(LLMs)等多个模型框架中的实验验证了MoH的有效性,表现出优于传统多头注意力的性能,且只需激活50%-90%的注意力头。
方法
整体架构
这篇论文提出的模型结构是基于Mixture-of-Head Attention (MoH) 的架构,它将传统的多头注意力机制与Mixture-of-Experts (MoE) 框架结合。具体而言,MoH使用一个路由器为每个token动态选择Top-K的注意力头,并通过加权求和代替标准求和来生成输出。此外,模型中引入了共享头(捕获通用知识)和两阶段路由策略(在共享头与动态路由头之间平衡权重),从而在不增加参数数量的情况下实现高效的推理和卓越的性能,适用于图像分类、生成和语言建模等多种任务。
多头注意力作为专家(Heads as Experts):
在MoH中,将传统多头注意力中的注意力头视为“专家”,并通过路由器动态激活每个token的Top-K头,从而选择最相关的头参与计算。动态路由器(Dynamic Router):
MoH通过一个动态路由器,根据输入token的特性为每个token分配路由分数,仅激活相关的注意力头以提高推理效率。共享头(Shared Heads)和路由头(Routed Heads):
模型中的一部分注意力头被设定为共享头,用于捕获通用知识(例如语言中的语法规则),这些共享头始终被激活。其余头作为路由头,根据动态路由器的分数动态激活。两阶段路由策略(Two-Stage Routing Strategy):
设计了一种两阶段路由策略,以动态平衡共享头和路由头的权重,使得模型能够更高效地利用注意力资源。加权求和(Weighted Summation):
替代传统多头注意力的简单求和方式,MoH对激活的头进行加权求和,从而增加了注意力机制的灵活性和性能潜力。整体结构可扩展性:
MoH结构在多个模型框架中得到验证,包括Vision Transformers (ViT) 用于图像分类、Diffusion Transformers (DiT) 用于图像生成,以及大语言模型 (LLMs) 用于语言任务。其核心特点是在不增加模型参数数量的情况下,通过动态头选择和共享知识捕获实现性能提升。
即插即用模块作用
MOH 作为一个即插即用模块:
图像分类任务
在基于Transformer的视觉模型(如Vision Transformers, ViT)中,MOH可以直接替代传统的多头注意力模块,用于图像分类任务。通过减少冗余头的激活和动态选择相关注意力头,MOH能够提升计算效率,同时保持甚至超越原始模型的分类性能。类条件图像生成任务
在扩散模型(如DiT, Diffusion models with Transformers)中,MOH可以用来优化图像生成任务中的注意力机制。通过动态路由机制,MOH能够更高效地捕捉像素级别的细粒度关系,从而提高生成质量并减少计算资源消耗。自然语言处理任务(NLP)
在大语言模型(LLMs,如LLaMA)中,MOH可以优化模型的注意力计算效率和性能,特别适合需要处理长文本、复杂上下文或多样化任务的语言建模场景。例如,MOH在文本分类、问答、逻辑推理等任务中都显示出优异的表现。迁移学习与模型微调
MOH可以无缝集成到已有的预训练模型中,通过继续微调的方式替代传统多头注意力结构。例如,它可以提升预训练模型在少量数据上的迁移学习性能,同时减少激活的注意力头以提高效率。
消融实验结果
单独加入共享头显著提升了模型的准确率(如图像分类任务中的Top-1准确率从75.6%提升至78.3%)。
再加入两阶段路由后,进一步提升模型性能,表现最佳(准确率达78.6%)。
说明共享头有效捕获了通用知识,而两阶段路由策略动态优化了共享头和路由头之间的权重平衡
在共享头比例范围从13.9%到74.0%内,模型的性能基本保持稳定(准确率在78.4%到78.6%之间)。
这表明共享头比例只要不过高或过低,对模型性能的影响有限,同时共享头起到了“Soft MoE”的作用
即插即用模块
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit import Final
from timm.layers import use_fused_attn
# 论文: Multi-Head Attention as Mixture-of-Head Attention
# 论文地址:https://arxiv.org/pdf/2410.11842
class MoHAttention(nn.Module):
fused_attn: Final[bool]
LOAD_BALANCING_LOSSES = []
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_norm=False,
attn_drop=0.,
proj_drop=0.,
norm_layer=nn.LayerNorm,
shared_head=0,
routed_head=0,
head_dim=None,
):
super().__init__()
# assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
if head_dim is None:
self.head_dim = dim // num_heads
else:
self.head_dim = head_dim
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()
self.qkv = nn.Linear(dim, (self.head_dim * self.num_heads) * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(self.head_dim * self.num_heads, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.shared_head = shared_head
self.routed_head = routed_head
if self.routed_head > 0:
self.wg = torch.nn.Linear(dim, num_heads - shared_head, bias=False)
if self.shared_head > 0:
self.wg_0 = torch.nn.Linear(dim, 2, bias=False)
if self.shared_head > 1:
self.wg_1 = torch.nn.Linear(dim, shared_head, bias=False)
def forward(self, x):
B, N, C = x.shape
_x = x.reshape(B * N, C)
if self.routed_head > 0:
logits = self.wg(_x)
gates = F.softmax(logits, dim=1)
num_tokens, num_experts = gates.shape
_, indices = torch.topk(gates, k=self.routed_head, dim=1)
mask = F.one_hot(indices, num_classes=num_experts).sum(dim=1)
if self.training:
me = gates.mean(dim=0)
ce = mask.float().mean(dim=0)
l_aux = torch.mean(me * ce) * num_experts * num_experts
MoHAttention.LOAD_BALANCING_LOSSES.append(l_aux)
routed_head_gates = gates * mask
denom_s = torch.sum(routed_head_gates, dim=1, keepdim=True)
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
routed_head_gates /= denom_s
routed_head_gates = routed_head_gates.reshape(B, N, -1) * self.routed_head
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
if self.routed_head > 0:
x = x.transpose(1, 2)
if self.shared_head > 0:
shared_head_weight = self.wg_1(_x)
shared_head_gates = F.softmax(shared_head_weight, dim=1).reshape(B, N, -1) * self.shared_head
weight_0 = self.wg_0(_x)
weight_0 = F.softmax(weight_0, dim=1).reshape(B, N, 2) * 2
shared_head_gates = torch.einsum("bn,bne->bne", weight_0[:,:,0], shared_head_gates)
routed_head_gates = torch.einsum("bn,bne->bne", weight_0[:,:,1], routed_head_gates)
masked_gates = torch.cat([shared_head_gates, routed_head_gates], dim=2)
else:
masked_gates = routed_head_gates
x = torch.einsum("bne,bned->bned", masked_gates, x)
x = x.reshape(B, N, self.head_dim * self.num_heads)
else:
shared_head_weight = self.wg_1(_x)
masked_gates = F.softmax(shared_head_weight, dim=1).reshape(B, N, -1) * self.shared_head
x = x.transpose(1, 2)
x = torch.einsum("bne,bned->bned", masked_gates, x)
x = x.reshape(B, N, self.head_dim * self.num_heads)
x = self.proj(x)
x = self.proj_drop(x)
return x
def main():
batch_size = 2
num_tokens = 16
embed_dim = 64
input = torch.rand(batch_size, num_tokens, embed_dim)
num_heads = 4
attn_layer = MoHAttention(
dim=embed_dim,
num_heads=num_heads,
qkv_bias=True,
qk_norm=True,
attn_drop=0.1,
proj_drop=0.1,
shared_head=2,
routed_head=2,
head_dim=16
)
attn_layer.train()
output = attn_layer(input)
print(input.size())
print(output.size())
if __name__ == "__main__": main()
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文