即插即用时序建模模块PoseBERT,涨点起飞起飞了

文摘   2025-01-23 17:20   中国香港  

论文介绍

题目:PoseBERT: A Generic Transformer Module for Temporal 3D Human Modeling

论文地址:https://ieeexplore.ieee.org/document/9982410

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 基于运动捕捉数据的无监督训练PoseBERT不依赖RGB图像的繁琐标注,而是使用运动捕捉(MoCap)数据,通过掩码建模训练模型。这一方法绕过了传统基于伪标注的困难,提高了数据获取的易用性。

  • 通用性和模块化设计PoseBERT是一个可以插入任何基于图像的姿态估计模型的模块,将其转变为基于视频的模型。其设计支持人体和手部建模,适配各种3D参数化模型(如SMPL和MANO)。

  • 多任务应用能力PoseBERT在无需微调的情况下,可以直接应用于一系列任务,如姿态序列去噪、恢复缺失姿态、姿态序列优化、动作补全以及未来动作预测。

  • 基于Transformer的时序建模使用Transformer架构捕获时间信息,使其能够平滑和优化预测姿态,特别是在处理输入序列中缺失或噪声数据时表现出色。

  • 高效的实时性能PoseBERT的计算开销低,可以在实时场景下以30帧/秒运行。例如,该模型已被用于基于摄像头的机械手实时遥操作演示。

  • 显著的性能改进将PoseBERT集成到现有的基于图像的模型中(如SPIN或VIBE),在多种基准数据集上的姿态优化任务中均显著提升了性能。

方法

整体架构

       PoseBERT 的整体结构包括一个输入处理模块,将 3D 姿态序列嵌入到高维特征空间,并加入位置编码;一个基于多层 Transformer 的时序建模模块,用于捕获时间依赖关系和上下文信息;一个逐层迭代回归模块,通过多层感知机逐步优化姿态参数;以及一个训练损失函数,包括重构损失和去噪设计,用于生成平滑且一致的 3D 姿态序列。模型具有通用性和高效性,可插入到任何基于图像的姿态估计方法中,支持实时处理和多任务应用。

1. 输入表示(Input Representations)

  • 输入数据: 输入是一个由TT 帧组成的姿态序列P={p1,p2,...,pT}P = \{p_1, p_2, ..., p_T\}。每帧的姿态表示可以是 3D 骨架关键点、3D 参数化模型(如 SMPL 或 MANO)的旋转参数。

  • 嵌入处理

    • 对每个姿态ptp_t 使用线性投影将其嵌入到高维空间中,形成DD-维的特征向量。

    • 如果某些帧的姿态缺失,则用一个可学习的特殊标记(mask token)代替。

    • 为每个时间步添加位置编码,以引入时序信息。


2. Transformer 模块(Temporal Modeling with Transformer)

  • 多层 Transformer 块

    • 输入经过一系列标准的 Transformer 块,每个块包含多头自注意力机制和前馈网络。

    • 自注意力机制捕捉输入序列中的时间依赖关系。

    • 输出通过残差连接和层归一化处理,形成更新后的特征表示。

  • 上下文建模

    • Transformer 模块通过时间上下文信息对输入序列进行建模,使其能够平滑预测姿态并填补缺失帧。


3. 逐层迭代回归(Iterative Pose Regression)

  • 姿态参数更新

    • Transformer 块的输出与当前的姿态估计θtl1\theta^{l-1}_t 进行拼接,通过多层感知机(MLP)生成新的姿态增量Δθtl\Delta\theta^l_t

    • 更新规则为:θtl=θtl1+Δθtl

    • 初始姿态θ0\theta^0 通常设置为参数化模型的平均姿态。

  • 逐层优化

    • MLP 在每一层 Transformer 模块后进行迭代优化,逐步提高姿态估计精度。


4. 损失函数(Loss Function)

  • 重构损失(Reconstruction Loss)

    • 姿态参数的 L2 损失:Lpose=t=1Tθtθt2

    • 平移参数的 L2 损失:Ltranslation=t=1Tγtγt2

  • 噪声鲁棒性(Denoising)

    • 在输入序列中引入高斯噪声或随机替换部分帧,模拟实际场景中的遮挡或模糊。


5. 输出(Output Representation)

  • 预测结果

    • 输出序列M={m1,m2,...,mT}M = \{m_1, m_2, ..., m_T\} 是时间一致的 3D 姿态网格(如人体或手部的 SMPL/MANO 模型)。

  • 填补缺失帧

    • 模型能够生成平滑且一致的姿态序列,即使在输入中有缺失的帧。


6. 实时应用能力

  • 低计算开销

    • 由于其高效的结构设计,PoseBERT 能够以 30 FPS 的速度实时处理输入序列,适合在线应用。

即插即用模块作用

