有关于pytorch单精度bfloat16位

1. 反直觉的bfloat16

torch支持单精度浮点数bfloat16。这种数据类型在使用的时候需要格外小心,因为它很可能会表现出一系列的“反人类直觉”特性。

什么是bfloat16

BF16是brain float的简称(来源于google brain)。不同于普通的单精度浮点数FP16(i.e., torch.float16),BF16是介于FP16和FP32之间的一种浮点数格式。BF16的指数位比FP16多,跟FP32一样,不过小数位比较少。即,BF16尝试采用牺牲精度的方法,来换取更大的数值空间(Dynamic Range)。

bfloat16带来的问题

虽然有实验和研究都已经表明,BF16的这种“牺牲精度”并不会损害性能 (甚至某些情况下能带来性能提升),并且速度更快,内存消耗更少 (和FP16一样)。但是在实际使用过程中,它往往还是会带来很多的模型训练时的负面影响,如:

  • 混合精度训练时,loss出现NANINF
  • 巨大的数值间隙,令人费解

比方说,下面这张图,计算constractive loss的时候,需要把positive loss (一个较大的数值)和negative loss (一个较小的数值)相加,就会由于BF16的小数精度表达能力过弱,而导致negative loss根本不起效果:
Positive loss根本没有发生变化
再比如下面这张图,也非常容易让人confused:

所以,在某些特定场景和需求下,我们可以选择不用BF16,而使用传统的FP16。比方说,计算constractive loss,对于小数数值精度有一定要求。

但很遗憾,目前很多与训练语言模型,比方说T5,都是使用BF16进行预训练的…

计算机体系结构的知识忘的太多了,现在已经有点记不得浮点数的相关概念了。。。

解决方案

参考自:float16 vs bfloat16 numerical properties comparison

主要有两种策略来应对bfloat16的精度问题:

  1. 添加assert:
    比较保守的一种策略。如果实在不确定自己的代码当前bfloat16计算是不是已经出现了浮点数溢出问题,可以在代码里面添加断言,来检测模型训练过程中 (尤其是计算loss的时候),tensor中是否已经出现了overflow现象。例如:
assert not torch.isinf(loss).any().item() and not torch.isnan(loss).any().item()
  1. 禁用pytorch subnormal numbers
    如果需要使用bfloat生成高精度的tensor,用下面这段代码把torch的denormal操作关闭:
_ = torch.set_flush_denormal(False)

示范:

  1. 强行更改dtype
    最暴力的一种解决精度问题的策略,自然就是强行把bfloat的矩阵类型改为双精度。适用于需要较高精度支持的场景:

2. 参考:

  • Mixed precision for bfloat16-pretrained models
  • 深度学习与bfloat16(BF16)
  • What Every User Should Know About Mixed Precision Training in PyTorch
  • float16 vs bfloat16 numerical properties comparison

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2023年6月25日
下一篇 2023年6月25日

相关推荐