【神经网络量化】——非线性激活函数sigmoid,tanh的量化推理

sigmoid, tanh, 量化推理

介绍

在嵌入式设备,ARM的M系列,或者存硬件实现网络的推理,这时就需要所有的运算都需要用int型(int8,int15)或者自定义的数据类型。这里包括常见的conv2d,devconv2d…等算子,relu,prelu,sigmoid,tanh等非线性激活函数。

初步知识

  • 神经网络的量化
  • 负补

sigmoid/tanh量化推理

这里我们以ARM的CMSIS_5中的代码进行原理和代码的解说。

1. 查表法

sigmoid,tanh及类似的非线性激活函数都是通过查表的方式来进行推理的。这里以int8为例

1.1表的生成
sigmoid,tanh都会趋于饱和,所以选定浮点的输入范围为[-8, 8), 其他的数进行clip就行。

 * sigmoid(8) = 0.9996646498695336
 * tanh(8) = 0.9999997749296758

但是量化模型的输入应该是int8的数据类型才是,[-128,127],那我们的目的应该是通过输入[-128, 127]为索引创建sigmoid( [-8, 8) )->[]的映射表。

1.2输入[-8,8)的转换
索引转换
这里有一个有趣的转变。详情如下:

128 = (uint8)(-128)
129 = (uint8)(-127)
130 = (uint8)(-126)

可参考:short | long | char自动类型转换的两个例子
这就是索引的变换。

输入tensor [-8,8)建立一个256的表
用numpy中的 linspace 函数
np.linspace(-8, 8, 256, endpoint=False)

然后再把这所有的256个浮点数,进行sigmoid函数,得到浮点的输出。后续再把浮点的out量化成int8.

def fp2q7(self, x):
    x_int = math.floor(x*(2**7)+0.5)
    if x_int >= 128 :
      x_int = 127
    if x_int < -128 :
      x_int = -128
    if x_int >= 0 :
      return x_int
    else :
      return 0x100 + x_int        # 0x100 是什么运算???
