CVPR 3D 点云补全模块PF-Net,即插即用,即用即涨点

文摘   2025-01-12 17:20   上海  

论文介绍

题目:PF-Net: Point Fractal Network for 3D Point Cloud Completion

论文地址:https://arxiv.org/abs/2003.00410

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 空间结构的保留与局部细节的还原:PF-Net与现有的点云补全网络不同,仅预测丢失的点云部分,而不修改原始部分。这种设计保留了原始点云的几何特性,同时使网络能够专注于识别丢失部分的结构和位置。

  • 多分辨率编码器(MRE):提出了新的特征提取器——组合多层感知器(CMLP),用于从不同分辨率的点云中提取多尺度特征。这种方法增强了网络对几何和语义信息的提取能力。

  • 点金字塔解码器(PPD):设计了一种基于特征点的分层生成网络,通过从粗到细逐步生成点云,既关注整体形状,又保留细节特征。该方法有效减少了现有方法中常见的几何失真问题。

  • 多阶段补全损失和对抗损失:结合了Chamfer Distance作为多阶段损失函数,使网络更关注特征点。此外,通过对抗损失,PF-Net在生成多种可能的点云模式时能更好地优化预测结果,生成更逼真的点云补全。

  • 高效的训练与预测:使用PyTorch框架实现,结合IFPS(迭代最远点采样)方法提高训练效率,同时适用于多种点云缺失比例的情况,展现了强大的鲁棒性。

方法

整体架构

       PF-Net 是一种用于 3D 点云补全的网络架构,由多分辨率编码器(MRE)、点金字塔解码器(PPD)和判别器组成。MRE 提取多分辨率的全局和局部特征,PPD 采用从粗到细的分层预测方法逐步生成丢失的点云部分,而判别器通过对抗损失提升生成点云的真实性。该模型在保留原始点云几何特性的同时,能有效还原丢失部分的细节,实现高精度、低失真的点云补全。

  • 多分辨率编码器(Multi-Resolution Encoder, MRE)

    • 使用一种新的特征提取器——组合多层感知器(CMLP),替代传统的多层感知器(MLP)。CMLP通过对不同层的特征进行拼接,形成一个高维的组合特征向量,保留了丰富的语义和几何信息。

    • 输入点云通过迭代最远点采样(Iterative Farthest Point Sampling, IFPS)方法生成不同分辨率的点云数据。

    • 每一层特征由独立的 CMLP 进行编码,最终所有层的特征向量拼接形成全局的特征表示。

    • 作用:提取点云的多尺度特征,包括局部和全局特征,以及低级和高级特征。

  • 点金字塔解码器(Point Pyramid Decoder, PPD)

    • 解码器以编码器输出的全局特征向量为输入。

    • 首先生成低分辨率的主中心点(Primary Center Points),作为点云的骨架。

    • 随后在更高分辨率层中预测更精细的**次中心点(Secondary Center Points)**和最终的详细点云,逐步完善点云的结构细节。

    • 各层通过特征传播(Feature Propagation)机制,从低分辨率到高分辨率逐步补全几何信息。

    • 这种分层结构受启发于数学中的分形几何(Fractal Geometry)

    • 作用:根据编码器提取的全局特征,逐层生成丢失的点云部分,采用从粗到细的分层预测方法。

  • 判别器(Discriminator)

    • 判别器接收生成的点云与真实点云,学习它们的分布差异。

    • 使用对抗损失(Adversarial Loss)优化生成器(编码器和解码器),以提高生成点云的真实性。

    • 作用:与生成对抗网络(GAN)的判别器类似,用于区分生成的点云和真实点云。

即插即用模块作用

PFNet 作为一个即插即用模块

  • 提升点云数据完整性:PF-Net 可以有效修复不完整的点云数据,通过生成丢失部分,提供更完整的三维环境或物体模型。

  • 提高后续算法的精度:对于后续的点云处理任务,如分类、分割、识别和建模,PF-Net 提供了更精确、完整的输入,避免了因为数据不完整导致的误差。

  • 适应不同场景的鲁棒性:PF-Net 在处理不同缺失比例和复杂度的点云数据时,展现出强大的鲁棒性和适应能力,能在不同的实际应用中保持高效的补全性能。

