即插即用动态剪枝模块TPC,涨点起飞起飞了

文摘   2025-01-15 17:21   上海  

论文介绍

题目:Hourglass Tokenizer for Efficient Transformer-Based 3D Human Pose Estimation

论文地址:https://arxiv.org/pdf/2311.12028

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 引入Hourglass Tokenizer (HoT)框架

    • 提出一种“剪枝-恢复”的框架,用于视频中的3D人体姿态估计。

    • 与现有方法不同,HoT框架通过在Transformer块中对姿态Token进行剪枝和恢复,显著降低了计算成本,同时保持了模型的高效性和准确性。

  • 提出Token Pruning Cluster (TPC)模块

    • 动态选择具有高语义多样性的代表性Token,消除视频帧的冗余信息。

    • 使用基于密度峰值的聚类算法(DPC-kNN),通过语义上具有代表性的聚类中心进行高效的姿态Token选择。

  • 开发Token Recovering Attention (TRA)模块

    • 利用轻量级的跨注意力机制,恢复剪枝操作后丢失的详细时空信息。

    • 实现从低时间分辨率到全时间分辨率的还原,满足快速推断的需求。

  • 在Transformer架构中的通用性

    • 该方法可以无缝集成到现有的多种视频姿态Transformer (VPT) 模型(如MHFormer、MixSTE和MotionBERT),并支持两种主流推断管线(seq2seq和seq2frame)。

  • 显著提升计算效率和推断速度

    • 在Human3.6M和MPI-INF-3DHP数据集上,与基线模型相比,HoT框架在减少高达50% FLOPs的同时保持或略微提升了模型性能。

  • 引入通用化的剪枝和恢复策略

    • 提供灵活的剪枝和恢复参数设置,适应不同任务和硬件限制的需求。

方法

整体架构

       论文提出的模型整体结构是一个基于Transformer的视频3D人体姿态估计框架,名为Hourglass Tokenizer (HoT)。模型接收2D姿态序列作为输入,经过姿态嵌入模块生成时空姿态Token,通过前几层Transformer捕获全局信息后,在中间引入**Token Pruning Cluster (TPC)模块对冗余Token进行动态剪枝,随后通过Token Recovering Attention (TRA)**模块在最后恢复全长度Token,最终通过回归头输出3D姿态。HoT框架以“剪枝-恢复”的方式优化了计算效率,适用于多种推断管线(seq2seq和seq2frame)。

1. 总体架构

论文提出了一种名为 Hourglass Tokenizer (HoT) 的框架,专为基于Transformer的视频3D人体姿态估计设计。该架构主要由以下部分组成:

  • 输入阶段:接受2D人体姿态序列(每帧包含关键点坐标)。

  • 姿态嵌入模块 (Pose Embedding Module)

    • 对输入的2D姿态进行编码,生成包含时空信息的姿态Token。

  • Transformer块 (Transformer Blocks)

    • 包含若干层Transformer块,用于捕捉全局的时空依赖关系。

    • 在前几层中保持全长度的姿态Token,以保留丰富的信息。

2. Hourglass Tokenizer 关键模块

HoT框架通过两个关键模块实现高效的Token处理:

  • Token Pruning Cluster (TPC)

    • 动态剪枝:在中间的Transformer块中剪除冗余的姿态Token,仅保留语义丰富的代表性Token。

    • 剪枝过程基于密度聚类算法(DPC-kNN),选择具有高语义多样性的Token作为代表性Token。

  • Token Recovering Attention (TRA)

    • 恢复全长度:在最后的Transformer块后,利用轻量级的多头交叉注意力(MCA)机制,从剪枝后的Token中恢复出全时间分辨率的姿态Token。

3. 推断管线

HoT支持两种主要推断管线:

  • seq2seq

    • 输入一个2D姿态序列,输出所有帧的3D姿态序列。

    • TPC模块在中间层剪枝,TRA模块在末端恢复Token。

  • seq2frame

    • 输入一个2D姿态序列,输出中心帧的3D姿态。

    • 只使用TPC模块剪枝,不需要恢复全长度的Token。

4. 回归模块

  • 姿态回归头 (Regression Head)

    • 将恢复后的姿态Token或中心帧的Token映射为3D姿态坐标。

即插即用模块作用

TCA 作为一个即插即用模块

  • 高效视频3D人体姿态估计在基于Transformer的3D人体姿态估计任务中,视频帧数量较多,导致计算成本过高。TPC通过动态剪枝减少冗余帧信息,适合需要高效计算的视频处理场景。

  • 资源受限设备上的模型部署在边缘设备、移动设备或其他计算资源受限的平台上,TPC可通过减少冗余计算降低FLOPs,从而降低模型的硬件需求并提升运行效率。

  • 需要保持模型性能的剪枝任务TPC能够在剪枝的同时保留语义丰富的代表性Token,适合对精度要求较高的任务,如人体姿态估计、视频动作识别等。

  • 高时空分辨率数据处理对于需要处理高时空分辨率数据(如长序列或高帧率视频)的场景,TPC可以动态选择关键帧,减少计算复杂度。

