一般情况下,模型在训练的时候,需要保证是train()模式,而在评估时需要保证是eval()模式。因为很多时候,模型中会包含dropout、BN的操作,而eval()模式下这两个功能是不会生效的,这样保证评估时候的稳定性。否则每次随机的dropout,那么每次评估的结果也是不同的。
假如有一个模型,我们标记为model,并不是model.eval()总是会让dropout失效的。比如,对于复杂模型,由多个的block组成stack,由多个stack组成一个模型model。
如上图所示,它是nbeats算法的模型,stacks是模型最外层的结构,总共由2个stack组成stacks结构。而每个stack又有3个block模块。对应的模型结构信息如下:
| N-Beats
| -- Stack Generic (#0) (share_weights_in_stack=False)
| -- GenericBlock(units=128, thetas_dim=4, backcast_length=12, forecast_length=4, share_thetas=False) at @2631227197480
| -- GenericBlock(units=128, thetas_dim=4, backcast_length=12, forecast_length=4, share_thetas=False) at @2631296887440
| -- GenericBlock(units=128, thetas_dim=4, backcast_length=12, forecast_length=4, share_thetas=False) at @2631296888056
| -- Stack Seasonality (#1) (share_weights_in_stack=False)
| -- SeasonalityBlock(units=128, thetas_dim=8, backcast_length=12, forecast_length=4, share_thetas=True) at @2631296888784
| -- SeasonalityBlock(units=128, thetas_dim=8, backcast_length=12, forecast_length=4, share_thetas=True) at @2631297372736
| -- SeasonalityBlock(units=128, thetas_dim=8, backcast_length=12, forecast_length=4, share_thetas=True) at @2631297373296
此时直接对模型model使用model.eval(),我们会发现dropout仍然生效。
我们通过JIT跟踪模型的执行,代码如下:
jit_model = torch.jit.trace(model.eval(), total_x_input)
执行后会在日志中生成如下信息:
具体错误信息如下:
D:\programs\python\Lib\site-packages\torch\jit\__init__.py:914: TracerWarning: Trace had nondeterministic nodes. Did you forget call .eval() on your model? Nodes:
%input.3 : Float(113, 128) = aten::dropout(%input.2, %97, %98), scope: NBeatsNet/GenericBlock/Dropout[fc1_dropout] # D:\programs\python\Lib\site-packages\torch\nn\functional.py:806:0
%input.6 : Float(113, 128) = aten::dropout(%input.5, %105, %106), scope: NBeatsNet/GenericBlock/Dropout[fc2_dropout] # D:\programs\python\Lib\site-packages\torch\nn\functional.py:806:0
%input.9 : Float(113, 128) = aten::dropout(%input.8, %113, %114), scope: NBeatsNet/GenericBlock/Dropout[fc3_dropout] # D:\programs\python\Lib\site-packages\torch\nn\functional.py:806:0
%input.12 : Float(113, 128) = aten::dropout(%input.11, %121, %122), scope: NBeatsNet/GenericBlock/Dropout[fc4_dropout] # D:\programs\python\Lib\site-packages\torch\nn\functional.py:806:0
JIT提示我们忘记了eval()操作,可是我们明明输入的是model.eval()的模型信息。
这到底是怎么回事?让我们细细分析!
首先,如第一张图所示,stacks的数据类型是class list,同时每一个stack的数据类型也是class list。也就是说,我们使用model.eval(),只是让model这个实例的training属性变为了False,并没有针对每个stakc下的block改变任何属性。这决定了model.eval()只是操作了个寂寞!
那么怎么改呢?我们的方向是让每个包含执行节点的block变为eval()模式,所以操作代码如下:
model.stacks[0][0].eval()
model.stacks[0][1].eval()
model.stacks[0][2].eval()
model.stacks[1][0].eval()
model.stacks[1][1].eval()
model.stacks[1][2].eval()
也许加一个循环会让代码变得更加优雅。
再执行”jit_model = torch.jit.trace(model.eval(), total_x_input)”操作,日志将不再提醒”Did you forget call .eval() on your model”。
文章出处登录后可见!