PF-Net基于深度学习的点云补全网络

目录


1. 论文和代码

论文:Point Fractal Network for 3D Point Cloud Completionhttps://openaccess.thecvf.com/content_CVPR_2020/papers/Huang_PF-Net_Point_Fractal_Network_for_3D_Point_Cloud_Completion_CVPR_2020_paper.pdfhttps://openaccess.thecvf.com/content_CVPR_2020/papers/Huang_PF-Net_Point_Fractal_Network_for_3D_Point_Cloud_Completion_CVPR_2020_paper.pdf

作者来自上海交通大学和上汤科技的大佬,发表在2020CVPR。 

代码: 

https://github.com/zztianzz/PF-Net-Point-Fractal-Networkhttps://github.com/zztianzz/PF-Net-Point-Fractal-Network

2. 论文阅读笔记

2.1 目的和框架

        该PF-Net要做的是点云补全,即将有残缺的点云数据(比如上图飞机少了机头,或者凳子少了腿),通过一些技术补全为完整的点云数据。

 简单来讲,PF-Net输入残缺后点云(飞机的机身),输出残缺的部分点云(飞机的机尾),端对端训练,作为生成器网络,生成残缺点云,再接一个判别器网络。

        该网络的特点:不改变原始的数据,只生成残缺部分的点云数据。即机身的点云数据不变,直接生成机头部分的点云。

       算法步骤:

(1)原始的黄色点云输入数据,经过了两次IFPS下采样,获得三种尺度的点云输入数据,其中N是原始的点云中点的个数,k是下采样倍数;

(2)再经过CMLP全链接网络,获得Latent vector F;

(3)再将各个latent vector拼接起来获得Final Laten Map M;

(4)接一个MLP和Linear全链接网络,再使用FPN特征金字塔作为解码网络,获取三种尺度下的残缺点云数据;

(5)对原始尺度下的残缺点云预测加一个判别器网络,使其生成的残缺数据更真实。

下面对各个部件,从输入到输出一个一个梳理。

2.2 IFPS 下采样

        Iterative farthest point sampling (IFPS),迭代最远点采样(技术来自Pointnet++),采集点云数据中骨架点点集合,通俗的将不破坏点云整体结构的情况下,就是只保留一些点。用该技术进行才采样比CNNs更快。

上图,原始台灯有 2048个点,即使下采样到128个点(保留了6.25%),依然很好的保留了台灯的基本骨架。

实现参考iterative farthest point sample (IFPS or FPS)_Mr.Q的博客-CSDN博客迭代最远距离采样,在点云论文PointNet++和PF-Net中用于对点云数据下采样。PF-Net基于深度学习的点云补全网络https://blog.csdn.net/jizhidexiaoming/article/details/128198099?spm=1001.2014.3001.5501

3. 源码解读

3.1 载入数据

shapenet_part_loader.py

# from __future__ import print_function
import torch.utils.data as data
import os
import os.path
import torch
import json
import numpy as np
import sys

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
dataset_path = os.path.abspath(
    os.path.join(BASE_DIR, '../dataset/shapenet_part/shapenetcore_partanno_segmentation_benchmark_v0/'))


