NeRF 源码分析解读(二)

NeRF 源码分析解读(二)

光线的生成

由上一章节我们得到了加载到的数据,包括读取图像的数组、图像的高宽焦距、相机的 pose 、以及用于分割测试集、训练集的分割数组。得到这些数据后,我们开始进行生成光线的步骤。
生成光线的步骤是 NeRF 代码中最为关键的一步,实际上我们模拟的光线就是三维空间中在指定方向上的一系列离散的点的坐标。有了这些点坐标,我们将其投入到 NeRF 的 MLP 神经网络中,计算这个点的密度值以及颜色值。
下面我们对具体代码进行分析。

def train():
	...
	# 加载数据,具体加载代码分析详见 上一篇博客
	...
	
    if K is None:
       K = np.array([
           [focal, 0, 0.5*W],
           [0, focal, 0.5*H],
           [0, 0, 1]
       ])

注意这里的 K 。这里的 K 指的是相机的内参,具体的作用会在后面的分析中进行解释说明

1、初始化 NeRF 网络模型

def train():
	...
	# 初始化 NeRF 网络模型
	render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)

	bds_dict = {
        'near' : near,
        'far' : far,
    }
    # 更新字典,加入三维空间的边界框 bounding box
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)
	...

我们首先对 create_nerf() 返回的值进行解释说明:

render_kwargs_train:一个字典,包含了用于训练的各个参数值。具体字典内容详见下面的代码分析

render_kwargs_test:

start:

grad_vars: 整个网络的梯度变量

optimizer: 整个网络的优化器

接下来我们对 create_nerf() 代码进行具体的逐行分析

def create_nerf(args):
    embed_fn, input_ch = get_embedder(args.multires, args.i_embed)  
	...

这行语句实际上获得一个编码器 embed_fn 以及一个编码后的维度,给定 embed_fn 一个输入,就可以获得输入的编码后的数据。具体的编码公式如下,详见论文 5.1 节,代码分析详见:位置编码代码注释分析

NeRF 源码分析解读(二)

以下是网络结构的初始化代码,网络的层级模型详见论文 补充材料 pic 7。 关于 NeRF 的 model 部分我们会单独开一个章节进行解读。(待更新),总之我们只需要知道我们在这里创建了 NeRF 的粗网络,给这个粗网络输入一个 5D 的输入,就可以得到一个 (RGB,A)的输出,即:
NeRF 源码分析解读(二)
NeRF 源码分析解读(二) 就是我们创建的网络 model

def create_nerf(args):
	...
	# 初始化MLP模型参数,网络的层级模型详见论文 补充材料 pic 7
    model = NeRF(D=args.netdepth, W=args.netwidth,
                 input_ch=input_ch, output_ch=output_ch, skips=skips,
                 input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
    # 模型中的梯度变量
    grad_vars = list(model.parameters())
    ...
    # 定义一个查询点的颜色和密度的匿名函数,实际上给定点坐标,方向,以及查询网络,我们就可以得到该点在该网络下的输出([rgb,alpha])
    network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
                                                                embed_fn=embed_fn,
                                                                embeddirs_fn=embeddirs_fn,
                                                                netchunk=args.netchunk  # 网络批处理查询点的数量)

可以看到这里的 network_query_fn 是一个匿名函数,真正起作用的函数是 run_network() 。下面我们对 run_network() 进行分析,观察我们是如何生成给定点的颜色和密度的。
run_network() 代码分析如下:

def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
    """
    对 input 进行处理,应用 神经网络 fn
    """
    inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
	embedded = embed_fn(inputs_flat)  # 对输入进行位置编码,得到编码后的结果,是一个 array 数组

	if viewdirs is not None:
		# 视图不为 None,即输入了视图方向,那么我们就应该考虑对视图方向作出处理,用以生成颜色
		input_dirs = viewdirs[:, None].expand(inputs.shape)
		input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
		embedded_dirs = embeddirs_fn(input_dirs_flat)  # 对输入方向进行编码
		embedded = torch.cat([embedded, embedded_dirs], -1)  # 将编码后的位置和方向聚合到一起

	outputs_flat = batchify(fn, netchunk)(embedded)  # 将编码过的点以批处理的形式输入到 网络模型 中得到 输出(RGB,A)
	outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
	return outputs

这里的 batchify() 函数会把 embedded 数组分批输入到网络 fn 中,前向传播得到对应的 (RGB,A)。
接下来我们继续对 create_nerf() 进行分析

def create_nerf(args):
	...
	# 创建网络的优化器
    optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
	
	...
	# 关于加载已有模型的部分我们不在逐行分析,对整体算法的分析没有任何影响

	...
	
	# 注意看,现在整体的初始化已经完成,我们需要对返回值进行一些处理
	render_kwargs_train = {
        'network_query_fn' : network_query_fn,
        'perturb' : args.perturb,
        'N_importance' : args.N_importance,
        'network_fine' : model_fine,
        'N_samples' : args.N_samples,
        'network_fn' : model,
        'use_viewdirs' : args.use_viewdirs,
        'white_bkgd' : args.white_bkgd,
        'raw_noise_std' : args.raw_noise_std,
    }

    # NDC 空间,只对前向场景有效,具体解释可以看论文,这里不再展开
    if args.dataset_type != 'llff' or args.no_ndc:
        print('Not ndc!')
        render_kwargs_train['ndc'] = False
        render_kwargs_train['lindisp'] = args.lindisp

    render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
    render_kwargs_test['perturb'] = False
    render_kwargs_test['raw_noise_std'] = 0.

    return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer

下面我们对字典 render_kwargs_train 中的键值进行分析:

'network_query_fn' : network_query_fn,  # 上文已经解释过了,这是一个匿名函数,
给这个函数输入位置坐标,方向坐标,以及神经网络,就可以利用神经网络返回该点对应的 颜色和密度

'perturb' : args.perturb,  # 扰动,对整体算法理解没有影响

'N_importance' : args.N_importance,  # 每条光线上细采样点的数量

'network_fine' : model_fine,  # 论文中的 精细网络

'N_samples' : args.N_samples,  # 每条光线上粗采样点的数量

'network_fn' : model,  # 论文中的 粗网络

'use_viewdirs' : args.use_viewdirs,  # 是否使用视点方向,影响到神经网络是否输出颜色

'white_bkgd' : args.white_bkgd,  # 如果为 True 将输入的 png 图像的透明部分转换成白色

'raw_noise_std' : args.raw_noise_std,  # 归一化密度

以上我们完成了 NeRF 模型的初始化部分。我们在下一章节继续对 train() 函数进行分析。NeRF源码分析解读(三)

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(1)
xiaoxingxing的头像xiaoxingxing管理团队
上一篇 2023年2月26日 上午9:11
下一篇 2023年2月26日

相关推荐