消融实验结果

  • 该表展示了 PointNet-MLP、PointNet-CMLP 和 PF-Net 在 ModelNet40 数据集上的分类准确度对比。通过比较,表明 CMLP 和 PF-Net 相较于传统的 PointNet-MLP 在分类任务中有更好的表现,尤其是在几何信息提取上表现更佳。

  • 该表验证了组合多层感知器(CMLP)在特征提取中的优势,尤其是在增强网络对几何和语义信息的理解方面。


  • 该表展示了 PF-Net(vanilla) 与其他几种基线(如单分辨率 MLP、CMLP 和多分辨率 CMLP)在“椅子”和“桌子”类别上的点云补全效果。PF-Net(vanilla) 在 Pred → GT 和 GT → Pred 错误上均表现出显著优势,特别是在细节保留和全局几何结构的重建上。

  • 该表通过与不同基线的比较,验证了多分辨率编码器(MRE)和点金字塔解码器(PPD)的有效性,说明这两个组件对于精确补全丢失的点云部分至关重要。


    • 该表展示了 PF-Net 在不同丢失比例(25%、50%、75%)下的鲁棒性测试结果。表中的结果表明,无论输入点云缺失多少,PF-Net 都能够稳定地生成准确的补全结果,且补全效果在不同缺失程度下变化较小。

    • 通过这一实验,验证了 PF-Net 在面对不同程度丢失的点云时具有很强的鲁棒性和适应能力。

即插即用模块

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
class Convlayer(nn.Module):
    def __init__(self, point_scales):
        super(Convlayer, self).__init__()
        self.point_scales = point_scales
        self.conv1 = torch.nn.Conv2d(1, 64, (1, 3))
        self.conv2 = torch.nn.Conv2d(64, 64, 1)
        self.conv3 = torch.nn.Conv2d(64, 128, 1)
        self.conv4 = torch.nn.Conv2d(128, 256, 1)
        self.conv5 = torch.nn.Conv2d(256, 512, 1)
        self.conv6 = torch.nn.Conv2d(512, 1024, 1)
        self.maxpool = torch.nn.MaxPool2d((self.point_scales, 1), 1)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        self.bn5 = nn.BatchNorm2d(512)
        self.bn6 = nn.BatchNorm2d(1024)

    def forward(self, x):
        x = torch.unsqueeze(x, 1)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x_128 = F.relu(self.bn3(self.conv3(x)))
        x_256 = F.relu(self.bn4(self.conv4(x_128)))
        x_512 = F.relu(self.bn5(self.conv5(x_256)))
        x_1024 = F.relu(self.bn6(self.conv6(x_512)))
        x_128 = torch.squeeze(self.maxpool(x_128), 2)
        x_256 = torch.squeeze(self.maxpool(x_256), 2)
        x_512 = torch.squeeze(self.maxpool(x_512), 2)
        x_1024 = torch.squeeze(self.maxpool(x_1024), 2)
        L = [x_1024, x_512, x_256, x_128]
        x = torch.cat(L, 1)
        return x


class Latentfeature(nn.Module):
    def __init__(self, num_scales, each_scales_size, point_scales_list):
        super(Latentfeature, self).__init__()
        self.num_scales = num_scales
        self.each_scales_size = each_scales_size
        self.point_scales_list = point_scales_list
        self.Convlayers1 = nn.ModuleList(
            [Convlayer(point_scales=self.point_scales_list[0]) for i in range(self.each_scales_size)])
        self.Convlayers2 = nn.ModuleList(
            [Convlayer(point_scales=self.point_scales_list[1]) for i in range(self.each_scales_size)])
        self.Convlayers3 = nn.ModuleList(
            [Convlayer(point_scales=self.point_scales_list[2]) for i in range(self.each_scales_size)])
        self.conv1 = torch.nn.Conv1d(3, 1, 1)
        self.bn1 = nn.BatchNorm1d(1)

    def forward(self, x):
        outs = []
        for i in range(self.each_scales_size):
            outs.append(self.Convlayers1[i](x[0]))
        for j in range(self.each_scales_size):
            outs.append(self.Convlayers2[j](x[1]))
        for k in range(self.each_scales_size):
            outs.append(self.Convlayers3[k](x[2]))
        latentfeature = torch.cat(outs, 2)
        latentfeature = latentfeature.transpose(1, 2)
        latentfeature = F.relu(self.bn1(self.conv1(latentfeature)))
        latentfeature = torch.squeeze(latentfeature, 1)
        return latentfeature