array([-8.    , -7.9375, -7.875 , -7.8125, -7.75  , -7.6875, -7.625 ,
       -7.5625, -7.5   , -7.4375, -7.375 , -7.3125, -7.25  , -7.1875,
       -7.125 , -7.0625, -7.    , -6.9375, -6.875 , -6.8125, -6.75  ,
       -6.6875, -6.625 , -6.5625, -6.5   , -6.4375, -6.375 , -6.3125,
       -6.25  , -6.1875, -6.125 , -6.0625, -6.    , -5.9375, -5.875 ,
       -5.8125, -5.75  , -5.6875, -5.625 , -5.5625, -5.5   , -5.4375,
       -5.375 , -5.3125, -5.25  , -5.1875, -5.125 , -5.0625, -5.    ,
       -4.9375, -4.875 , -4.8125, -4.75  , -4.6875, -4.625 , -4.5625,
       -4.5   , -4.4375, -4.375 , -4.3125, -4.25  , -4.1875, -4.125 ,
       -4.0625, -4.    , -3.9375, -3.875 , -3.8125, -3.75  , -3.6875,
       -3.625 , -3.5625, -3.5   , -3.4375, -3.375 , -3.3125, -3.25  ,
       -3.1875, -3.125 , -3.0625, -3.    , -2.9375, -2.875 , -2.8125,
       -2.75  , -2.6875, -2.625 , -2.5625, -2.5   , -2.4375, -2.375 ,
       -2.3125, -2.25  , -2.1875, -2.125 , -2.0625, -2.    , -1.9375,
       -1.875 , -1.8125, -1.75  , -1.6875, -1.625 , -1.5625, -1.5   ,
       -1.4375, -1.375 , -1.3125, -1.25  , -1.1875, -1.125 , -1.0625,
       -1.    , -0.9375, -0.875 , -0.8125, -0.75  , -0.6875, -0.625 ,
       -0.5625, -0.5   , -0.4375, -0.375 , -0.3125, -0.25  , -0.1875,
       -0.125 , -0.0625,  0.    ,  0.0625,  0.125 ,  0.1875,  0.25  ,
        0.3125,  0.375 ,  0.4375,  0.5   ,  0.5625,  0.625 ,  0.6875,
        0.75  ,  0.8125,  0.875 ,  0.9375,  1.    ,  1.0625,  1.125 ,
        1.1875,  1.25  ,  1.3125,  1.375 ,  1.4375,  1.5   ,  1.5625,
        1.625 ,  1.6875,  1.75  ,  1.8125,  1.875 ,  1.9375,  2.    ,
        2.0625,  2.125 ,  2.1875,  2.25  ,  2.3125,  2.375 ,  2.4375,
        2.5   ,  2.5625,  2.625 ,  2.6875,  2.75  ,  2.8125,  2.875 ,
        2.9375,  3.    ,  3.0625,  3.125 ,  3.1875,  3.25  ,  3.3125,
        3.375 ,  3.4375,  3.5   ,  3.5625,  3.625 ,  3.6875,  3.75  ,
        3.8125,  3.875 ,  3.9375,  4.    ,  4.0625,  4.125 ,  4.1875,
        4.25  ,  4.3125,  4.375 ,  4.4375,  4.5   ,  4.5625,  4.625 ,
        4.6875,  4.75  ,  4.8125,  4.875 ,  4.9375,  5.    ,  5.0625,
        5.125 ,  5.1875,  5.25  ,  5.3125,  5.375 ,  5.4375,  5.5   ,
        5.5625,  5.625 ,  5.6875,  5.75  ,  5.8125,  5.875 ,  5.9375,
        6.    ,  6.0625,  6.125 ,  6.1875,  6.25  ,  6.3125,  6.375 ,
        6.4375,  6.5   ,  6.5625,  6.625 ,  6.6875,  6.75  ,  6.8125,
        6.875 ,  6.9375,  7.    ,  7.0625,  7.125 ,  7.1875,  7.25  ,
        7.3125,  7.375 ,  7.4375,  7.5   ,  7.5625,  7.625 ,  7.6875,
        7.75  ,  7.8125,  7.875 ,  7.9375])

但是,这个表没法索引。将上面的向量量化到int8为:

array([-128., -127., -126., -125., -124., -123., -122., -121., -120.,
       -119., -118., -117., -116., -115., -114., -113., -112., -111.,
       -110., -109., -108., -107., -106., -105., -104., -103., -102.,
       -101., -100.,  -99.,  -98.,  -97.,  -96.,  -95.,  -94.,  -93.,
        -92.,  -91.,  -90.,  -89.,  -88.,  -87.,  -86.,  -85.,  -84.,
        -83.,  -82.,  -81.,  -80.,  -79.,  -78.,  -77.,  -76.,  -75.,
        -74.,  -73.,  -72.,  -71.,  -70.,  -69.,  -68.,  -67.,  -66.,
        -65.,  -64.,  -63.,  -62.,  -61.,  -60.,  -59.,  -58.,  -57.,
        -56.,  -55.,  -54.,  -53.,  -52.,  -51.,  -50.,  -49.,  -48.,
        -47.,  -46.,  -45.,  -44.,  -43.,  -42.,  -41.,  -40.,  -39.,
        -38.,  -37.,  -36.,  -35.,  -34.,  -33.,  -32.,  -31.,  -30.,
        -29.,  -28.,  -27.,  -26.,  -25.,  -24.,  -23.,  -22.,  -21.,
        -20.,  -19.,  -18.,  -17.,  -16.,  -15.,  -14.,  -13.,  -12.,
        -11.,  -10.,   -9.,   -8.,   -7.,   -6.,   -5.,   -4.,   -3.,
         -2.,   -1.,    0.,    1.,    2.,    3.,    4.,    5.,    6.,
          7.,    8.,    9.,   10.,   11.,   12.,   13.,   14.,   15.,
         16.,   17.,   18.,   19.,   20.,   21.,   22.,   23.,   24.,
         25.,   26.,   27.,   28.,   29.,   30.,   31.,   32.,   33.,
         34.,   35.,   36.,   37.,   38.,   39.,   40.,   41.,   42.,
         43.,   44.,   45.,   46.,   47.,   48.,   49.,   50.,   51.,
         52.,   53.,   54.,   55.,   56.,   57.,   58.,   59.,   60.,
         61.,   62.,   63.,   64.,   65.,   66.,   67.,   68.,   69.,
         70.,   71.,   72.,   73.,   74.,   75.,   76.,   77.,   78.,
         79.,   80.,   81.,   82.,   83.,   84.,   85.,   86.,   87.,
         88.,   89.,   90.,   91.,   92.,   93.,   94.,   95.,   96.,
         97.,   98.,   99.,  100.,  101.,  102.,  103.,  104.,  105.,
        106.,  107.,  108.,  109.,  110.,  111.,  112.,  113.,  114.,
        115.,  116.,  117.,  118.,  119.,  120.,  121.,  122.,  123.,
        124.,  125.,  126.,  127.])

