记录 PyTorch Lightning 的一个坑

项目场景

PyTorch Lightning 对 PyTorch 做了进一步的封装,并继承了日志记录,分布式训练等工具,让我们能够把研究核心放在模型改进上而不是工程代码的编写。近期使用发现一个小问题,在此记录一下。

问题描述

模型训练的时候很正常,但验证的时候报错:

TypeError: validation_step() takes 3 positional arguments but 4 were given

并且,测试的时候也会遇到类似的问题。

原因分析

原来是我重写 LightningModulevalidation_steptest_step 方法时没有指定 batch_idx 参数,虽然这个参数在方法中没有被使用,但是却会被隐式地调用。batch_idx 就是批数据的索引,例如打印训练进度条的时候肯定会被调用的。但如果不显式地指定,就是导致位置参数和关键字参数识别冲突,从而引发异常。

解决方案

这是我原来的代码:

def validation_step(self, batch):
	pass

def test_step(self, batch):
    pass

加上 batch_idx 参数就行了:

def validation_step(self, batch, batch_idx):
	pass

def test_step(self, batch, batch_idx):
    pass

引用参考

https://github.com/PyTorchLightning/pytorch-lightning/issues/1034

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
社会演员多的头像社会演员多普通用户
上一篇 2022年5月23日
下一篇 2022年5月23日

相关推荐