【目标检测】(13) 先验框解码,调整预测框,附TensorFlow完整代码

各位同学好,今天和大家分享一下如何使用 TensorFlow 对YOLOV3 和 YOLOV4网络的输出特征进行解码,微调每个先验框的坐标和宽高,使其逼近真实标签框。

YOLOV3 和 YOLOV4 调整先验框的方法类似,代码通用。阅读本篇文章之前,建议先看以下文章

YOLOV2中的先验框: https://blog.csdn.net/dgvv4/article/details/123772756

YOLOV3特征提取网络: https://blog.csdn.net/dgvv4/article/details/121997986

YOLOV4特征提取网络: https://blog.csdn.net/dgvv4/article/details/123818580

1. 生成先验框

网络的三个有效输出特征层分别输出 52*52,26*26,13*13 的特征图。13*13的输出特征负责预测大尺度物体,26*26负责预测中等大小的物体,52*52负责预测小尺度物体。YOLOV3 有9个先验框,最大的三个先验框分配给了13*13的特征图,中等大小的三个先验框分配给了26*26的特征图,最小的三个先验框分配给了52*52的特征图。

如下图,以13*13的特征图为例。物体的中心点落在了红色网格中,那么就需要由这个红色网格生成的三个先验框中的一个去预测这个物体,由三个先验框中与物体真实框的 iou 最大的先验框去预测这个物体。

然而网络输出三种尺度 13*13、26*26、52*52 的特征图,物体的中心点肯定会落在三种尺度的网格中,那么该物体是由哪一种尺度的网格生成的先验框去预测的呢。是由9个先验框中,与真实标签框 iou 最大的那个先验框所在尺度的网格来预测。

【目标检测】(13) 先验框解码,调整预测框,附TensorFlow完整代码

2. 模型输出结果

网络的输入特征图的shape是 [416, 416, 3],经过一系列卷积层,输出三个尺度的特征结果,它们的shape分别是 [52, 52, (3*(5+num_classes))],[26, 26, (3*(5+num_classes))],[13, 13, (3*(5+num_classes))],。

其中通道数 3*(5+num_classes) 可以理解为:每个网格生成 3 个预测框,每个预测框包含了预测框相较于先验框的偏移量,坐标偏移量(tx, ty),宽高偏移量(tw, th),预测框中是否包含目标物体 c,预测框内的物体属于每个类别的条件概率num_classes,在VOC数据集中num_classes=20。其中(tx, ty) 是负无穷到正无穷的任何数,(tw, th) 是归一化后的宽高。

【目标检测】(13) 先验框解码,调整预测框,附TensorFlow完整代码

3. 微调先验框

以某个网格的先验框的调整为例,如下图所示,虚线框代表:和物体的真实标签框 iou 值最大的那个先验框,该先验框的宽高为 (pw, ph) ;蓝色框代表微调先验框后生成的预测框。

(cx,cy)是先验框中心点所在的网格的左上坐标(归一化后的坐标),由于坐标偏移量 (tx,ty) 可以是从负无穷到正无穷的任何数,为了防止坐标调整偏移过大,给偏移量添加sigmoid函数。将坐标偏移量限制在0-1之间,将预测框的中心点限制在它所在的网格内。高宽偏移量 (tw, th)是归一化后的宽高调整值。最终预测框的宽高 (bw, bh)

【目标检测】(13) 先验框解码,调整预测框,附TensorFlow完整代码

4. 代码展示

从网络的输出特征中提取预测帧的中心点坐标、预测帧的宽度和高度、预测帧的置信度以及帧中对象属于每个类别的条件概率。

import tensorflow as tf