class PartDataset(data.Dataset):
    def __init__(self, root=dataset_path, npoints=2500, classification=False, class_choice=None, split='train',
                 normalize=True):
        """

        Parameters
        ----------
        root: str. 数据集完整路径
        npoints: 2048. the point number of a sample. 输入到网络中点云的点个数。
        classification: bool. True. "Airplane" or "Mug" or something else.
        class_choice: list. None. 训练指定的类别。
        split: str. train/test
        normalize: bool. 是否归一化
        """
        self.npoints = npoints
        self.root = root
        self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')  # 映射表格
        self.cat = {}  # 存放映射字典, {airplane: 11231414, ...}
        self.classification = classification
        self.normalize = normalize

        with open(self.catfile, 'r') as f:
            for line in f:
                ls = line.strip().split()
                self.cat[ls[0]] = ls[1]
        # print(self.cat)
        if not class_choice is None:
            self.cat = {k: v for k, v in self.cat.items() if k in class_choice}
            print(self.cat)
        self.meta = {}
        with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f:
            train_ids = set([str(d.split('/')[2]) for d in json.load(f)])  # 点云文件名称
        with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:
            val_ids = set([str(d.split('/')[2]) for d in json.load(f)])
        with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f:
            test_ids = set([str(d.split('/')[2]) for d in json.load(f)])

        # 获取datapath list [("Airplane", 点云文件路径,分割文件路径,点云文件夹id,点云文件名称), ...]
        for item in self.cat:
            # print('category', item)
            self.meta[item] = []  # {"Airplane": [(点云文件路径,分割文件路径,点云类别id,点云文件名称), ...],
                                  #  "": [], ...}
            dir_point = os.path.join(self.root, self.cat[item], 'points')  # 当前类别的点云文件夹路径
            dir_seg = os.path.join(self.root, self.cat[item], 'points_label')  # 当前类别的分割文件夹路径
            # print(dir_point, dir_seg)
            fns = sorted(os.listdir(dir_point))  # 当前类别的所有点云文件名
            if split == 'trainval':
                fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))]
            elif split == 'train':
                fns = [fn for fn in fns if fn[0:-4] in train_ids]  # 获取所有属于训练集的点云文件名称
            elif split == 'val':
                fns = [fn for fn in fns if fn[0:-4] in val_ids]
            elif split == 'test':
                fns = [fn for fn in fns if fn[0:-4] in test_ids]
            else:
                print('Unknown split: %s. Exiting..' % (split))
                sys.exit(-1)

            for fn in fns:  #
                token = (os.path.splitext(os.path.basename(fn))[0])  # 获取点云文件名称
                self.meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg'),
                                        self.cat[item], token))  # {"Airplane": [(点云文件路径,分割文件路径,点云文件夹id,点云文件名称), ...]}
        self.datapath = []  # [("Airplane", 点云文件路径,分割文件路径,点云文件夹id,点云文件名称), ...]
        for item in self.cat:
            for fn in self.meta[item]:
                self.datapath.append((item, fn[0], fn[1], fn[2], fn[3]))
        # ["cls_name": cls_id, ...]
        self.classes = dict(zip(sorted(self.cat), range(len(self.cat))))  # {"Airplane": 0, "", 1, ...} 按首字母排序。
        print(self.classes)
        self.num_seg_classes = 0
        if not self.classification:
            for i in range(len(self.datapath) // 50):
                l = len(np.unique(np.loadtxt(self.datapath[i][2]).astype(np.uint8)))
                if l > self.num_seg_classes:
                    self.num_seg_classes = l
        # print(self.num_seg_classes)
        self.cache = {}  # from index to (point_set, cls, seg) tuple
        self.cache_size = 18000  # 加载一次后,不会重复加载

    def __getitem__(self, index):
        if index in self.cache:  # 加载一次后,不会重复加载,所以如果在缓存中,直接取出来即可。
            #            point_set, seg, cls= self.cache[index]
            point_set, seg, cls, foldername, filename = self.cache[index]
        else:
            fn = self.datapath[index]
            # 1. cls. "Mug"类别id是11
            cls = self.classes[self.datapath[index][0]]
            # 2. point_set
            point_set = np.loadtxt(fn[1]).astype(np.float32)  # (2817, 3). 载入点云,并转成float32类型
            if self.normalize:
                point_set = self.pc_normalize(point_set)
            # 3. seg
            seg = np.loadtxt(fn[2]).astype(np.int64) - 1  # 分割类别id
            # 4. foldername 点云文件夹
            foldername = fn[3]
            # 5. filename 点云文件名称
            filename = fn[4]
            if len(self.cache) < self.cache_size:  # 载入缓存,以便下次迭代时使用
                self.cache[index] = (point_set, seg, cls, foldername, filename)

        # 随机选择npoints个点参与训练
        choice_idx = np.random.choice(len(seg), self.npoints, replace=True)  # 其实可以不用seg文件来随机
        # resample
        point_set = point_set[choice_idx, :]
        seg = seg[choice_idx]

        # To Pytorch
        point_set = torch.from_numpy(point_set)  # (2048,3)
        seg = torch.from_numpy(seg)  # (2048,)
        cls = torch.from_numpy(np.array([cls]).astype(np.int64))  # (1,)
        if self.classification:
            return point_set, cls
        else:
            return point_set, seg, cls

    def __len__(self):
        return len(self.datapath)

    def pc_normalize(self, pc):
        """ pc: NxC, return NxC """
        # l = pc.shape[0]
        centroid = np.mean(pc, axis=0)  # [-0.00400733  0.14655513  0.0053034 ]
        pc = pc - centroid  # 所有的值减去均值
        m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))  # sqrt(x1^2+y1^2+z1^2) + sqrt(x2^2+y2^2+z2^2)+...  0.55
        pc = pc / m
        return pc