class PointcloudCls(nn.Module):
    def __init__(self, num_scales, each_scales_size, point_scales_list, k=40):
        super(PointcloudCls, self).__init__()
        self.latentfeature = Latentfeature(num_scales, each_scales_size, point_scales_list)
        self.fc1 = nn.Linear(1920, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, k)
        self.dropout = nn.Dropout(p=0.3)
        self.bn1 = nn.BatchNorm1d(1024)
        self.bn2 = nn.BatchNorm1d(512)
        self.bn3 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.latentfeature(x)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.dropout(self.fc2(x))))
        x = F.relu(self.bn3(self.dropout(self.fc3(x))))
        x = self.fc4(x)
        return F.log_softmax(x, dim=1)


class _netG(nn.Module):
    def __init__(self, num_scales, each_scales_size, point_scales_list, point_num):
        super(_netG, self).__init__()
        self.point_num = point_num # 保存输入的点数
        self.latentfeature = Latentfeature(num_scales, each_scales_size, point_scales_list)
        self.fc1 = nn.Linear(1920, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc_final = nn.Linear(256, point_num * 3) # 最后一个全连接层输出维度为 point_num * 3

    def forward(self, x):
        x = self.latentfeature(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc_final(x) # 输出维度为 [batch_size, point_num * 3]
        x = x.reshape(-1, self.point_num, 3) # 重塑为 [batch_size, point_num, 3]
        return x



class _netlocalD(nn.Module):
    def __init__(self, crop_point_num):
        super(_netlocalD, self).__init__()
        self.crop_point_num = crop_point_num
        self.conv1 = torch.nn.Conv2d(1, 64, (1, 3))
        self.conv2 = torch.nn.Conv2d(64, 64, 1)
        self.conv3 = torch.nn.Conv2d(64, 128, 1)
        self.conv4 = torch.nn.Conv2d(128, 256, 1)
        self.maxpool = torch.nn.MaxPool2d((self.crop_point_num, 1), 1)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        self.fc1 = nn.Linear(448, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 16)
        self.fc4 = nn.Linear(16, 1)
        self.bn_1 = nn.BatchNorm1d(256)
        self.bn_2 = nn.BatchNorm1d(128)
        self.bn_3 = nn.BatchNorm1d(16)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x_64 = F.relu(self.bn2(self.conv2(x)))
        x_128 = F.relu(self.bn3(self.conv3(x_64)))
        x_256 = F.relu(self.bn4(self.conv4(x_128)))
        x_64 = torch.squeeze(self.maxpool(x_64))
        x_128 = torch.squeeze(self.maxpool(x_128))
        x_256 = torch.squeeze(self.maxpool(x_256))
        Layers = [x_256, x_128, x_64]
        x = torch.cat(Layers, 1)
        x = F.relu(self.bn_1(self.fc1(x)))
        x = F.relu(self.bn_2(self.fc2(x)))
        x = F.relu(self.bn_3(self.fc3(x)))
        x = self.fc4(x)
        return x



if __name__ == '__main__':

    input1 = torch.randn(64, 2048, 3)
    input2 = torch.randn(64, 512, 3)
    input3 = torch.randn(64, 256, 3)
    input = [input1, input2, input3]
    block = _netG(num_scales=3, each_scales_size=1, point_scales_list=[2048, 512, 256], point_num=2048)
    output = block(input)    print(output.shape)

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章