然后是前面提到的一个有趣的转换。

128 = (uint8)(-128)
129 = (uint8)(-127)
130 = (uint8)(-126)

256 = (uint8)(-1)

所以我们将上表的次序调整一下,将负数按从小到大接到正数后面。为[0~127, -128 ~ -1],这样整个表就建立好了。

0x100 + x_int 负数为什么要 + 0x100

table_gen.py中sigmoid输出的浮点值需要量化成int8,然后写入到C文件中用16进制,具体的处理代码:

  def fp2q7(self, x):
    x_int = math.floor(x*(2**7)+0.5)
    if x_int >= 128 :
      x_int = 127
    if x_int < -128 :
      x_int = -128
    if x_int >= 0 :
      return x_int
    else :
      return 0x100 + x_int        

这是因为16进制,和二进制是等价的,在计算机中正数以原码,负数以补码的形式存储,int8类型负数的补码就是 256+x(x是一个负数)。下面举几个例子:

如果sigmoid的输出量化后为-128
以0x%02x的16进制格式写入,**因为0x只能写入正数**,所以我们求出二进制的补码,将该补码当原码写入。

所以这里就是求负数的补码表示的正数。
在c中直接:
uint(-128) = 128
在python中:
0x100+(-128)= 128

128因为是正数,二进制原码为:10000000
-128的补码:1000 0000
可以看到二者是一样的

计算

CMSIS中的计算code

void arm_nn_activations_direct_q7(q7_t *data, uint16_t size, uint16_t int_width, arm_nn_activation_type type)
{
    uint16_t i = size;
    q7_t *pIn = data;
    q7_t *pOut = data;
    q7_t in;
    q7_t out;
    uint16_t shift_size = 3 - int_width;
    const q7_t *lookup_table;
    switch (type)
    {
    case ARM_SIGMOID:
        lookup_table = sigmoidTable_q7;
        break;
    case ARM_TANH:
    default:
        lookup_table = tanhTable_q7;
        break;
    }
    while (i)
    {
        in = *pIn++;
        out = lookup_table[(uint8_t)(in >> shift_size)];
        *pOut++ = out;
        i--;
    }
}

解释

  • 这里的int_width默认为3,这是和前面输入的范围[-8, 8)对应的。而且建立的表的输入范围为[-8,8),这个的int_width是用来控制输入的范围的。int_width=2, 表示输入的范围为[-4, 4), int_width=1, 输入的范围为[-1, 1),同时可以用原来的表。
  • (uint8_t)(in >> shift_size),shift_size默认为0,于是直接就是(uint8_t)in,这和前面将的次序调整有关。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(1)
心中带点小风骚的头像心中带点小风骚普通用户
上一篇 2022年3月18日 下午4:21
下一篇 2022年3月18日 下午4:39

相关推荐