pytorch | loss不收敛或者训练中梯度grad为None的问题

0. 个人觉得有用的参考

  1. 原因定位:
    • https://blog.csdn.net/weixin_44231148/article/details/107240840
    • Pytorch中自定义网络参数,存在梯度但不进行更新 – 漱石的文章 – 知乎
    • https://zhuanlan.zhihu.com/p/92729376
      https://zhuanlan.zhihu.com/p/508458545
  2. 介绍hooks:
    • https://zhuanlan.zhihu.com/p/553627695
    • https://medium.com/analytics-vidhya/pytorch-hooks-5909c7636fb
  3. [强推] 介绍autograd机制,有很多demo,对理解hooks以及autograd非常有帮助:
    • PyTorch 源码解读之 torch.autograd:梯度计算详解 – OpenMMLab的文章 – 知乎
      https://zhuanlan.zhihu.com/p/321449610

其他的想到再补

1. 笔者的一些经验

  1. 笔者血泪史中最重要的一点:网络输出到求loss之间的操作,尽可能简洁
    • 如果NN的输出直接和label可以进行对比,那是最好的情况,比如输出是猫还是狗这种tag
    • 如果不能直接进行对比,则应该尽可能简洁,同时注意以下问题:

1. 1 尽可能注意

  • 避免原地操作:原地操作无法溯源,backward的时候找不到之前的值了
    • 能用torch.squeeze(x),不用x.squeeze_()
    • 能用loss = loss + loss_2,不用loss += loss_2,另外,后者的写法pytorch也会报错
  • 使用pytorch内置的函数或者内置的操作
    • 如果是四则运算可以用+ - * /等无所谓,但是如果是其他操作,pytorch的函数一定比自己写的要好(能用torch的轮子不要自己造),比如logsumexp这种函数,之前还都要自己写,但是现在pytorch也已经帮你写好勒~
  • 尽可能的向量化执行操作。

2. 常用工具

  • 列举一些用过的debug工具,主要用于定位weight不更新,loss不收敛,以及grad为None
  • 工具有好有坏。

2.1 利用hooks输出grad

  1. 众所周知pytorch不保存中间梯度值,所以如果想知道grad到底传没传过去,建议使用pytorch提供的hooks机制

    • hooks机制可能不是很好理解,可以参考上面我给出的一些对autograd和hooks进行介绍的链接。
    • 总之,我的理解是,
      • hooks可以加在pytorch计算图中任何一个地方,开销不大,方便debug,可以获取计算图中任何一个地方的grad或者weight值,尤其是计算图中间节点的grad
      • hook 函数是一个自己定义的函数,只是对函数的input有一些要求,具体要求取决于用的是什么hook
  2. 笔者常用的一套是基于register_hook

grads = {}

def save_grad(name):
	# 返回hook函数 
    def hook(grad):
    # 简单粗暴的直接输出,存到grads中也可以
        # grads[name] = grad
        print(f"name={name}, grad={grad}")
    return hook

def loss_fn(output, target):	

	#### 先对output作一些操作,然后再求mse
	output1 = f_1(output)
	#### 有时候需要检查这个操作会不会影响梯度
	output1.register_hook(save_grad('output1'))
	
	output2 = f_2(output1)
	loss = mse(output2 , target)
	return loss 
for batch_id, data in enumerate(train_loader): 
#.......省略细节,只展示主干

	output = mlp(input_data)
	
	loss = loss_fn(output, target)
	
    optimizer.zero_grad()
    # Only a backward hook is possible for Tensors.
    # 只有执行完backward()之后才会注册hooks
    loss.backward() 
    optimizer.step()

  1. 另一套常用的方法是下面的,基于register_forward_hookregister_full_backward_hook
# hook functions have to take these 3 input
def hook_forward_fn(module, input, output):
    print("It's forward: ")
    print(f"module: {module}")
    print(f"input: {input}")
    print(f"output: {output}")
    print("="*20)

def hook_backward_fn(module, grad_input, grad_output):
    print("It's backward: ")
    print(f"module: {module}")
    print(f"grad_input: {grad_input}")
    print(f"grad_output: {grad_output}")
    print("="*20)

# Set the hooks
mlp.conv1.register_forward_hook(hook_forward_fn)
mlp.conv1.register_full_backward_hook(hook_backward_fn)

  1. 上面两种hooks的区别是:
    • register_hook用于对某个tensor进行加钩子,比如我写的loss函数(没继承module
    • hook_forward_fnhook_backward_fn用于对继承了module的class/对象进行加钩子,比如上面的mlp这个网络

2.2 其他方法输出grad [不建议]

  1. 利用比如retain_graph或者什么方式,也可以保留中间节点的grad值,但是不建议使用,因为内存开销大,偶尔debug下还可以

2.3 tensorboard输出weight值

  1. 利用tensorboard输出weight值可以清楚的看到有没有更新,或者权重更新的方向是不是自己满意的方向
  2. 下图就是一个随着epoch的增加不更新的weight
    image.png
  3. 假设你已经掌握了一点点tensorboard的使用方法,下面是我常用的一个demo:
def plot_alpha(writer,mlp,epoch,tag_str):
    """
    plot weight of 'alpha' in tensorboard
    Args:
        writer:  tensorboard writer
        mlp:     the NN model
        epoch:   epoch
        tag_str: the name of the plot window
    """
    for name, param in mlp.named_parameters():
        # 简单写法:想画什么就大概记录下这一层的名字‘block_alpha’,想画权重就是‘weight’,偏移就是'bias'
        if 'block_alpha' in name and 'weight' in name:
            writer.add_histogram(tag=name + tag_str, values=param.data.clone().cpu().numpy(), global_step=epoch)


for batch_id, data in enumerate(train_loader): 
#.......省略细节,只展示主干

	output = mlp(input_data)
	
	loss = loss_fn(output, target)
	
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
	# plot
    if epoch%3==0:	# 每3个epoch, 画一次weight
        # Record the weight
    	plot_alpha(writer,mlp,epoch,opt.tag_str)
    	
  1. 其他方法:使用注册hooks的方式进行,hook_forward_fnhook_backward_fn上面展示过了

2.4 [不好用] torch.autograd.gradcheck

  1. 代码大概:
from torch.autograd.gradcheck import gradcheck
# grad check
input_test = torch.randn((2,3,300), requires_grad=True,device=device)
test_ans = gradcheck(mlp.to(device), input_test, eps=1e-3)  #, eps=1e-6, atol=1e-4
print("Are the gradients correct: ", test_ans)

  1. 不好用的点在于精度问题,复杂网络或者复杂的loss函数,经常由于精度问题导致无法通过gradcheck
    • 有时候把精度eps设的大一点还可以通过这个梯度检查,但是还有什么意义呢。。

2.5 [不是很好用但也还行]网上的梯度流可视化demo

  1. 来自:https://github.com/t-vi/pytorch-tvmisc/blob/master/visualize/bad_grad_viz.ipynb
  2. 最后可以可视化输出,有问题的节点会被标红

2.6 torch.autograd.detect_anomaly()

  1. 之前用过,但是加上这个后速度非常慢,所以后来也没用过了,想起来再补

欢迎交流:)

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2023年8月16日
下一篇 2023年8月16日

相关推荐