pytorch的eval()失效剖析及解决方案

一般情况下,模型在训练的时候,需要保证是train()模式,而在评估时需要保证是eval()模式。因为很多时候,模型中会包含dropout、BN的操作,而eval()模式下这两个功能是不会生效的,这样保证评估时候的稳定性。否则每次随机的dropout,那么每次评估的结果也是不同的。

假如有一个模型,我们标记为model,并不是model.eval()总是会让dropout失效的。比如,对于复杂模型,由多个的block组成stack,由多个stack组成一个模型model。

如上图所示,它是nbeats算法的模型,stacks是模型最外层的结构,总共由2个stack组成stacks结构。而每个stack又有3个block模块。对应的模型结构信息如下:

| N-Beats
| --  Stack Generic (#0) (share_weights_in_stack=False)
     | -- GenericBlock(units=128, thetas_dim=4, backcast_length=12, forecast_length=4, share_thetas=False) at @2631227197480
     | -- GenericBlock(units=128, thetas_dim=4, backcast_length=12, forecast_length=4, share_thetas=False) at @2631296887440
     | -- GenericBlock(units=128, thetas_dim=4, backcast_length=12, forecast_length=4, share_thetas=False) at @2631296888056
| --  Stack Seasonality (#1) (share_weights_in_stack=False)
     | -- SeasonalityBlock(units=128, thetas_dim=8, backcast_length=12, forecast_length=4, share_thetas=True) at @2631296888784
     | -- SeasonalityBlock(units=128, thetas_dim=8, backcast_length=12, forecast_length=4, share_thetas=True) at @2631297372736
     | -- SeasonalityBlock(units=128, thetas_dim=8, backcast_length=12, forecast_length=4, share_thetas=True) at @2631297373296

此时直接对模型model使用model.eval(),我们会发现dropout仍然生效。

我们通过JIT跟踪模型的执行,代码如下:

jit_model = torch.jit.trace(model.eval(), total_x_input)

执行后会在日志中生成如下信息:

具体错误信息如下:

D:\programs\python\Lib\site-packages\torch\jit\__init__.py:914: TracerWarning: Trace had nondeterministic nodes. Did you forget call .eval() on your model? Nodes:
  %input.3 : Float(113, 128) = aten::dropout(%input.2, %97, %98), scope: NBeatsNet/GenericBlock/Dropout[fc1_dropout] # D:\programs\python\Lib\site-packages\torch\nn\functional.py:806:0
  %input.6 : Float(113, 128) = aten::dropout(%input.5, %105, %106), scope: NBeatsNet/GenericBlock/Dropout[fc2_dropout] # D:\programs\python\Lib\site-packages\torch\nn\functional.py:806:0
  %input.9 : Float(113, 128) = aten::dropout(%input.8, %113, %114), scope: NBeatsNet/GenericBlock/Dropout[fc3_dropout] # D:\programs\python\Lib\site-packages\torch\nn\functional.py:806:0
  %input.12 : Float(113, 128) = aten::dropout(%input.11, %121, %122), scope: NBeatsNet/GenericBlock/Dropout[fc4_dropout] # D:\programs\python\Lib\site-packages\torch\nn\functional.py:806:0

 JIT提示我们忘记了eval()操作,可是我们明明输入的是model.eval()的模型信息。

这到底是怎么回事?让我们细细分析!

首先,如第一张图所示,stacks的数据类型是class list,同时每一个stack的数据类型也是class list。也就是说,我们使用model.eval(),只是让model这个实例的training属性变为了False,并没有针对每个stakc下的block改变任何属性。这决定了model.eval()只是操作了个寂寞!

那么怎么改呢?我们的方向是让每个包含执行节点的block变为eval()模式,所以操作代码如下:

model.stacks[0][0].eval()
model.stacks[0][1].eval()
model.stacks[0][2].eval()
model.stacks[1][0].eval()
model.stacks[1][1].eval()
model.stacks[1][2].eval()

也许加一个循环会让代码变得更加优雅。

再执行”jit_model = torch.jit.trace(model.eval(), total_x_input)”操作,日志将不再提醒”Did you forget call .eval() on your model”。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2022年5月26日
下一篇 2022年5月26日

相关推荐