论文介绍
题目:PoseBERT: A Generic Transformer Module for Temporal 3D Human Modeling
论文地址:https://ieeexplore.ieee.org/document/9982410
QQ深度学习交流群:994264161
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
基于运动捕捉数据的无监督训练:PoseBERT不依赖RGB图像的繁琐标注,而是使用运动捕捉(MoCap)数据,通过掩码建模训练模型。这一方法绕过了传统基于伪标注的困难,提高了数据获取的易用性。
通用性和模块化设计:PoseBERT是一个可以插入任何基于图像的姿态估计模型的模块,将其转变为基于视频的模型。其设计支持人体和手部建模,适配各种3D参数化模型(如SMPL和MANO)。
多任务应用能力:PoseBERT在无需微调的情况下,可以直接应用于一系列任务,如姿态序列去噪、恢复缺失姿态、姿态序列优化、动作补全以及未来动作预测。
基于Transformer的时序建模:使用Transformer架构捕获时间信息,使其能够平滑和优化预测姿态,特别是在处理输入序列中缺失或噪声数据时表现出色。
高效的实时性能:PoseBERT的计算开销低,可以在实时场景下以30帧/秒运行。例如,该模型已被用于基于摄像头的机械手实时遥操作演示。
显著的性能改进:将PoseBERT集成到现有的基于图像的模型中(如SPIN或VIBE),在多种基准数据集上的姿态优化任务中均显著提升了性能。
方法
整体架构
PoseBERT 的整体结构包括一个输入处理模块,将 3D 姿态序列嵌入到高维特征空间,并加入位置编码;一个基于多层 Transformer 的时序建模模块,用于捕获时间依赖关系和上下文信息;一个逐层迭代回归模块,通过多层感知机逐步优化姿态参数;以及一个训练损失函数,包括重构损失和去噪设计,用于生成平滑且一致的 3D 姿态序列。模型具有通用性和高效性,可插入到任何基于图像的姿态估计方法中,支持实时处理和多任务应用。
1. 输入表示(Input Representations)
输入数据: 输入是一个由
帧组成的姿态序列T T 。每帧的姿态表示可以是 3D 骨架关键点、3D 参数化模型(如 SMPL 或 MANO)的旋转参数。P = { p 1 , p 2 , . . . , p T } P = \{p_1, p_2, ..., p_T\} 嵌入处理:
对每个姿态
使用线性投影将其嵌入到高维空间中,形成p t p_t -维的特征向量。D D 如果某些帧的姿态缺失,则用一个可学习的特殊标记(mask token)代替。
为每个时间步添加位置编码,以引入时序信息。
2. Transformer 模块(Temporal Modeling with Transformer)
多层 Transformer 块:
输入经过一系列标准的 Transformer 块,每个块包含多头自注意力机制和前馈网络。
自注意力机制捕捉输入序列中的时间依赖关系。
输出通过残差连接和层归一化处理,形成更新后的特征表示。
上下文建模:
Transformer 模块通过时间上下文信息对输入序列进行建模,使其能够平滑预测姿态并填补缺失帧。
3. 逐层迭代回归(Iterative Pose Regression)
姿态参数更新:
Transformer 块的输出与当前的姿态估计
进行拼接,通过多层感知机(MLP)生成新的姿态增量θ t l − 1 \theta^{l-1}_t 。Δ θ t l \Delta\theta^l_t 更新规则为:
θ t l = θ t l − 1 + Δ θ t l 初始姿态
通常设置为参数化模型的平均姿态。θ 0 \theta^0 逐层优化:
MLP 在每一层 Transformer 模块后进行迭代优化,逐步提高姿态估计精度。
4. 损失函数(Loss Function)
重构损失(Reconstruction Loss):
姿态参数的 L2 损失:
L pose = ∑ t = 1 T ∣ ∣ θ t − θ t ′ ∣ ∣ 2 平移参数的 L2 损失:
L translation = ∑ t = 1 T ∣ ∣ γ t − γ t ′ ∣ ∣ 2 噪声鲁棒性(Denoising):
在输入序列中引入高斯噪声或随机替换部分帧,模拟实际场景中的遮挡或模糊。
5. 输出(Output Representation)
预测结果:
输出序列
是时间一致的 3D 姿态网格(如人体或手部的 SMPL/MANO 模型)。M = { m 1 , m 2 , . . . , m T } M = \{m_1, m_2, ..., m_T\} 填补缺失帧:
模型能够生成平滑且一致的姿态序列,即使在输入中有缺失的帧。
6. 实时应用能力
低计算开销:
由于其高效的结构设计,PoseBERT 能够以 30 FPS 的速度实时处理输入序列,适合在线应用。
即插即用模块作用
PoseBERT 作为一个即插即用模块:
时序建模(Temporal Modeling):
PoseBERT 基于 Transformer 架构,能够捕捉姿态序列中的时间依赖关系,为姿态估计引入上下文信息,从而平滑预测结果。
例如,在存在运动模糊或遮挡的情况下,PoseBERT 可利用邻帧信息推断缺失或模糊的姿态。
去噪与修复(Denoising and Refinement):
对于从图像中估计的噪声姿态,PoseBERT 可以通过时序上下文信息优化每一帧的姿态,生成更准确的 3D 网格。
例如,在手部或人体动作捕捉中,PoseBERT 可减少突变、错误预测或数据丢失的影响。
任务无关的通用性(Task-agnostic Generality):
PoseBERT 采用基于 MoCap 数据的无监督训练,不依赖特定任务的标注,能够灵活应用于多种任务(如未来预测、动作补全)。
这种通用性使其适合作为一个模块化解决方案。
实时高效性(Real-time Efficiency):
PoseBERT 计算开销低,能够以 30 帧/秒的速度运行,适合实时场景。
例如,在机器人操作中,PoseBERT 可实时生成平滑的控制信号。
兼容性与模块化(Compatibility and Modularity):
PoseBERT 可以无缝集成到现有的基于图像的姿态估计框架中,作为一个“增强模块”提升性能。
它的即插即用特性适合在不同的数据和模型中快速部署。
消融实验结果
内容:研究了不同预训练策略对模型性能的影响。
实验变量:
掩码比例(Masking Percentage):随机屏蔽输入序列的一部分帧,模拟缺失数据。
添加噪声(Noise Injection):在输入序列中加入高斯噪声,模拟实际应用中的遮挡或运动模糊。
结果:
屏蔽输入序列的 12.5% 帧可以显著提升预测的平滑度(加速度误差 Accel 降低)。
添加高斯噪声进一步提高了性能和序列平滑度。
结论:
使用掩码和噪声注入能够提高 PoseBERT 对时序动态建模的鲁棒性。
内容:研究了不同训练策略对模型性能的影响。
实验变量:
添加回归器(Regressor)到 Transformer 中。
不同的帧率(fps):降低帧率以观察较长时间窗口的影响。
随机替换帧(Random Pose/Joints):用随机姿态或关节替换输入的一部分帧。
结果:
在 Transformer 中集成回归器能够提高性能,尤其是在更大的时间窗口中建模时序关系。
降低帧率并未显著提高性能,且过低帧率会导致性能下降。
随机替换帧或关节有助于训练的鲁棒性,但加入比例过高会损害性能。
结论:
PoseBERT 的最佳训练策略是适度的掩码和噪声注入,同时保持合理的时间窗口。
内容:研究模型超参数的影响。
实验变量:
是否使用位置编码(Positional Encoding)。
回归器的共享性(Shared Regressor)。
Transformer 的深度(Layer Depth,
)。L L 嵌入维度(Embedding Dimension,
)。D D 输入序列长度(Sequence Length,
)。T T 结果:
移除位置编码会导致性能下降,说明时间信息对建模至关重要。
共享回归器参数能够减少模型复杂度并略微提高性能。
Transformer 的最佳深度为
,嵌入维度为L = 4 L=4 ,序列长度为D = 512 D=512 。T = 16 T=16 增加深度、嵌入维度或序列长度过多,性能提升有限且带来更高的计算开销。
结论:
适当的模型深度和输入长度能够平衡性能与计算效率。
即插即用模块
import torch
from torch import nn
import roma
from einops import rearrange
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class FeedForwardResidual(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0., out_dim=24 * 6):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim + out_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, out_dim),
)
nn.init.xavier_uniform_(self.net[-1].weight, gain=0.01)
def forward(self, x, init, n_iter=1):
pred_pose = init
for _ in range(n_iter):
xf = torch.cat([x, init], -1)
pred_pose = pred_pose + self.net(xf)
return pred_pose
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask=None):
"""
Args:
- x: [batch_size,seq_len,dim]
- mask: [batch_size,seq_len] - dytpe= torch.bool - default True everywhere, if False it means that we don't pay attention to this timestep
"""
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale # [B,H,T,T]
mask_value = -torch.finfo(dots.dtype).max
if mask is not None: # always true
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask.unsqueeze(1).unsqueeze(1).repeat(1, 1, n, 1) # updating masked timesteps with context
dots.masked_fill_(~mask, mask_value) # ~ do the opposite i.e. move True to False here
del mask
attn = dots.softmax(dim=-1)
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=1024):
super(PositionalEncoding, self).__init__()
self.pe = nn.Parameter(torch.randn(1, max_len, d_model))
def forward(self, x, start=0):
x = x + self.pe[:, start:(start + x.size(1))]
return x
class TransformerRegressor(nn.Module):
def __init__(self, dim, depth=2, heads=8, dim_head=32, mlp_dim=32, dropout=0.1, out=[22 * 6, 3],
share_regressor=False):
super().__init__()
self.layers = nn.ModuleList([])
for i in range(depth):
list_modules = [
PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
]
# Regressor
if i == 0 or not share_regressor:
# N regressor per layer
for out_i in out:
list_modules.append(PreNorm(dim, FeedForwardResidual(dim, mlp_dim, dropout=dropout, out_dim=out_i)))
else:
# Share regressor across layers
for j in range(2, len(self.layers[0])):
list_modules.append(self.layers[0][j])
self.layers.append(nn.ModuleList(list_modules))
def forward(self, x, init, mask=None):
batch_size, seq_len, *_ = x.size()
y = init
for layers_i in self.layers:
# attention and feeforward module
attn, ff = layers_i[0], layers_i[1]
x = attn(x, mask=mask) + x
x = ff(x) + x
# N regressors
for j, reg in enumerate(layers_i[2:]):
y[j] = reg(x, init=y[j], n_iter=1)
return y
class PoseBERT(nn.Module):
def __init__(self,
in_dim=24 * 6, n_jts_out=24, init_pose=None,
dim=512, depth=4, heads=8, dim_head=64, mlp_dim=512, dropout=0.1,
share_regressor=1,
*args, **kwargs):
super(PoseBERT, self).__init__()
self.pos = PositionalEncoding(dim, 1024)
self.emb = nn.Linear(in_dim, dim)
self.mask_token = nn.Parameter(torch.randn(1, 1, dim))
self.decoder = TransformerRegressor(dim, depth, heads, dim_head, mlp_dim, dropout,
[n_jts_out * 6],
share_regressor == 1)
if init_pose is None:
init_pose = torch.zeros(n_jts_out * 6).float()
self.register_buffer('init_pose', init_pose.reshape(1, 1, -1))
# Type of input
if in_dim == 24 * 6:
self.input = 'rotmat'
elif in_dim == 16 * 3 + 6:
self.input = 'h36m'
else:
raise NameError
def forward(self, rotmat, root=None, rel=None, mask=None):
"""
Args:
- rotmat: torch.Tensor - torch.float32 - [batch_size, seq_len, 24, 3, 3]
- root: torch.Tensor - torch.float32 - [batch_size, seq_len, 3, 3]
- rel: torch.Tensor - torch.float32 - [batch_size, seq_len, 17, 3]
- mask: torch.Tensor - torch.bool - [batch_size, seq_len]
Return:
- y: torch.Tensor - [batch_size, seq_len, 24, 3, 3] - torch.float32
"""
# Handling input
if self.input == 'rotmat':
assert rotmat is not None
# Keep 6D representation only and concat
x = rotmat[..., :2].flatten(2) # [batch_size, seq-len, in_dim]
elif self.input == 'h36m':
assert root is not None and rel is not None
# 6D repr of the root rotation and keep the relative pose only (discard the hip because it is centered)
x = torch.cat([root[..., :2].flatten(2), rel[:, :, 1:].flatten(2)], -1)
else:
raise NameError
batch_size, seq_len, *_ = x.size()
# Default masks
if mask is None:
mask = torch.ones(batch_size, seq_len).type_as(x).bool()
# Input embedding
x = self.emb(x)
x = x * mask.float().unsqueeze(-1) + self.mask_token * (1. - mask.float().unsqueeze(-1)) # masked token
x = self.pos(x) # inject position info
# Transformer
init = [self.init_pose.repeat(batch_size, seq_len, 1)] # init mean pose
y = self.decoder(x, init, mask)[0]
# Move from rotation representation from 6D to 9D
y = roma.special_gramschmidt(y.reshape(batch_size, seq_len, -1, 3, 2))
return y
if __name__ == '__main__':
# 初始化模型
model = PoseBERT()
# 生成随机旋转矩阵输入,形状为 [batch_size, seq_len, n_jts, rotmat_size, rotmat_size]
input_rotmat = torch.randn(2, 10, 24, 3, 3)
# 调用模型
output = model(rotmat=input_rotmat)
# 打印输入和输出尺寸
print("Input size:", input_rotmat.size())
print("Output size:", output.size())
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文