2024即插即用多头注意力模块MOH,涨点起飞起飞了!

文摘   2024-12-31 17:20   中国香港  

论文介绍

题目:Multi-Head Attention as Mixture-of-Head Attention

论文地址:https://arxiv.org/pdf/2410.11842

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 引入Mixture-of-Head Attention (MoH)
    该论文提出了一种新的多头注意力机制,称为Mixture-of-Head Attention (MoH),将注意力头视为Mixture-of-Experts (MoE)框架中的专家。通过这种机制,每个token可以动态选择最相关的注意力头,从而提高推理效率,同时保持甚至超越原始多头注意力的性能。

  • 动态路由机制
    MoH通过动态路由机制实现了每个token选择适当的注意力头,从而减少冗余头的激活。这种设计在不增加参数数量的情况下提升了模型性能和推理效率。

  • 加权求和替代标准求和
    在传统多头注意力中,输出是所有头的简单求和,而MoH引入了加权求和机制,为注意力机制带来了更大的灵活性和性能潜力。

  • 共享头和两阶段路由策略
    论文中提出了共享头的概念,用于捕获跨上下文的通用知识。同时设计了两阶段路由策略,在共享头和动态路由头之间平衡权重,从而进一步优化模型性能。

  • 支持预训练模型的迁移和改进
    MoH能够将现有的预训练多头注意力模型(如LLaMA3-8B)继续微调为MoH模型。例如,MoH-LLaMA3-8B在14个基准测试上的平均准确率提高了2.4%,仅激活了75%的注意力头。

  • 广泛的实验验证
    论文通过在图像分类(ViT)、类条件图像生成(DiT)和大语言模型(LLMs)等多个模型框架中的实验验证了MoH的有效性,表现出优于传统多头注意力的性能,且只需激活50%-90%的注意力头。

方法

整体架构

     这篇论文提出的模型结构是基于Mixture-of-Head Attention (MoH) 的架构,它将传统的多头注意力机制与Mixture-of-Experts (MoE) 框架结合。具体而言,MoH使用一个路由器为每个token动态选择Top-K的注意力头,并通过加权求和代替标准求和来生成输出。此外,模型中引入了共享头(捕获通用知识)和两阶段路由策略(在共享头与动态路由头之间平衡权重),从而在不增加参数数量的情况下实现高效的推理和卓越的性能,适用于图像分类、生成和语言建模等多种任务。

  • 多头注意力作为专家(Heads as Experts)
    在MoH中,将传统多头注意力中的注意力头视为“专家”,并通过路由器动态激活每个token的Top-K头,从而选择最相关的头参与计算。

  • 动态路由器(Dynamic Router)
    MoH通过一个动态路由器,根据输入token的特性为每个token分配路由分数,仅激活相关的注意力头以提高推理效率。

  • 共享头(Shared Heads)和路由头(Routed Heads)
    模型中的一部分注意力头被设定为共享头,用于捕获通用知识(例如语言中的语法规则),这些共享头始终被激活。其余头作为路由头,根据动态路由器的分数动态激活。

  • 两阶段路由策略(Two-Stage Routing Strategy)
    设计了一种两阶段路由策略,以动态平衡共享头和路由头的权重,使得模型能够更高效地利用注意力资源。

  • 加权求和(Weighted Summation)
    替代传统多头注意力的简单求和方式,MoH对激活的头进行加权求和,从而增加了注意力机制的灵活性和性能潜力。

  • 整体结构可扩展性
    MoH结构在多个模型框架中得到验证,包括Vision Transformers (ViT) 用于图像分类、Diffusion Transformers (DiT) 用于图像生成,以及大语言模型 (LLMs) 用于语言任务。其核心特点是在不增加模型参数数量的情况下,通过动态头选择和共享知识捕获实现性能提升。

即插即用模块作用

MOH 作为一个即插即用模块

  • 图像分类任务
    在基于Transformer的视觉模型(如Vision Transformers, ViT)中,MOH可以直接替代传统的多头注意力模块,用于图像分类任务。通过减少冗余头的激活和动态选择相关注意力头,MOH能够提升计算效率,同时保持甚至超越原始模型的分类性能。

  • 类条件图像生成任务
    在扩散模型(如DiT, Diffusion models with Transformers)中,MOH可以用来优化图像生成任务中的注意力机制。通过动态路由机制,MOH能够更高效地捕捉像素级别的细粒度关系,从而提高生成质量并减少计算资源消耗。

  • 自然语言处理任务(NLP)
    在大语言模型(LLMs,如LLaMA)中,MOH可以优化模型的注意力计算效率和性能,特别适合需要处理长文本、复杂上下文或多样化任务的语言建模场景。例如,MOH在文本分类、问答、逻辑推理等任务中都显示出优异的表现。

  • 迁移学习与模型微调
    MOH可以无缝集成到已有的预训练模型中,通过继续微调的方式替代传统多头注意力结构。例如,它可以提升预训练模型在少量数据上的迁移学习性能,同时减少激活的注意力头以提高效率。