#(一)解码网络的输出层的信息
def anchors_decode(feats, anchors, num_classes):
    '''
    feats是某一个特征层的输出结果, 如shape=[b, 13, 13, 3*(5+num_classes)]
    anchors代表每个特征层, 每个网格的三个先验框[3,2]
    num_classes代表分类类别的数量
    '''
    # 计算每个网格几个先验框=3
    num_anchors = len(anchors)

    # 获得图像网格的宽和高的shape=[h,w]=[13,13]
    grid_shape = feats.shape[1:3]
    
    #(1)获得网格中每个网格点的坐标信息
    # 获得网格点的x坐标信息[1]==>[1,13,1,1]
    grid_x = tf.reshape(range(0, grid_shape[1]), shape=[1,-1,1,1])
    # 在y维度上扩张,将前面的数据进行复制然后直接接在原数据后面
    # [1,13,1,1]==>[13,13,3,1]
    grid_x = tf.tile(grid_x, [grid_shape[0], 1, num_anchors, 1])

    # 获得网格点的y坐标信息,方法同上[13]==>[13,1,1,1]
    grid_y = tf.reshape(range(0, grid_shape[0]), shape=[-1,1,1,1])
    # 维度扩张[13,1,1,1]==>[13,13,3,1]
    grid_y = tf.tile(grid_y, [1, grid_shape[1], num_anchors, 1])

    # 在通道维度上合并[13,13,3,2],每个网格的坐标信息, 横纵坐标都是0-12,
    grid = tf.concat([grid_x, grid_y], axis=-1)
    # 转换成tf.float32类型
    grid = tf.cast(grid, tf.float32)

    #(2)调整先验框的信息,13*13个网格,每个网格有3个先验框,每个先验框有(x,y)坐标
    # [3,2]==>[1,1,3,2]
    anchors_tensor = tf.reshape(anchors, shape=[1,1,num_anchors,2])
    # [1,1,3,2]==>[13,13,3,2]
    anchors_tensor = tf.tile(anchors_tensor, [grid_shape[0], grid_shape[1], 1, 1])
    # 转换成float32类型
    anchors_tensor = tf.cast(anchors_tensor, tf.float32)
    
    #(3)调整网络输出特征图的结果
    # [b, 13, 13, 3*(5+num_classes)] = [b, 13, 13, 3, (5+num_classes)]
    '''
    代表13*13个网格, 每个网格有3个先验框, 每个先验框有(5+num_classes)项信息
    其中, 5代表: 中心点坐标(x,y), 宽高(w,h), 置信度c
    num_classes: 检测框属于某个类别的条件概率, VOC数据集中等于20
    '''
    feats = tf.reshape(feats, shape=[-1, grid_shape[0], grid_shape[1], num_anchors, 5+num_classes])

    #(4)调整先验框中心坐标及宽高
    # 对预测框中心点坐标归一化处理,只能在所处的网格中调整
    anchor_xy = tf.nn.sigmoid(feats[..., :2])
    box_xy = anchor_xy + grid  # 每个网格的预测框坐标

    # 网格的预测框宽高默认是归一化之后的,对宽高取指数
    anchors_wh = tf.exp(feats[..., 2:4])
    box_wh = anchors_wh * anchors_tensor  # 预测框的宽高

    # 获得预测框的置信度和每个类别的条件概率
    box_conf = tf.nn.sigmoid(feats[..., 4:5])
    box_prob = tf.nn.sigmoid(feats[..., 5:])

    # 返回预测框信息
    return box_xy, box_wh, box_conf, box_prob

随机生成一个13*13特征层的输出特征图,并给每个网格设置三种长宽比的先验框,来验证三个先验框微调的效果。

#(二)验证
if __name__ == '__main__':

    feat = tf.random.normal([4,13,13,75], mean=0, stddev=0.5)  # 构建输出特征图
    anchors = tf.constant([[142, 110],[192, 243],[459, 401]])  # 每个网格的先验框坐标

    # 返回调整后的预测框信息
    box_xy, box_wh, box_conf, box_prob = anchors_decode(feat, anchors, 20)
    
    '''
    某个网格调整后的预测框宽高
    [[  80.93765 ,  106.666855],
     [ 119.06944 ,  248.29587 ],
     [ 413.01917 ,  339.91293 ]]
    '''

以网格为例,将结果可视化。左图是原始的prior box,三个prior box的中心点在一起,右图是调整后的预测框,红点是每个调整后的预测框的中心点。

【目标检测】(13) 先验框解码,调整预测框,附TensorFlow完整代码

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
扎眼的阳光的头像扎眼的阳光普通用户
上一篇 2022年4月11日 下午5:56
下一篇 2022年4月11日

相关推荐