if __name__ == '__main__':
    dset = PartDataset(root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/', classification=True,
                       class_choice=None, npoints=4096, split='train')
    #    d = PartDataset( root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/',classification=False, class_choice=None, npoints=4096, split='test')
    print(len(dset))
    ps, cls = dset[10000]
    print(cls)
#    print(ps.size(), ps.type(), cls.size(), cls.type())
#    print(ps)
#    ps = ps.numpy()
#    np.savetxt('ps'+'.txt', ps, fmt = "%f %f %f")

3.1.1 归一化操作

(1)坐标值减去各自坐标值的均值;

(2)sqrt(x1^2+y1^2+z1^2) + sqrt(x2^2+y2^2+z2^2)+…  == 0.55

(3)坐标值 / 0.55

3.2 数据前处理

Trian_PFNet.py

dset = shapenet_part_loader.PartDataset(
    root='/home/zxq/code/python/PF-Net-Point-Fractal-Network/dataset/shapenetcore_partanno_segmentation_benchmark_v0/',
    classification=True, 
    class_choice=None, 
    npoints=opt.pnum, 
    split='train')
assert dset
dataloader = torch.utils.data.DataLoader(dset, batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers))

real_label = 1
fake_label = 0

for i, data in enumerate(dataloader, 0):

    real_point, target = data  # 点云坐标(b,2048,3). 点云类别(b,1) (Airplane or Mug).

    batch_size = real_point.size()[0]
    real_center = torch.FloatTensor(batch_size, 1, opt.crop_point_num, 3)  # (b,1,512,3). # 保存裁剪点的坐标
    input_cropped1 = torch.FloatTensor(batch_size, opt.pnum, 3)  # (b,2048,3). 原始点云数据的坐标,后面将裁剪掉crop_point_num个点
    input_cropped1 = input_cropped1.data.copy_(real_point)  # input_cropped1的地址指向没变,只是重新赋值。
    real_point = torch.unsqueeze(real_point, 1)  # (b,2048,3) -> (b,1,2048,3)
    input_cropped1 = torch.unsqueeze(input_cropped1, 1)  # (b,2048,3) -> (b,1,2048,3)
    p_origin = [0, 0, 0]

    # 计算点云和各自视点之间的距离,并从小到大排序;裁剪点云
    # input_cropped1被裁剪后的点云,real_center是被裁剪下来的点云
    # Set viewpoints
    vp_choice_list = [torch.Tensor([1, 0, 0]), torch.Tensor([0, 0, 1]), torch.Tensor([1, 0, 1]),
                      torch.Tensor([-1, 0, 0]), torch.Tensor([-1, 1, 0])]
    for m in range(batch_size):  # 计算batch中所有点云距离vp
        cur_vp_index = random.sample(vp_choice_list, 1)  # Random choose one of the viewpoint
        p_center = cur_vp_index[0]  # eg. [1,0,0]
        distance_list = []  # 点和各自vp之间的距离
        for n in range(opt.pnum):  # 点云中第n个点
            distance_list.append(distance_squre(real_point[m, 0, n], p_center))  # 当前点和vp之间的距离
        distance_order = sorted(enumerate(distance_list), key=lambda x: x[1])  # enumerate使其变成2维,x[1]第二维度
        # 裁剪掉距离视点最近的前crop_point_num个点
        for sp in range(opt.crop_point_num):  # distance_order[sp] == (point_idx, dist_val)
            input_cropped1.data[m, 0, distance_order[sp][0]] = torch.FloatTensor([0, 0, 0])  # 坐标置为0
            real_center.data[m, 0, sp] = real_point[m, 0, distance_order[sp][0]]  # 保存裁剪点的坐标

    label.resize_([batch_size, 1]).fill_(real_label)  # (b,) -> (b,1).  填充1

    # to cuda
    real_point = real_point.to(device)  # (b,1,2048,3) 原始完整点云坐标数据
    real_center = real_center.to(device)  # (b,1,512,3) 被裁剪下来的点云
    input_cropped1 = input_cropped1.to(device)  # (b,1,2048,3) 被裁剪后的点云
    label = label.to(device)  # (2,1) 1是真实,0是生成

    ############################
    # (1) data prepare
    ###########################
    # 被裁剪下来的点云
    # scale 0
    real_center = Variable(real_center, requires_grad=True)
    real_center = torch.squeeze(real_center, 1)  # (b,1,512,3) -> (b,512,3)
    # scale 1
    real_center_key1_idx = utils.farthest_point_sample(real_center, 64, RAN=False)  # 提取64个点作为骨架点
    real_center_key1 = utils.index_points(real_center, real_center_key1_idx)
    real_center_key1 = Variable(real_center_key1, requires_grad=True)
    # scale 2
    real_center_key2_idx = utils.farthest_point_sample(real_center, 128, RAN=True)  # 提取128个点作为骨架点
    real_center_key2 = utils.index_points(real_center, real_center_key2_idx)  # 被裁剪下来的点云
    real_center_key2 = Variable(real_center_key2, requires_grad=True)
    # 被裁剪后的点云
    # scale 0
    input_cropped1 = torch.squeeze(input_cropped1, 1)  # (b,1,2048,3) -> (b,512,3)
    # scale 1
    input_cropped2_idx = utils.farthest_point_sample(input_cropped1, opt.point_scales_list[1], RAN=True)  # 1024
    input_cropped2 = utils.index_points(input_cropped1, input_cropped2_idx)
    # scale 2
    input_cropped3_idx = utils.farthest_point_sample(input_cropped1, opt.point_scales_list[2], RAN=False)  # 512
    input_cropped3 = utils.index_points(input_cropped1, input_cropped3_idx)

    input_cropped1 = Variable(input_cropped1, requires_grad=True)
    input_cropped2 = Variable(input_cropped2, requires_grad=True)
    input_cropped3 = Variable(input_cropped3, requires_grad=True)

    # to cuda
    input_cropped2 = input_cropped2.to(device)
    input_cropped3 = input_cropped3.to(device)
    input_cropped = [input_cropped1, input_cropped2, input_cropped3]  # 被裁剪后的点云 from diff scales

 得到数据:

