论文介绍
题目:PF-Net: Point Fractal Network for 3D Point Cloud Completion
论文地址:https://arxiv.org/abs/2003.00410
QQ深度学习交流群:994264161
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
空间结构的保留与局部细节的还原:PF-Net与现有的点云补全网络不同,仅预测丢失的点云部分,而不修改原始部分。这种设计保留了原始点云的几何特性,同时使网络能够专注于识别丢失部分的结构和位置。
多分辨率编码器(MRE):提出了新的特征提取器——组合多层感知器(CMLP),用于从不同分辨率的点云中提取多尺度特征。这种方法增强了网络对几何和语义信息的提取能力。
点金字塔解码器(PPD):设计了一种基于特征点的分层生成网络,通过从粗到细逐步生成点云,既关注整体形状,又保留细节特征。该方法有效减少了现有方法中常见的几何失真问题。
多阶段补全损失和对抗损失:结合了Chamfer Distance作为多阶段损失函数,使网络更关注特征点。此外,通过对抗损失,PF-Net在生成多种可能的点云模式时能更好地优化预测结果,生成更逼真的点云补全。
高效的训练与预测:使用PyTorch框架实现,结合IFPS(迭代最远点采样)方法提高训练效率,同时适用于多种点云缺失比例的情况,展现了强大的鲁棒性。
方法
整体架构
PF-Net 是一种用于 3D 点云补全的网络架构,由多分辨率编码器(MRE)、点金字塔解码器(PPD)和判别器组成。MRE 提取多分辨率的全局和局部特征,PPD 采用从粗到细的分层预测方法逐步生成丢失的点云部分,而判别器通过对抗损失提升生成点云的真实性。该模型在保留原始点云几何特性的同时,能有效还原丢失部分的细节,实现高精度、低失真的点云补全。
多分辨率编码器(Multi-Resolution Encoder, MRE)
使用一种新的特征提取器——组合多层感知器(CMLP),替代传统的多层感知器(MLP)。CMLP通过对不同层的特征进行拼接,形成一个高维的组合特征向量,保留了丰富的语义和几何信息。
输入点云通过迭代最远点采样(Iterative Farthest Point Sampling, IFPS)方法生成不同分辨率的点云数据。
每一层特征由独立的 CMLP 进行编码,最终所有层的特征向量拼接形成全局的特征表示。
作用:提取点云的多尺度特征,包括局部和全局特征,以及低级和高级特征。
点金字塔解码器(Point Pyramid Decoder, PPD)
解码器以编码器输出的全局特征向量为输入。
首先生成低分辨率的主中心点(Primary Center Points),作为点云的骨架。
随后在更高分辨率层中预测更精细的**次中心点(Secondary Center Points)**和最终的详细点云,逐步完善点云的结构细节。
各层通过特征传播(Feature Propagation)机制,从低分辨率到高分辨率逐步补全几何信息。
这种分层结构受启发于数学中的分形几何(Fractal Geometry)。
作用:根据编码器提取的全局特征,逐层生成丢失的点云部分,采用从粗到细的分层预测方法。
判别器(Discriminator)
判别器接收生成的点云与真实点云,学习它们的分布差异。
使用对抗损失(Adversarial Loss)优化生成器(编码器和解码器),以提高生成点云的真实性。
作用:与生成对抗网络(GAN)的判别器类似,用于区分生成的点云和真实点云。
即插即用模块作用
PFNet 作为一个即插即用模块:
提升点云数据完整性:PF-Net 可以有效修复不完整的点云数据,通过生成丢失部分,提供更完整的三维环境或物体模型。
提高后续算法的精度:对于后续的点云处理任务,如分类、分割、识别和建模,PF-Net 提供了更精确、完整的输入,避免了因为数据不完整导致的误差。
适应不同场景的鲁棒性:PF-Net 在处理不同缺失比例和复杂度的点云数据时,展现出强大的鲁棒性和适应能力,能在不同的实际应用中保持高效的补全性能。
消融实验结果
该表展示了 PointNet-MLP、PointNet-CMLP 和 PF-Net 在 ModelNet40 数据集上的分类准确度对比。通过比较,表明 CMLP 和 PF-Net 相较于传统的 PointNet-MLP 在分类任务中有更好的表现,尤其是在几何信息提取上表现更佳。
该表验证了组合多层感知器(CMLP)在特征提取中的优势,尤其是在增强网络对几何和语义信息的理解方面。
该表展示了 PF-Net(vanilla) 与其他几种基线(如单分辨率 MLP、CMLP 和多分辨率 CMLP)在“椅子”和“桌子”类别上的点云补全效果。PF-Net(vanilla) 在 Pred → GT 和 GT → Pred 错误上均表现出显著优势,特别是在细节保留和全局几何结构的重建上。
该表通过与不同基线的比较,验证了多分辨率编码器(MRE)和点金字塔解码器(PPD)的有效性,说明这两个组件对于精确补全丢失的点云部分至关重要。
该表展示了 PF-Net 在不同丢失比例(25%、50%、75%)下的鲁棒性测试结果。表中的结果表明,无论输入点云缺失多少,PF-Net 都能够稳定地生成准确的补全结果,且补全效果在不同缺失程度下变化较小。
通过这一实验,验证了 PF-Net 在面对不同程度丢失的点云时具有很强的鲁棒性和适应能力。
即插即用模块
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
class Convlayer(nn.Module):
def __init__(self, point_scales):
super(Convlayer, self).__init__()
self.point_scales = point_scales
self.conv1 = torch.nn.Conv2d(1, 64, (1, 3))
self.conv2 = torch.nn.Conv2d(64, 64, 1)
self.conv3 = torch.nn.Conv2d(64, 128, 1)
self.conv4 = torch.nn.Conv2d(128, 256, 1)
self.conv5 = torch.nn.Conv2d(256, 512, 1)
self.conv6 = torch.nn.Conv2d(512, 1024, 1)
self.maxpool = torch.nn.MaxPool2d((self.point_scales, 1), 1)
self.bn1 = nn.BatchNorm2d(64)
self.bn2 = nn.BatchNorm2d(64)
self.bn3 = nn.BatchNorm2d(128)
self.bn4 = nn.BatchNorm2d(256)
self.bn5 = nn.BatchNorm2d(512)
self.bn6 = nn.BatchNorm2d(1024)
def forward(self, x):
x = torch.unsqueeze(x, 1)
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x_128 = F.relu(self.bn3(self.conv3(x)))
x_256 = F.relu(self.bn4(self.conv4(x_128)))
x_512 = F.relu(self.bn5(self.conv5(x_256)))
x_1024 = F.relu(self.bn6(self.conv6(x_512)))
x_128 = torch.squeeze(self.maxpool(x_128), 2)
x_256 = torch.squeeze(self.maxpool(x_256), 2)
x_512 = torch.squeeze(self.maxpool(x_512), 2)
x_1024 = torch.squeeze(self.maxpool(x_1024), 2)
L = [x_1024, x_512, x_256, x_128]
x = torch.cat(L, 1)
return x
class Latentfeature(nn.Module):
def __init__(self, num_scales, each_scales_size, point_scales_list):
super(Latentfeature, self).__init__()
self.num_scales = num_scales
self.each_scales_size = each_scales_size
self.point_scales_list = point_scales_list
self.Convlayers1 = nn.ModuleList(
[Convlayer(point_scales=self.point_scales_list[0]) for i in range(self.each_scales_size)])
self.Convlayers2 = nn.ModuleList(
[Convlayer(point_scales=self.point_scales_list[1]) for i in range(self.each_scales_size)])
self.Convlayers3 = nn.ModuleList(
[Convlayer(point_scales=self.point_scales_list[2]) for i in range(self.each_scales_size)])
self.conv1 = torch.nn.Conv1d(3, 1, 1)
self.bn1 = nn.BatchNorm1d(1)
def forward(self, x):
outs = []
for i in range(self.each_scales_size):
outs.append(self.Convlayers1[i](x[0]))
for j in range(self.each_scales_size):
outs.append(self.Convlayers2[j](x[1]))
for k in range(self.each_scales_size):
outs.append(self.Convlayers3[k](x[2]))
latentfeature = torch.cat(outs, 2)
latentfeature = latentfeature.transpose(1, 2)
latentfeature = F.relu(self.bn1(self.conv1(latentfeature)))
latentfeature = torch.squeeze(latentfeature, 1)
return latentfeature
class PointcloudCls(nn.Module):
def __init__(self, num_scales, each_scales_size, point_scales_list, k=40):
super(PointcloudCls, self).__init__()
self.latentfeature = Latentfeature(num_scales, each_scales_size, point_scales_list)
self.fc1 = nn.Linear(1920, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, 256)
self.fc4 = nn.Linear(256, k)
self.dropout = nn.Dropout(p=0.3)
self.bn1 = nn.BatchNorm1d(1024)
self.bn2 = nn.BatchNorm1d(512)
self.bn3 = nn.BatchNorm1d(256)
self.relu = nn.ReLU()
def forward(self, x):
x = self.latentfeature(x)
x = F.relu(self.bn1(self.fc1(x)))
x = F.relu(self.bn2(self.dropout(self.fc2(x))))
x = F.relu(self.bn3(self.dropout(self.fc3(x))))
x = self.fc4(x)
return F.log_softmax(x, dim=1)
class _netG(nn.Module):
def __init__(self, num_scales, each_scales_size, point_scales_list, point_num):
super(_netG, self).__init__()
self.point_num = point_num # 保存输入的点数
self.latentfeature = Latentfeature(num_scales, each_scales_size, point_scales_list)
self.fc1 = nn.Linear(1920, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, 256)
self.fc_final = nn.Linear(256, point_num * 3) # 最后一个全连接层输出维度为 point_num * 3
def forward(self, x):
x = self.latentfeature(x)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = self.fc_final(x) # 输出维度为 [batch_size, point_num * 3]
x = x.reshape(-1, self.point_num, 3) # 重塑为 [batch_size, point_num, 3]
return x
class _netlocalD(nn.Module):
def __init__(self, crop_point_num):
super(_netlocalD, self).__init__()
self.crop_point_num = crop_point_num
self.conv1 = torch.nn.Conv2d(1, 64, (1, 3))
self.conv2 = torch.nn.Conv2d(64, 64, 1)
self.conv3 = torch.nn.Conv2d(64, 128, 1)
self.conv4 = torch.nn.Conv2d(128, 256, 1)
self.maxpool = torch.nn.MaxPool2d((self.crop_point_num, 1), 1)
self.bn1 = nn.BatchNorm2d(64)
self.bn2 = nn.BatchNorm2d(64)
self.bn3 = nn.BatchNorm2d(128)
self.bn4 = nn.BatchNorm2d(256)
self.fc1 = nn.Linear(448, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 16)
self.fc4 = nn.Linear(16, 1)
self.bn_1 = nn.BatchNorm1d(256)
self.bn_2 = nn.BatchNorm1d(128)
self.bn_3 = nn.BatchNorm1d(16)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x_64 = F.relu(self.bn2(self.conv2(x)))
x_128 = F.relu(self.bn3(self.conv3(x_64)))
x_256 = F.relu(self.bn4(self.conv4(x_128)))
x_64 = torch.squeeze(self.maxpool(x_64))
x_128 = torch.squeeze(self.maxpool(x_128))
x_256 = torch.squeeze(self.maxpool(x_256))
Layers = [x_256, x_128, x_64]
x = torch.cat(Layers, 1)
x = F.relu(self.bn_1(self.fc1(x)))
x = F.relu(self.bn_2(self.fc2(x)))
x = F.relu(self.bn_3(self.fc3(x)))
x = self.fc4(x)
return x
if __name__ == '__main__':
input1 = torch.randn(64, 2048, 3)
input2 = torch.randn(64, 512, 3)
input3 = torch.randn(64, 256, 3)
input = [input1, input2, input3]
block = _netG(num_scales=3, each_scales_size=1, point_scales_list=[2048, 512, 256], point_num=2048)
output = block(input) print(output.shape)
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文