PoseBERT 作为一个即插即用模块

  • 时序建模(Temporal Modeling)

    • PoseBERT 基于 Transformer 架构,能够捕捉姿态序列中的时间依赖关系,为姿态估计引入上下文信息,从而平滑预测结果。

    • 例如,在存在运动模糊或遮挡的情况下,PoseBERT 可利用邻帧信息推断缺失或模糊的姿态。

  • 去噪与修复(Denoising and Refinement)

    • 对于从图像中估计的噪声姿态,PoseBERT 可以通过时序上下文信息优化每一帧的姿态,生成更准确的 3D 网格。

    • 例如,在手部或人体动作捕捉中,PoseBERT 可减少突变、错误预测或数据丢失的影响。

  • 任务无关的通用性(Task-agnostic Generality)

    • PoseBERT 采用基于 MoCap 数据的无监督训练,不依赖特定任务的标注,能够灵活应用于多种任务(如未来预测、动作补全)。

    • 这种通用性使其适合作为一个模块化解决方案。

  • 实时高效性(Real-time Efficiency)

    • PoseBERT 计算开销低,能够以 30 帧/秒的速度运行,适合实时场景。

    • 例如,在机器人操作中,PoseBERT 可实时生成平滑的控制信号。

  • 兼容性与模块化(Compatibility and Modularity)

    • PoseBERT 可以无缝集成到现有的基于图像的姿态估计框架中,作为一个“增强模块”提升性能。

    • 它的即插即用特性适合在不同的数据和模型中快速部署。

消融实验结果

  • 内容:研究了不同预训练策略对模型性能的影响。

  • 实验变量

    • 掩码比例(Masking Percentage):随机屏蔽输入序列的一部分帧,模拟缺失数据。

    • 添加噪声(Noise Injection):在输入序列中加入高斯噪声,模拟实际应用中的遮挡或运动模糊。

  • 结果

    • 屏蔽输入序列的 12.5% 帧可以显著提升预测的平滑度(加速度误差 Accel 降低)。

    • 添加高斯噪声进一步提高了性能和序列平滑度。

  • 结论

    • 使用掩码和噪声注入能够提高 PoseBERT 对时序动态建模的鲁棒性。


  • 内容:研究了不同训练策略对模型性能的影响。

  • 实验变量

    • 添加回归器(Regressor)到 Transformer 中。

    • 不同的帧率(fps):降低帧率以观察较长时间窗口的影响。

    • 随机替换帧(Random Pose/Joints):用随机姿态或关节替换输入的一部分帧。

  • 结果

    • 在 Transformer 中集成回归器能够提高性能,尤其是在更大的时间窗口中建模时序关系。

    • 降低帧率并未显著提高性能,且过低帧率会导致性能下降。

    • 随机替换帧或关节有助于训练的鲁棒性,但加入比例过高会损害性能。

  • 结论

    • PoseBERT 的最佳训练策略是适度的掩码和噪声注入,同时保持合理的时间窗口。


  • 内容:研究模型超参数的影响。

  • 实验变量

    • 是否使用位置编码(Positional Encoding)。

    • 回归器的共享性(Shared Regressor)。

    • Transformer 的深度(Layer Depth,LL)。

    • 嵌入维度(Embedding Dimension,DD)。

    • 输入序列长度(Sequence Length,TT)。

  • 结果

    • 移除位置编码会导致性能下降,说明时间信息对建模至关重要。

    • 共享回归器参数能够减少模型复杂度并略微提高性能。

    • Transformer 的最佳深度为L=4L=4,嵌入维度为D=512D=512,序列长度为T=16T=16

    • 增加深度、嵌入维度或序列长度过多,性能提升有限且带来更高的计算开销。

  • 结论

    • 适当的模型深度和输入长度能够平衡性能与计算效率。

即插即用模块

import torch
from torch import nn
import roma
from einops import rearrange



class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class FeedForwardResidual(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0., out_dim=24 * 6):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim + out_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, out_dim),
        )
        nn.init.xavier_uniform_(self.net[-1].weight, gain=0.01)

    def forward(self, x, init, n_iter=1):
        pred_pose = init
        for _ in range(n_iter):
            xf = torch.cat([x, init], -1)
            pred_pose = pred_pose + self.net(xf)
        return pred_pose