real_center: (b,512,3).  被裁剪下来的点云

input_cropped: list of tensor. (b,2048,3), (b,1024,3), (b,512,3) . 裁剪后的点云

label_center: (b,1). 0/1是否是真是点云

real_center_key1: (b,128,3). 被裁剪下来的点云(下次样)

real_center_key2: (b,64,3). 被裁剪下来的点云(下次样)

3.3 网络输入输出

3.3.1 判别器训练

(1)输入真实的被裁剪下来的点云,判别器进行判断,计算errD_real_loss;

(2)利用被裁剪后的点云,生成假的被裁剪下来的点云,再经过判别器,计算errD_fake_loss;

判别器的目标是:

  • 真的判定为真的,即图中real_center的预测值越接近1,损失越小; 
  • 假的判定为假的,即图中fake的预测值越接近0,损失越小。

 对应的代码

point_netG = point_netG.train()
point_netD = point_netD.train()
############################
# (2) Update D network
###########################
point_netD.zero_grad()
real_center = torch.unsqueeze(real_center, 1)  # (b,512,3) -> (b,1,512,3)
output = point_netD(real_center)  # (b,1,512,3). output: (b,1)
# label: (b,1) fill with 1. 对于判别器来说,output值越大越好,损失值越小
errD_real = criterion(output, label)
errD_real.backward()

# input_cropped: (2,2048,3)/(2,1024,3)/(2,512,3). fake_1: (b,64,3), fake_2: (b,128,3), fake: (b,512,3).
fake_center1, fake_center2, fake = point_netG(input_cropped)
fake = torch.unsqueeze(fake, 1)  # (b,512,3) -> (b,1,512,3)
label.data.fill_(fake_label)  # (b,1). label赋值为0
output = point_netD(fake.detach())  # output: (b,1)
# label: (b,1) fill with 0. 对于判别器来说,output值越小越好,损失值越小
errD_fake = criterion(output, label)  #
errD_fake.backward()

errD = errD_real + errD_fake  # errD 没有参与训练,只是用于打印,没啥其他用处。

optimizerD.step()

3.3.2 生成器训练

对图中生成的4个fake点云进行学习,降低损失函数。

############################
# (3) Update G network: maximize log(D(G(z)))
###########################
point_netG.zero_grad()
label.data.fill_(real_label)  # (b,1). label赋值为1
# fake: (b,1,512,3). output: (b,1)。利用更新后的判别器再次判断fake数据
output = point_netD(fake)
errG_D = criterion(output, label)  # tensor(0.5747)

# fake: (b,1,512,3) -> (b,512,3), real_center: (b,1,512,3) -> (b,512,3)
CD_LOSS = criterion_PointLoss(torch.squeeze(fake, 1), torch.squeeze(real_center, 1))  # 只是打印,没有参与训练