消融实验结果

  • 比较了两种推断管线(seq2seq 和 seq2frame)的效率(FPS)和准确性(MPJPE)。

  • 结果表明:seq2seq管线计算效率更高,但精度略低,而seq2frame管线精度更高但效率较低。通过整合HoT,显著降低了计算成本并提升了推断速度。


  • 比较了不同剪枝层数(n)的影响。

  • 结果显示:剪枝层数越深,计算成本(FLOPs)越低,但性能(MPJPE)略有下降;适当调整剪枝位置可在精度和效率之间取得良好平衡。

  • 测试了不同代表性Token数量(f)的影响。

  • 结果表明:选择适中的Token数量(如f=81)可在保留关键信息和减少冗余计算之间实现最佳权衡。

即插即用模块

import math
import torch
import torch.nn as nn
from functools import partial
from timm.models.layers import DropPath
def index_points(points, idx):
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]

    return new_points


def cluster_dpc_knn_center(x, cluster_num, k, center, token_mask=None):
    with torch.no_grad():
        B, N, C = x.shape

        dist_matrix = torch.cdist(x, x) / (C ** 0.5)

        if token_mask is not None:
            token_mask = token_mask > 0
            dist_matrix = dist_matrix * token_mask[:, None, :] + (dist_matrix.max() + 1) * (~token_mask[:, None, :])

        dist_nearest, index_nearest = torch.topk(dist_matrix, k=k, dim=-1, largest=False)

        density = (-(dist_nearest ** 2).mean(dim=-1)).exp()
        density = density + torch.rand(density.shape, device=density.device, dtype=density.dtype) * 1e-6

        if token_mask is not None:
            density = density * token_mask

        mask = density[:, None, :] > density[:, :, None]
        mask = mask.type(x.dtype)
        dist_max = dist_matrix.flatten(1).max(dim=-1)[0][:, None, None]
        dist, index_parent = (dist_matrix * mask + dist_max * (1 - mask)).min(dim=-1)

        score = dist * density

        ## remove center
        score[:, center] = -math.inf

        _, index_down = torch.topk(score, k=cluster_num, dim=-1)

        dist_matrix = index_points(dist_matrix, index_down)

        idx_cluster = dist_matrix.argmin(dim=1)

        idx_batch = torch.arange(B, device=x.device)[:, None].expand(B, cluster_num)
        idx_tmp = torch.arange(cluster_num, device=x.device)[None, :].expand(B, cluster_num)
        idx_cluster[idx_batch.reshape(-1), index_down.reshape(-1)] = idx_tmp.reshape(-1)

    return index_down, idx_cluster


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads

        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Cross_Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.linear_q = nn.Linear(dim, dim, bias=qkv_bias)
        self.linear_k = nn.Linear(dim, dim, bias=qkv_bias)
        self.linear_v = nn.Linear(dim, dim, bias=qkv_bias)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x_1, x_2, x_3):
        B, N, C = x_1.shape
        q = self.linear_q(x_1).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        k = self.linear_k(x_2).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        v = self.linear_v(x_3).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class SHR_Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_hidden_dim, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1_1 = norm_layer(dim)
        self.norm1_2 = norm_layer(dim)
        self.norm1_3 = norm_layer(dim)

        self.attn_1 = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, \
                                qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.attn_2 = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, \
                                qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.attn_3 = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, \
                                qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim * 3)
        self.mlp = Mlp(in_features=dim * 3, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x_1, x_2, x_3):
        x_1 = x_1 + self.drop_path(self.attn_1(self.norm1_1(x_1)))
        x_2 = x_2 + self.drop_path(self.attn_2(self.norm1_2(x_2)))
        x_3 = x_3 + self.drop_path(self.attn_3(self.norm1_3(x_3)))

        x = torch.cat([x_1, x_2, x_3], dim=2)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        x_1 = x[:, :, :x.shape[2] // 3]
        x_2 = x[:, :, x.shape[2] // 3: x.shape[2] // 3 * 2]
        x_3 = x[:, :, x.shape[2] // 3 * 2: x.shape[2]]

        return x_1, x_2, x_3


class CHI_Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_hidden_dim, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm3_11 = norm_layer(dim)
        self.norm3_12 = norm_layer(dim)
        self.norm3_13 = norm_layer(dim)

        self.norm3_21 = norm_layer(dim)
        self.norm3_22 = norm_layer(dim)
        self.norm3_23 = norm_layer(dim)

        self.norm3_31 = norm_layer(dim)
        self.norm3_32 = norm_layer(dim)
        self.norm3_33 = norm_layer(dim)

        self.attn_1 = Cross_Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, \
                                      qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.attn_2 = Cross_Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, \
                                      qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.attn_3 = Cross_Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, \
                                      qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim * 3)
        self.mlp = Mlp(in_features=dim * 3, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x_1, x_2, x_3):
        x_1 = x_1 + self.drop_path(self.attn_1(self.norm3_11(x_2), self.norm3_12(x_3), self.norm3_13(x_1)))
        x_2 = x_2 + self.drop_path(self.attn_2(self.norm3_21(x_1), self.norm3_22(x_3), self.norm3_23(x_2)))
        x_3 = x_3 + self.drop_path(self.attn_3(self.norm3_31(x_1), self.norm3_32(x_2), self.norm3_33(x_3)))

        x = torch.cat([x_1, x_2, x_3], dim=2)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        x_1 = x[:, :, :x.shape[2] // 3]
        x_2 = x[:, :, x.shape[2] // 3: x.shape[2] // 3 * 2]
        x_3 = x[:, :, x.shape[2] // 3 * 2: x.shape[2]]

        return x_1, x_2, x_3


class Transformer(nn.Module):
    def __init__(self, depth=3, embed_dim=512, mlp_hidden_dim=1024, token_num=117, layer_index=1, h=8, drop_rate=0.1,
                 length=27):
        super().__init__()
        drop_path_rate = 0.20
        attn_drop_rate = 0.
        qkv_bias = True
        qk_scale = None

        self.center = (length - 1) // 2
        self.token_num = token_num
        self.layer_index = layer_index

        print(self.token_num, self.layer_index)

        norm_layer = partial(nn.LayerNorm, eps=1e-6)

        self.pos_embed_1 = nn.Parameter(torch.zeros(1, length, embed_dim))
        self.pos_embed_2 = nn.Parameter(torch.zeros(1, length, embed_dim))
        self.pos_embed_3 = nn.Parameter(torch.zeros(1, length, embed_dim))

        self.pos_drop_1 = nn.Dropout(p=drop_rate)
        self.pos_drop_2 = nn.Dropout(p=drop_rate)
        self.pos_drop_3 = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]

        self.SHR_blocks = nn.ModuleList([
            SHR_Block(
                dim=embed_dim, num_heads=h, mlp_hidden_dim=mlp_hidden_dim, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth - 1)])

        self.CHI_blocks = nn.ModuleList([
            CHI_Block(
                dim=embed_dim, num_heads=h, mlp_hidden_dim=mlp_hidden_dim, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[depth - 1], norm_layer=norm_layer)
            for i in range(1)])

        self.norm = norm_layer(embed_dim * 3)

    def forward(self, x_1, x_2, x_3, index=None):
        b, f, c = x_1.shape

        x_1 += self.pos_embed_1
        x_2 += self.pos_embed_2
        x_3 += self.pos_embed_3

        x_1 = self.pos_drop_1(x_1)
        x_2 = self.pos_drop_2(x_2)
        x_3 = self.pos_drop_3(x_3)

        for i, blk in enumerate(self.SHR_blocks):
            ##-----------------Clusteing-----------------##
            if i == self.layer_index:
                if index is None:
                    x_knn = torch.cat([x_1, x_2, x_3], dim=2)

                    # 确保 cluster_num 不超过 x_knn 的序列长度
                    adjusted_cluster_num = min(self.token_num - 1, x_knn.shape[1] - 1)

                    index, idx_cluster = cluster_dpc_knn_center(x_knn, adjusted_cluster_num, 2, self.center)

                    index_center = self.center * torch.ones(b, 1, device=x_knn.device, dtype=index.dtype)
                    index = torch.cat([index, index_center], dim=-1)
                    index, _ = torch.sort(index)

                batch_ind = torch.arange(b, device=x_1.device).unsqueeze(-1)
                x_1 = x_1[batch_ind, index]
                x_2 = x_2[batch_ind, index]
                x_3 = x_3[batch_ind, index]
            ##-----------------Clusteing-----------------##

            x_1, x_2, x_3 = self.SHR_blocks[i](x_1, x_2, x_3)

        x_1, x_2, x_3 = self.CHI_blocks[0](x_1, x_2, x_3)

        x = torch.cat([x_1, x_2, x_3], dim=2)
        x = self.norm(x)

        return x, index



if __name__ == '__main__':

    args = {
        'depth': 3,
        'embed_dim': 512,
        'mlp_hidden_dim': 1024,
        'token_num': 117,
        'layer_index': 1,
        'h': 8,
        'drop_rate': 0.1,
        'length': 27
    }


    model = Transformer(depth=args['depth'], embed_dim=args['embed_dim'], mlp_hidden_dim=args['mlp_hidden_dim'],
                        token_num=args['token_num'], layer_index=args['layer_index'], h=args['h'],
                        drop_rate=args['drop_rate'], length=args['length'])

    batch_size = 2
    sequence_length = args['length']
    embed_dim = args['embed_dim']

    x_1 = torch.rand(batch_size, sequence_length, embed_dim)
    x_2 = torch.rand(batch_size, sequence_length, embed_dim)
    x_3 = torch.rand(batch_size, sequence_length, embed_dim)

    output, _ = model(x_1, x_2, x_3)

    print('Output size:', output.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章