论文介绍
题目:Hourglass Tokenizer for Efficient Transformer-Based 3D Human Pose Estimation
论文地址:https://arxiv.org/pdf/2311.12028
QQ深度学习交流群:994264161
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
引入Hourglass Tokenizer (HoT)框架:
提出一种“剪枝-恢复”的框架,用于视频中的3D人体姿态估计。
与现有方法不同,HoT框架通过在Transformer块中对姿态Token进行剪枝和恢复,显著降低了计算成本,同时保持了模型的高效性和准确性。
提出Token Pruning Cluster (TPC)模块:
动态选择具有高语义多样性的代表性Token,消除视频帧的冗余信息。
使用基于密度峰值的聚类算法(DPC-kNN),通过语义上具有代表性的聚类中心进行高效的姿态Token选择。
开发Token Recovering Attention (TRA)模块:
利用轻量级的跨注意力机制,恢复剪枝操作后丢失的详细时空信息。
实现从低时间分辨率到全时间分辨率的还原,满足快速推断的需求。
在Transformer架构中的通用性:
该方法可以无缝集成到现有的多种视频姿态Transformer (VPT) 模型(如MHFormer、MixSTE和MotionBERT),并支持两种主流推断管线(seq2seq和seq2frame)。
显著提升计算效率和推断速度:
在Human3.6M和MPI-INF-3DHP数据集上,与基线模型相比,HoT框架在减少高达50% FLOPs的同时保持或略微提升了模型性能。
引入通用化的剪枝和恢复策略:
提供灵活的剪枝和恢复参数设置,适应不同任务和硬件限制的需求。
方法
整体架构
论文提出的模型整体结构是一个基于Transformer的视频3D人体姿态估计框架,名为Hourglass Tokenizer (HoT)。模型接收2D姿态序列作为输入,经过姿态嵌入模块生成时空姿态Token,通过前几层Transformer捕获全局信息后,在中间引入**Token Pruning Cluster (TPC)模块对冗余Token进行动态剪枝,随后通过Token Recovering Attention (TRA)**模块在最后恢复全长度Token,最终通过回归头输出3D姿态。HoT框架以“剪枝-恢复”的方式优化了计算效率,适用于多种推断管线(seq2seq和seq2frame)。
1. 总体架构
论文提出了一种名为 Hourglass Tokenizer (HoT) 的框架,专为基于Transformer的视频3D人体姿态估计设计。该架构主要由以下部分组成:
输入阶段:接受2D人体姿态序列(每帧包含关键点坐标)。
姿态嵌入模块 (Pose Embedding Module):
对输入的2D姿态进行编码,生成包含时空信息的姿态Token。
Transformer块 (Transformer Blocks):
包含若干层Transformer块,用于捕捉全局的时空依赖关系。
在前几层中保持全长度的姿态Token,以保留丰富的信息。
2. Hourglass Tokenizer 关键模块
HoT框架通过两个关键模块实现高效的Token处理:
Token Pruning Cluster (TPC):
动态剪枝:在中间的Transformer块中剪除冗余的姿态Token,仅保留语义丰富的代表性Token。
剪枝过程基于密度聚类算法(DPC-kNN),选择具有高语义多样性的Token作为代表性Token。
Token Recovering Attention (TRA):
恢复全长度:在最后的Transformer块后,利用轻量级的多头交叉注意力(MCA)机制,从剪枝后的Token中恢复出全时间分辨率的姿态Token。
3. 推断管线
HoT支持两种主要推断管线:
seq2seq:
输入一个2D姿态序列,输出所有帧的3D姿态序列。
TPC模块在中间层剪枝,TRA模块在末端恢复Token。
seq2frame:
输入一个2D姿态序列,输出中心帧的3D姿态。
只使用TPC模块剪枝,不需要恢复全长度的Token。
4. 回归模块
姿态回归头 (Regression Head):
将恢复后的姿态Token或中心帧的Token映射为3D姿态坐标。
即插即用模块作用
TCA 作为一个即插即用模块:
高效视频3D人体姿态估计:在基于Transformer的3D人体姿态估计任务中,视频帧数量较多,导致计算成本过高。TPC通过动态剪枝减少冗余帧信息,适合需要高效计算的视频处理场景。
资源受限设备上的模型部署:在边缘设备、移动设备或其他计算资源受限的平台上,TPC可通过减少冗余计算降低FLOPs,从而降低模型的硬件需求并提升运行效率。
需要保持模型性能的剪枝任务:TPC能够在剪枝的同时保留语义丰富的代表性Token,适合对精度要求较高的任务,如人体姿态估计、视频动作识别等。
高时空分辨率数据处理:对于需要处理高时空分辨率数据(如长序列或高帧率视频)的场景,TPC可以动态选择关键帧,减少计算复杂度。
消融实验结果
比较了两种推断管线(seq2seq 和 seq2frame)的效率(FPS)和准确性(MPJPE)。
结果表明:seq2seq管线计算效率更高,但精度略低,而seq2frame管线精度更高但效率较低。通过整合HoT,显著降低了计算成本并提升了推断速度。
比较了不同剪枝层数(n)的影响。
结果显示:剪枝层数越深,计算成本(FLOPs)越低,但性能(MPJPE)略有下降;适当调整剪枝位置可在精度和效率之间取得良好平衡。
测试了不同代表性Token数量(f)的影响。
结果表明:选择适中的Token数量(如f=81)可在保留关键信息和减少冗余计算之间实现最佳权衡。
即插即用模块
import math
import torch
import torch.nn as nn
from functools import partial
from timm.models.layers import DropPath
def index_points(points, idx):
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
def cluster_dpc_knn_center(x, cluster_num, k, center, token_mask=None):
with torch.no_grad():
B, N, C = x.shape
dist_matrix = torch.cdist(x, x) / (C ** 0.5)
if token_mask is not None:
token_mask = token_mask > 0
dist_matrix = dist_matrix * token_mask[:, None, :] + (dist_matrix.max() + 1) * (~token_mask[:, None, :])
dist_nearest, index_nearest = torch.topk(dist_matrix, k=k, dim=-1, largest=False)
density = (-(dist_nearest ** 2).mean(dim=-1)).exp()
density = density + torch.rand(density.shape, device=density.device, dtype=density.dtype) * 1e-6
if token_mask is not None:
density = density * token_mask
mask = density[:, None, :] > density[:, :, None]
mask = mask.type(x.dtype)
dist_max = dist_matrix.flatten(1).max(dim=-1)[0][:, None, None]
dist, index_parent = (dist_matrix * mask + dist_max * (1 - mask)).min(dim=-1)
score = dist * density
## remove center
score[:, center] = -math.inf
_, index_down = torch.topk(score, k=cluster_num, dim=-1)
dist_matrix = index_points(dist_matrix, index_down)
idx_cluster = dist_matrix.argmin(dim=1)
idx_batch = torch.arange(B, device=x.device)[:, None].expand(B, cluster_num)
idx_tmp = torch.arange(cluster_num, device=x.device)[None, :].expand(B, cluster_num)
idx_cluster[idx_batch.reshape(-1), index_down.reshape(-1)] = idx_tmp.reshape(-1)
return index_down, idx_cluster
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Cross_Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.linear_q = nn.Linear(dim, dim, bias=qkv_bias)
self.linear_k = nn.Linear(dim, dim, bias=qkv_bias)
self.linear_v = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x_1, x_2, x_3):
B, N, C = x_1.shape
q = self.linear_q(x_1).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
k = self.linear_k(x_2).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
v = self.linear_v(x_3).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SHR_Block(nn.Module):
def __init__(self, dim, num_heads, mlp_hidden_dim, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1_1 = norm_layer(dim)
self.norm1_2 = norm_layer(dim)
self.norm1_3 = norm_layer(dim)
self.attn_1 = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, \
qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.attn_2 = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, \
qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.attn_3 = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, \
qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim * 3)
self.mlp = Mlp(in_features=dim * 3, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x_1, x_2, x_3):
x_1 = x_1 + self.drop_path(self.attn_1(self.norm1_1(x_1)))
x_2 = x_2 + self.drop_path(self.attn_2(self.norm1_2(x_2)))
x_3 = x_3 + self.drop_path(self.attn_3(self.norm1_3(x_3)))
x = torch.cat([x_1, x_2, x_3], dim=2)
x = x + self.drop_path(self.mlp(self.norm2(x)))
x_1 = x[:, :, :x.shape[2] // 3]
x_2 = x[:, :, x.shape[2] // 3: x.shape[2] // 3 * 2]
x_3 = x[:, :, x.shape[2] // 3 * 2: x.shape[2]]
return x_1, x_2, x_3
class CHI_Block(nn.Module):
def __init__(self, dim, num_heads, mlp_hidden_dim, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm3_11 = norm_layer(dim)
self.norm3_12 = norm_layer(dim)
self.norm3_13 = norm_layer(dim)
self.norm3_21 = norm_layer(dim)
self.norm3_22 = norm_layer(dim)
self.norm3_23 = norm_layer(dim)
self.norm3_31 = norm_layer(dim)
self.norm3_32 = norm_layer(dim)
self.norm3_33 = norm_layer(dim)
self.attn_1 = Cross_Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, \
qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.attn_2 = Cross_Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, \
qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.attn_3 = Cross_Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, \
qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim * 3)
self.mlp = Mlp(in_features=dim * 3, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x_1, x_2, x_3):
x_1 = x_1 + self.drop_path(self.attn_1(self.norm3_11(x_2), self.norm3_12(x_3), self.norm3_13(x_1)))
x_2 = x_2 + self.drop_path(self.attn_2(self.norm3_21(x_1), self.norm3_22(x_3), self.norm3_23(x_2)))
x_3 = x_3 + self.drop_path(self.attn_3(self.norm3_31(x_1), self.norm3_32(x_2), self.norm3_33(x_3)))
x = torch.cat([x_1, x_2, x_3], dim=2)
x = x + self.drop_path(self.mlp(self.norm2(x)))
x_1 = x[:, :, :x.shape[2] // 3]
x_2 = x[:, :, x.shape[2] // 3: x.shape[2] // 3 * 2]
x_3 = x[:, :, x.shape[2] // 3 * 2: x.shape[2]]
return x_1, x_2, x_3
class Transformer(nn.Module):
def __init__(self, depth=3, embed_dim=512, mlp_hidden_dim=1024, token_num=117, layer_index=1, h=8, drop_rate=0.1,
length=27):
super().__init__()
drop_path_rate = 0.20
attn_drop_rate = 0.
qkv_bias = True
qk_scale = None
self.center = (length - 1) // 2
self.token_num = token_num
self.layer_index = layer_index
print(self.token_num, self.layer_index)
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.pos_embed_1 = nn.Parameter(torch.zeros(1, length, embed_dim))
self.pos_embed_2 = nn.Parameter(torch.zeros(1, length, embed_dim))
self.pos_embed_3 = nn.Parameter(torch.zeros(1, length, embed_dim))
self.pos_drop_1 = nn.Dropout(p=drop_rate)
self.pos_drop_2 = nn.Dropout(p=drop_rate)
self.pos_drop_3 = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.SHR_blocks = nn.ModuleList([
SHR_Block(
dim=embed_dim, num_heads=h, mlp_hidden_dim=mlp_hidden_dim, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
for i in range(depth - 1)])
self.CHI_blocks = nn.ModuleList([
CHI_Block(
dim=embed_dim, num_heads=h, mlp_hidden_dim=mlp_hidden_dim, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[depth - 1], norm_layer=norm_layer)
for i in range(1)])
self.norm = norm_layer(embed_dim * 3)
def forward(self, x_1, x_2, x_3, index=None):
b, f, c = x_1.shape
x_1 += self.pos_embed_1
x_2 += self.pos_embed_2
x_3 += self.pos_embed_3
x_1 = self.pos_drop_1(x_1)
x_2 = self.pos_drop_2(x_2)
x_3 = self.pos_drop_3(x_3)
for i, blk in enumerate(self.SHR_blocks):
##-----------------Clusteing-----------------##
if i == self.layer_index:
if index is None:
x_knn = torch.cat([x_1, x_2, x_3], dim=2)
# 确保 cluster_num 不超过 x_knn 的序列长度
adjusted_cluster_num = min(self.token_num - 1, x_knn.shape[1] - 1)
index, idx_cluster = cluster_dpc_knn_center(x_knn, adjusted_cluster_num, 2, self.center)
index_center = self.center * torch.ones(b, 1, device=x_knn.device, dtype=index.dtype)
index = torch.cat([index, index_center], dim=-1)
index, _ = torch.sort(index)
batch_ind = torch.arange(b, device=x_1.device).unsqueeze(-1)
x_1 = x_1[batch_ind, index]
x_2 = x_2[batch_ind, index]
x_3 = x_3[batch_ind, index]
##-----------------Clusteing-----------------##
x_1, x_2, x_3 = self.SHR_blocks[i](x_1, x_2, x_3)
x_1, x_2, x_3 = self.CHI_blocks[0](x_1, x_2, x_3)
x = torch.cat([x_1, x_2, x_3], dim=2)
x = self.norm(x)
return x, index
if __name__ == '__main__':
args = {
'depth': 3,
'embed_dim': 512,
'mlp_hidden_dim': 1024,
'token_num': 117,
'layer_index': 1,
'h': 8,
'drop_rate': 0.1,
'length': 27
}
model = Transformer(depth=args['depth'], embed_dim=args['embed_dim'], mlp_hidden_dim=args['mlp_hidden_dim'],
token_num=args['token_num'], layer_index=args['layer_index'], h=args['h'],
drop_rate=args['drop_rate'], length=args['length'])
batch_size = 2
sequence_length = args['length']
embed_dim = args['embed_dim']
x_1 = torch.rand(batch_size, sequence_length, embed_dim)
x_2 = torch.rand(batch_size, sequence_length, embed_dim)
x_3 = torch.rand(batch_size, sequence_length, embed_dim)
output, _ = model(x_1, x_2, x_3)
print('Output size:', output.size())
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文