# 生成不同尺度下数据的损失CD
# fake and real_center: (b,1,512,3). 生成的假的被裁剪下来的点云、真的被裁剪下来的点云
# fake_center1 and real_center_key1: (b,64,3)
# fake_center2 and real_center_key2: (b,128,3)
errG_l2 = criterion_PointLoss(torch.squeeze(fake, 1), torch.squeeze(real_center, 1)) \
          + alpha1 * criterion_PointLoss(fake_center1, real_center_key1) \
          + alpha2 * criterion_PointLoss(fake_center2, real_center_key2)

errG = (1 - opt.wtl2) * errG_D + opt.wtl2 * errG_l2  # 0.05*errG_D + 0.95*errG_l2
errG.backward()
optimizerG.step()

3.4 判别器模型

 对应到论文中的框架图:

其中CMLP等于上图的conv2d+maxpool+conc组合操作。

(1) 输入生成的假的被裁剪下来的点云,四次卷积,缩小通道数,获得多尺度特征;

(2)分别对最后三个多尺度卷积结果进行最大池化,4维度变2维度特征;

(3)拼接多个尺度特征,再接4个全链接层。

class _netlocalD(nn.Module):
    def __init__(self, crop_point_num):
        super(_netlocalD, self).__init__()
        self.crop_point_num = crop_point_num
        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(1, 3))
        self.conv2 = torch.nn.Conv2d(64, 64, 1)
        self.conv3 = torch.nn.Conv2d(64, 128, 1)
        self.conv4 = torch.nn.Conv2d(128, 256, 1)

        self.maxpool = torch.nn.MaxPool2d(kernel_size=(self.crop_point_num, 1), stride=1)

        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)

        self.fc1 = nn.Linear(448, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 16)
        self.fc4 = nn.Linear(16, 1)

        self.bn_1 = nn.BatchNorm1d(256)
        self.bn_2 = nn.BatchNorm1d(128)
        self.bn_3 = nn.BatchNorm1d(16)

    def forward(self, x):  # size: (2,1,512,3)
        x = F.relu(self.bn1(self.conv1(x)))  # (b,1,512,3) -> (2,64,512,1). conv2d+bn2d+relu
        x_64 = F.relu(self.bn2(self.conv2(x)))  # (b,64,512,1) -> (b,64,512,1)
        x_128 = F.relu(self.bn3(self.conv3(x_64)))  # (b,64,512,1) -> (b,128,512,1)
        x_256 = F.relu(self.bn4(self.conv4(x_128)))  # (b,128,512,1) -> (b,256,512,1)

        x_64 = torch.squeeze(self.maxpool(x_64))  # (b,64,512,1) -> (b,64,1,1)->(b,64)
        x_128 = torch.squeeze(self.maxpool(x_128))  # (b,128,512,1) -> (b,128,1,1)->(b,128)
        x_256 = torch.squeeze(self.maxpool(x_256))  # (b,256,512,1) -> (b,256,1,1)->(b,256)

        Layers = [x_256, x_128, x_64]  # (b,64), (b,128), (b,256)
        x = torch.cat(Layers, 1)  # (b,448)
        x = F.relu(self.bn_1(self.fc1(x)))  # (b,448) -> (b,256)
        x = F.relu(self.bn_2(self.fc2(x)))  # (b,256) -> (b,128)
        x = F.relu(self.bn_3(self.fc3(x)))  # (b,128) -> (b,16)
        x = self.fc4(x)  # (b,1). real or fake
        return x

3.5 生成器模型

3.5.1 CMLP

 框架图中的CMLP代码如下,输入size: (b,num_points,3),输出size: (b,1024+512+256+128, 1).