class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask=None):
        """
        Args:
            - x: [batch_size,seq_len,dim]
            - mask: [batch_size,seq_len] - dytpe= torch.bool - default True everywhere, if False it means that we don't pay attention to this timestep
        "
""
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)

        dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale # [B,H,T,T]
        mask_value = -torch.finfo(dots.dtype).max

        if mask is not None:  # always true
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = mask.unsqueeze(1).unsqueeze(1).repeat(1, 1, n, 1) # updating masked timesteps with context
            dots.masked_fill_(~mask, mask_value) # ~ do the opposite i.e. move True to False here
            del mask
        attn = dots.softmax(dim=-1)

        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1024):
        super(PositionalEncoding, self).__init__()
        self.pe = nn.Parameter(torch.randn(1, max_len, d_model))

    def forward(self, x, start=0):
        x = x + self.pe[:, start:(start + x.size(1))]
        return x


class TransformerRegressor(nn.Module):

    def __init__(self, dim, depth=2, heads=8, dim_head=32, mlp_dim=32, dropout=0.1, out=[22 * 6, 3],
                 share_regressor=False)
:
        super().__init__()

        self.layers = nn.ModuleList([])
        for i in range(depth):
            list_modules = [
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ]

            # Regressor
            if i == 0 or not share_regressor:
                # N regressor per layer
                for out_i in out:
                    list_modules.append(PreNorm(dim, FeedForwardResidual(dim, mlp_dim, dropout=dropout, out_dim=out_i)))
            else:
                # Share regressor across layers
                for j in range(2, len(self.layers[0])):
                    list_modules.append(self.layers[0][j])
            self.layers.append(nn.ModuleList(list_modules))

    def forward(self, x, init, mask=None):
        batch_size, seq_len, *_ = x.size()
        y = init
        for layers_i in self.layers:
            # attention and feeforward module
            attn, ff = layers_i[0], layers_i[1]
            x = attn(x, mask=mask) + x
            x = ff(x) + x

            # N regressors
            for j, reg in enumerate(layers_i[2:]):
                y[j] = reg(x, init=y[j], n_iter=1)

        return y


class PoseBERT(nn.Module):
    def __init__(self,
                 in_dim=24 * 6, n_jts_out=24, init_pose=None,
                 dim=512, depth=4, heads=8, dim_head=64, mlp_dim=512, dropout=0.1,
                 share_regressor=1,
                 *args, **kwargs)
:
        super(PoseBERT, self).__init__()

        self.pos = PositionalEncoding(dim, 1024)
        self.emb = nn.Linear(in_dim, dim)
        self.mask_token = nn.Parameter(torch.randn(1, 1, dim))

        self.decoder = TransformerRegressor(dim, depth, heads, dim_head, mlp_dim, dropout,
                                            [n_jts_out * 6],
                                            share_regressor == 1)

        if init_pose is None:
            init_pose = torch.zeros(n_jts_out * 6).float()
        self.register_buffer('init_pose', init_pose.reshape(1, 1, -1))

        # Type of input
        if in_dim == 24 * 6:
            self.input = 'rotmat'
        elif in_dim == 16 * 3 + 6:
            self.input = 'h36m'
        else:
            raise NameError

    def forward(self, rotmat, root=None, rel=None, mask=None):
        """
        Args:
            - rotmat: torch.Tensor - torch.float32 - [batch_size, seq_len, 24, 3, 3]
            - root: torch.Tensor - torch.float32 - [batch_size, seq_len, 3, 3]
            - rel: torch.Tensor - torch.float32 - [batch_size, seq_len, 17, 3]
            - mask: torch.Tensor - torch.bool - [batch_size, seq_len]
        Return:
            - y: torch.Tensor - [batch_size, seq_len, 24, 3, 3] - torch.float32
        "
""

        # Handling input
        if self.input == 'rotmat':
            assert rotmat is not None
            # Keep 6D representation only and concat
            x = rotmat[..., :2].flatten(2) # [batch_size, seq-len, in_dim]
        elif self.input == 'h36m':
            assert root is not None and rel is not None
            # 6D repr of the root rotation and keep the relative pose only (discard the hip because it is centered)
            x = torch.cat([root[..., :2].flatten(2), rel[:, :, 1:].flatten(2)], -1)
        else:
            raise NameError

        batch_size, seq_len, *_ = x.size()

        # Default masks
        if mask is None:
            mask = torch.ones(batch_size, seq_len).type_as(x).bool()

        # Input embedding
        x = self.emb(x)
        x = x * mask.float().unsqueeze(-1) + self.mask_token * (1. - mask.float().unsqueeze(-1)) # masked token
        x = self.pos(x) # inject position info

        # Transformer
        init = [self.init_pose.repeat(batch_size, seq_len, 1)] # init mean pose
        y = self.decoder(x, init, mask)[0]

        # Move from rotation representation from 6D to 9D
        y = roma.special_gramschmidt(y.reshape(batch_size, seq_len, -1, 3, 2))

        return y

if __name__ == '__main__':
    # 初始化模型
    model = PoseBERT()

    # 生成随机旋转矩阵输入,形状为 [batch_size, seq_len, n_jts, rotmat_size, rotmat_size]
    input_rotmat = torch.randn(2, 10, 24, 3, 3)

    # 调用模型
    output = model(rotmat=input_rotmat)

    # 打印输入和输出尺寸
    print("Input size:", input_rotmat.size())
    print("Output size:", output.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章