消融实验结果

  • 单独加入共享头显著提升了模型的准确率(如图像分类任务中的Top-1准确率从75.6%提升至78.3%)。

  • 再加入两阶段路由后,进一步提升模型性能,表现最佳(准确率达78.6%)。

  • 说明共享头有效捕获了通用知识,而两阶段路由策略动态优化了共享头和路由头之间的权重平衡

    • 在共享头比例范围从13.9%到74.0%内,模型的性能基本保持稳定(准确率在78.4%到78.6%之间)。

    • 这表明共享头比例只要不过高或过低,对模型性能的影响有限,同时共享头起到了“Soft MoE”的作用

即插即用模块

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit import Final
from timm.layers import use_fused_attn

# 论文: Multi-Head Attention as Mixture-of-Head Attention

# 论文地址:https://arxiv.org/pdf/2410.11842


class MoHAttention(nn.Module):
    fused_attn: Final[bool]
    LOAD_BALANCING_LOSSES = []

    def __init__(
            self,
            dim,
            num_heads=8,
            qkv_bias=False,
            qk_norm=False,
            attn_drop=0.,
            proj_drop=0.,
            norm_layer=nn.LayerNorm,
            shared_head=0,
            routed_head=0,
            head_dim=None,
    ):
        super().__init__()
        # assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        
        if head_dim is None:
            self.head_dim = dim // num_heads
        else:
            self.head_dim = head_dim
        
        self.scale = self.head_dim ** -0.5
        self.fused_attn = use_fused_attn()

        self.qkv = nn.Linear(dim, (self.head_dim * self.num_heads) * 3, bias=qkv_bias)
        
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(self.head_dim * self.num_heads, dim)
        
        self.proj_drop = nn.Dropout(proj_drop)

        self.shared_head = shared_head
        self.routed_head = routed_head
        
        if self.routed_head > 0:
            self.wg = torch.nn.Linear(dim, num_heads - shared_head, bias=False)
            if self.shared_head > 0:
                self.wg_0 = torch.nn.Linear(dim, 2, bias=False)

        if self.shared_head > 1:
            self.wg_1 = torch.nn.Linear(dim, shared_head, bias=False)

    def forward(self, x):
        B, N, C = x.shape

        _x = x.reshape(B * N, C)
        
        if self.routed_head > 0:
            logits = self.wg(_x)
            gates = F.softmax(logits, dim=1)

            num_tokens, num_experts = gates.shape
            _, indices = torch.topk(gates, k=self.routed_head, dim=1)
            mask = F.one_hot(indices, num_classes=num_experts).sum(dim=1)

            if self.training:
                me = gates.mean(dim=0)
                ce = mask.float().mean(dim=0)
                l_aux = torch.mean(me * ce) * num_experts * num_experts

                MoHAttention.LOAD_BALANCING_LOSSES.append(l_aux)

            routed_head_gates = gates * mask
            denom_s = torch.sum(routed_head_gates, dim=1, keepdim=True)
            denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
            routed_head_gates /= denom_s
            routed_head_gates = routed_head_gates.reshape(B, N, -1) * self.routed_head

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        if self.fused_attn:
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p if self.training else 0.,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v
        
        if self.routed_head > 0:
            x = x.transpose(1, 2)

            if self.shared_head > 0:
                shared_head_weight = self.wg_1(_x)
                shared_head_gates = F.softmax(shared_head_weight, dim=1).reshape(B, N, -1) * self.shared_head

                weight_0 = self.wg_0(_x)
                weight_0 = F.softmax(weight_0, dim=1).reshape(B, N, 2) * 2
        
                shared_head_gates = torch.einsum("bn,bne->bne", weight_0[:,:,0], shared_head_gates)
                routed_head_gates = torch.einsum("bn,bne->bne", weight_0[:,:,1], routed_head_gates)
                
                masked_gates = torch.cat([shared_head_gates, routed_head_gates], dim=2)
            else:
                masked_gates = routed_head_gates

            x = torch.einsum("bne,bned->bned", masked_gates, x)
            x = x.reshape(B, N, self.head_dim * self.num_heads)
        else:
            shared_head_weight = self.wg_1(_x)
            masked_gates = F.softmax(shared_head_weight, dim=1).reshape(B, N, -1) * self.shared_head
            x = x.transpose(1, 2)

            x = torch.einsum("bne,bned->bned", masked_gates, x)
            x = x.reshape(B, N, self.head_dim * self.num_heads)

        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    
def main():

    batch_size = 2
    num_tokens = 16
    embed_dim = 64

    input = torch.rand(batch_size, num_tokens, embed_dim)

    num_heads = 4
    attn_layer = MoHAttention(
        dim=embed_dim,
        num_heads=num_heads,
        qkv_bias=True,
        qk_norm=True,
        attn_drop=0.1,
        proj_drop=0.1,
        shared_head=2,
        routed_head=2,
        head_dim=16
    )


    attn_layer.train()

    output = attn_layer(input)

    print(input.size())
    print(output.size())

if __name__ == "__main__":    main()

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章