class Convlayer(nn.Module):
    def __init__(self, point_scales):
        """
        CMLP: conv+max_pool+concat, 其中最大池化的核大小是动态的,使得最后输出的特征向量是固定大小
        Parameters
        ----------
        point_scales: int. 2048/1024/512. 用于最大池化核算子大小,相当与自适应最大池化,把特征图池化到1x1大小
        """
        super(Convlayer, self).__init__()
        self.point_scales = point_scales
        self.conv1 = torch.nn.Conv2d(1, 64, (1, 3))
        self.conv2 = torch.nn.Conv2d(64, 64, 1)
        self.conv3 = torch.nn.Conv2d(64, 128, 1)
        self.conv4 = torch.nn.Conv2d(128, 256, 1)
        self.conv5 = torch.nn.Conv2d(256, 512, 1)
        self.conv6 = torch.nn.Conv2d(512, 1024, 1)
        self.maxpool = torch.nn.MaxPool2d((self.point_scales, 1), 1)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        self.bn5 = nn.BatchNorm2d(512)
        self.bn6 = nn.BatchNorm2d(1024)

    def forward(self, x):  # (b,num_point,3)
        x = torch.unsqueeze(x, 1)  # (b,num_point,3) -> (b,1,num_point,3)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        # 获取4个尺度的4维度特征
        x_128 = F.relu(self.bn3(self.conv3(x)))
        x_256 = F.relu(self.bn4(self.conv4(x_128)))
        x_512 = F.relu(self.bn5(self.conv5(x_256)))
        x_1024 = F.relu(self.bn6(self.conv6(x_512)))
        # 4维度变2维度特征
        x_128 = torch.squeeze(self.maxpool(x_128), 2)  # (b,c,num_point,1) -> (b,c,1)
        x_256 = torch.squeeze(self.maxpool(x_256), 2)
        x_512 = torch.squeeze(self.maxpool(x_512), 2)
        x_1024 = torch.squeeze(self.maxpool(x_1024), 2)
        # 拼接多尺度特征
        L = [x_1024, x_512, x_256, x_128]  # (b,1024,1), (b,512,1),(b,256,1), (b,128,1)
        x = torch.cat(L, 1)  # (b,1024+512+256+128, 1)
        return x

3.5.2 Final Feature Vector V

如下是框架中的特征向量Final feature vector V求取代码.

输入size: list. (b,2048,3)/(b,1024,3)/(b,512,3),输出size: (b,1920).

class Latentfeature(nn.Module):
    def __init__(self, num_scales, each_scales_size, point_scales_list):
        """

        Parameters
        ----------
        num_scales: int. 3. number of scales.
        each_scales_size: int. 1. each scales size. 即每个尺度的shape
        point_scales_list: list. [2048, 1024, 512]. number of points in each scales.
        """
        super(Latentfeature, self).__init__()
        self.num_scales = num_scales
        self.each_scales_size = each_scales_size
        self.point_scales_list = point_scales_list
        self.Convlayers1 = nn.ModuleList(  # CMLP
            [Convlayer(point_scales=self.point_scales_list[0]) for i in range(self.each_scales_size)])
        self.Convlayers2 = nn.ModuleList(
            [Convlayer(point_scales=self.point_scales_list[1]) for i in range(self.each_scales_size)])
        self.Convlayers3 = nn.ModuleList(
            [Convlayer(point_scales=self.point_scales_list[2]) for i in range(self.each_scales_size)])
        self.conv1 = torch.nn.Conv1d(3, 1, 1)
        self.bn1 = nn.BatchNorm1d(1)

    def forward(self, x):
        """

        Parameters
        ----------
        x: list. (b,2048,3)/(b,1024,3)/(b,512,3)

        Returns. (b,1920)
        -------

        """
        outs = []
        # 1, CMLP. input (b,point_num,3), output latent vector.
        for i in range(self.each_scales_size):
            outs.append(self.Convlayers1[i](x[0]))  # CMLP: (2,2048,3) -> (b,1024+512+256+128,1)
        for j in range(self.each_scales_size):
            outs.append(self.Convlayers2[j](x[1]))  # CMLP: (2,1024,3) -> (b,1024+512+256+128,1)
        for k in range(self.each_scales_size):
            outs.append(self.Convlayers3[k](x[2]))  # CMLP: (2,512,3) ->  (b,1024+512+256+128,1)
        # 2, CONCAT
        latentfeature = torch.cat(outs, 2)  # (b,1920,3). final latent map M
        # 3, MLP
        latentfeature = latentfeature.transpose(1, 2)  # (b,1920,3) -> (b,3,1920)
        latentfeature = F.relu(self.bn1(self.conv1(latentfeature)))  # (b,3,1920) -> (b,1,1920)
        latentfeature = torch.squeeze(latentfeature, 1)  # (b,1,1920) -> (b,1920)
        return latentfeature

3.5.3 生成器主代码 

class _netG(nn.Module):
    def __init__(self, num_scales, each_scales_size, point_scales_list, crop_point_num):
        """

        Parameters
        ----------
        num_scales: int. 3. number of scales.
        each_scales_size: int. 1. each scales size. 即每个尺度的shape
        point_scales_list: list. [2048, 1024, 512]. number of points in each scale.
        crop_point_num: int. 512. 裁剪多少个点下来
        """
        super(_netG, self).__init__()
        self.crop_point_num = crop_point_num
        self.latentfeature = Latentfeature(num_scales, each_scales_size, point_scales_list)
        self.fc1 = nn.Linear(1920, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)

        self.fc1_1 = nn.Linear(1024, 128 * 512)
        self.fc2_1 = nn.Linear(512, 64 * 128)  # nn.Linear(512,64*256) !
        self.fc3_1 = nn.Linear(256, 64 * 3)

        self.conv1_1 = torch.nn.Conv1d(512, 512, 1)  # torch.nn.Conv1d(256,256,1) !
        self.conv1_2 = torch.nn.Conv1d(512, 256, 1)
        self.conv1_3 = torch.nn.Conv1d(256, int((self.crop_point_num * 3) / 128), 1)
        self.conv2_1 = torch.nn.Conv1d(128, 6, 1)  # torch.nn.Conv1d(256,12,1) !

    def forward(self, x):
        """

        Parameters
        ----------
        x: list. (b,2048,3)/(b,1024,3)/(b,512,3)

        Returns (b,64,3), (b,128,3), (b,512,3).
        -------

        """
        # final feature vector V
        x = self.latentfeature(x)  # list -> (b,1920)
        # FPN
        # fc1, fc2, fc3
        x_1 = F.relu(self.fc1(x))  # (b,1920) -> (b,1024)
        x_2 = F.relu(self.fc2(x_1))  # (b,1024) -> (b,512)
        x_3 = F.relu(self.fc3(x_2))  # (b,512) -> (b,256)
        # x_3: fc+reshape. 少了论文中的一个conv
        pc1_feat = self.fc3_1(x_3)  # (b,256) -> (b,192)
        pc1_xyz = pc1_feat.reshape(-1, 64, 3)  # (b,192) -> (b,64,3). 64x3 center1. 64个点
        # x_2: fc+reshape+conv1d
        pc2_feat = F.relu(self.fc2_1(x_2))  # (b,192) -> (b,8192)
        pc2_feat = pc2_feat.reshape(-1, 128, 64)  # (b,8192) -> (b,128,64)
        pc2_xyz = self.conv2_1(pc2_feat)  # (b,128,64) -> (b,6,64). 6x64 center2
        # x_1: fc_reshape+conv1d+conv1d+conv1d
        pc3_feat = F.relu(self.fc1_1(x_1))  # (b,1024) -> (b,65536)
        pc3_feat = pc3_feat.reshape(-1, 512, 128)  # (b,65536) -> (b,512,128)
        pc3_feat = F.relu(self.conv1_1(pc3_feat))  # (b,512,128) -> (b,512,128)
        pc3_feat = F.relu(self.conv1_2(pc3_feat))  # (b,512,128) -> (b,256,128)
        pc3_xyz = self.conv1_3(pc3_feat)  # (b,256,128) -> (b,12,128). 12x128 fine

        # plus: scale 1 + scale 2
        pc1_xyz_expand = torch.unsqueeze(pc1_xyz, 2)  # (b,64,3) -> (b,64,1,3)
        pc2_xyz = pc2_xyz.transpose(1, 2)  # (b,6,64) -> (b,64,6)
        pc2_xyz = pc2_xyz.reshape(-1, 64, 2, 3)  # (b,64,6) -> (b,64,2,3)
        pc2_xyz = pc1_xyz_expand + pc2_xyz  # (b,64,1,3) + (b,64,2,3) = (b,64,2,3)
        pc2_xyz = pc2_xyz.reshape(-1, 128, 3)  # (b,64,2,3) -> (b,128,3)
        # plus: scale 2 + scale 3
        pc2_xyz_expand = torch.unsqueeze(pc2_xyz, 2)  # (b,128,3) -> (b,128,1,3)
        pc3_xyz = pc3_xyz.transpose(1, 2)  # (b,12,128) -> (b,12,128)
        pc3_xyz = pc3_xyz.reshape(-1, 128, int(self.crop_point_num / 128), 3)  # (b,12,128) -> (b,128,4,3)
        pc3_xyz = pc2_xyz_expand + pc3_xyz  # (b,128,1,3) + (b,128,4,3) = (b,128,4,3)
        pc3_xyz = pc3_xyz.reshape(-1, self.crop_point_num, 3)  # (b,128,4,3) -> (b,512,3)

        return pc1_xyz, pc2_xyz, pc3_xyz  # (b,64,3), (b,128,3), (b,512,3). center1, center2, fine

3.6 测试效果

 

测试代码

# 1. init model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
point_netG = _netG(opt.num_scales, opt.each_scales_size, opt.point_scales_list, opt.crop_point_num)
point_netG = torch.nn.DataParallel(point_netG)
point_netG.to(device)
point_netG.load_state_dict(torch.load(opt.netG, map_location=lambda storage, location: storage)['state_dict'])
point_netG.eval()

# 2. load incomplete point cloud
input_cropped1 = np.loadtxt(opt.infile, delimiter=',')  # (1536,3). csv文件
input_cropped1 = torch.FloatTensor(input_cropped1)  # (1536,3)
input_cropped1 = torch.unsqueeze(input_cropped1, 0)  # (1,1536,3)

Zeros = torch.zeros(1, 512, 3)  # (1,512,3)
input_cropped1 = torch.cat((input_cropped1, Zeros), 1)  # (1,1536+512,3) = (1,2048,3)

# 2. preprocess
# 获得多尺度输入: [input_cropped1, input_cropped2, input_cropped3]. (1,2048,3)/(1,1024,3)/(1,512,3)
input_cropped2_idx = utils.farthest_point_sample(input_cropped1, opt.point_scales_list[1], RAN=True)
input_cropped2 = utils.index_points(input_cropped1, input_cropped2_idx)  # (1,1024,3)
input_cropped3_idx = utils.farthest_point_sample(input_cropped1, opt.point_scales_list[2], RAN=False)
input_cropped3 = utils.index_points(input_cropped1, input_cropped3_idx)  # (1,512,3)
# input_cropped4_idx = utils.farthest_point_sample(input_cropped1, 256, RAN=True)
# input_cropped4 = utils.index_points(input_cropped1, input_cropped4_idx)  # (1,256,3). 没啥用

# to cuda
input_cropped2 = input_cropped2.to(device)  # (1,1024,3)
input_cropped3 = input_cropped3.to(device)  # (1,512,3)
input_cropped = [input_cropped1, input_cropped2, input_cropped3]
# 3. infer. fake.size: (1,512,3)
fake_center1, fake_center2, fake = point_netG(input_cropped)
# fake = fake.cuda()  # 返回的本来就在cuda设备上
# fake_center1 = fake_center1.cuda()
# fake_center2 = fake_center2.cuda()

# 4. post-process
# input_cropped2 = input_cropped2.cpu()
# input_cropped3 = input_cropped3.cpu()
# input_cropped4 = input_cropped4.cpu()

# np_crop2 = input_cropped2[0].detach().numpy()
# np_crop3 = input_cropped3[0].detach().numpy()
# np_crop4 = input_cropped4[0].detach().numpy()

# # 真实被裁剪下来的点云,并生成多尺度真实点云
# real = np.loadtxt(opt.infile_real, delimiter=',')
# real = torch.FloatTensor(real)
# real = torch.unsqueeze(real, 0)
# real2_idx = utils.farthest_point_sample(real, 64, RAN=False)
# real2 = utils.index_points(real, real2_idx)
# real3_idx = utils.farthest_point_sample(real, 128, RAN=True)
# real3 = utils.index_points(real, real3_idx)
#
# real2 = real2.cpu()
# real3 = real3.cpu()
#
# np_real2 = real2[0].detach().numpy()
# np_real3 = real3[0].detach().numpy()

fake = fake.cpu()
# fake_center1 = fake_center1.cpu()
# fake_center2 = fake_center2.cpu()
np_fake = fake[0].detach().numpy()  # (1,512,3) -> (512,3)
# np_fake1 = fake_center1[0].detach().numpy()
# np_fake2 = fake_center2[0].detach().numpy()
input_cropped1 = input_cropped1.cpu()
np_crop = input_cropped1[0].numpy()  # (1,2048,3) -> (2048,3)

np.savetxt('test_one/crop_ours' + '.csv', np_crop, fmt="%f,%f,%f")
np.savetxt('test_one/fake_ours' + '.csv', np_fake, fmt="%f,%f,%f")
np.savetxt('test_one/crop_ours_txt' + '.txt', np_crop, fmt="%f,%f,%f")
np.savetxt('test_one/fake_ours_txt' + '.txt', np_fake, fmt="%f,%f,%f")

        

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2023年4月5日
下一篇 2023年4月5日

相关推荐