目标检测中rpn到底怎么理解

https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/blob/797f12c91fbb6caaa748c09f16f0cd0fbb9cbd61/pytorch_object_detection/mask_rcnn/network_files/rpn_function.py

import torch.nn as nn
import torch
import torch.nn.functional as F
from typing import List
import utils.det_utils  as det_utils
import torchvision

@torch.jit.unused
def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
    # type: (Tensor, int) -> Tuple[int, int]
    from torch.onnx import operators
    num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
    pre_nms_top_n = torch.min(torch.cat(
        (torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype),
         num_anchors), 0))

    return num_anchors, pre_nms_top_n


class AnchorsGenerator(nn.Module):
    def __init__(self,sizes=(128,256,512),aspect_ratios=(0.5,1.0,2.0)):
        super(AnchorsGenerator, self).__init__()
        self.sizes=sizes
        self.aspect_ratios=aspect_ratios
        self.cell_anchors=None
        self._cache={}


    def num_anchors_per_location(self):
        # 计算每个预测特征层上每个滑动窗口的预测目标数
        return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
    # 产生9个不同大小的anchor  中心是(0,0)
    def generate_anchors(self,scales,aspect_ratios,dtype=torch.float32,device=torch.device('cpu')):
        scales = torch.as_tensor(scales, dtype=dtype, device=device)
        aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
        h_ratios = torch.sqrt(aspect_ratios)
        w_ratios = 1.0 / h_ratios

        # [r1, r2, r3]' * [s1, s2, s3]
        # number of elements is len(ratios)*len(scales)
        ws = (w_ratios[:, None] * scales[None, :]).view(-1)
        hs = (h_ratios[:, None] * scales[None, :]).view(-1)

        # left-top, right-bottom coordinate relative to anchor center(0, 0)
        # 生成的anchors模板都是以(0, 0)为中心的, shape [len(ratios)*len(scales), 4]
        base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2

        return base_anchors.round()  # round 四舍五入

    #计算预测特征图还原到原始图像上所有的点的所有anchors的坐标
    def grid_anchors(self,gird_sizes,strides):
        anchors=[]
        cell_anchors=self.cell_anchors
        # 遍历每个预测特征层的grid_size,strides和cell_anchors
        for size,stride,base_anchors in zip(gird_sizes,strides,cell_anchors):
            gird_height,gird_width=size
            stride_height,stride_width=stride
            device=base_anchors.device

            #特征图的坐标还原到原图大小的坐标
            shifts_x=torch.arange(0,gird_width,dtype=torch.float32,device=device)*stride_width
            shifts_y=torch.arange(0,gird_height,dtype=torch.float32,device=device)*stride_height

            #通过坐标还原到原图
            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
            shift_x = shift_x.reshape(-1)
            shift_y = shift_y.reshape(-1)

            shifts = torch.stack([shift_x, shift_y, shift_x, shift_y], dim=1)

            # 将anchors模板与原图上的坐标偏移量相加得到原图上所有anchors的坐标信息
            shifts_anchor = shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)

            anchors.append(shifts_anchor.reshape(-1,4))

        return anchors


    # 缓存anchors
    def cached_grid_anchors(self,grid_sizes,strides):
        key = str(grid_sizes) + str(strides)
        if key in self._cache:
            return self._cache[key]
        anchors=self.grid_anchors(grid_sizes,strides)
        self._cache[key]=anchors
        return anchors

    def set_cell_anchors(self,dtype,device):
        if self.cell_anchors is not None:
            cell_anchor=self.cell_anchors
            if cell_anchor[0].device==device:
                return

        cell_anchors = [
            self.generate_anchors(sizes, aspect_ratios, dtype, device)
            for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)
        ]
        self.cell_anchors = cell_anchors

    def forward(self,image_list,feature_maps):
        # 获取每个预测特征层的尺寸(height, width)
        grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
        # 获取输入图像的height和width
        image_size = image_list.tensors.shape[-2:]
        # 获取变量类型和设备类型
        dtype, device = feature_maps[0].dtype, feature_maps[0].device

        # one step in feature map equate n pixel stride in origin image
        # 计算特征层上的一步等于原始图像上的步长
        strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
                    torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
        # 根据提供的sizes和aspect_ratios生成anchors模板 ,以(0,0)为中心的5*3个anchor
        self.set_cell_anchors(dtype, device)

        # 计算/读取所有anchors的坐标信息(这里的anchors信息是映射到原图上的所有anchors信息,不是anchors模板)
        # 得到的是一个list列表,对应每层预测特征图映射回原图的3个anchors坐标信息
        anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)

        anchors=torch.jit.annotate(List[List[torch.Tensor]],[])

        #获取一个bitch下所有图片的所有anchors
        for i, (image_height, image_width) in enumerate(image_list.image_sizes):
            anchors_in_image = []
            # 遍历每层预测特征图映射回原图的anchors坐标信息
            for anchors_per_feature_map in anchors_over_all_feature_maps:
                anchors_in_image.append(anchors_per_feature_map)
            anchors.append(anchors_in_image)

        #把每张图片的anchor合在一起
        anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
        # Clear the cache in case that memory leaks.
        self._cache.clear()
        return anchors

class RPNHead(nn.Module):
    def __init__(self,in_channels,num_anchors):
        super(RPNHead, self).__init__()
        self.conv=nn.Conv2d(in_channels,in_channels,kernel_size=3,stride=1,padding=1)

        self.cls_logits=nn.Conv2d(in_channels,num_anchors,kernel_size=1,stride=1)
        self.bbox_pred=nn.Conv2d(in_channels,num_anchors*4,kernel_size=1,stride=1)

        for layer in self.children():
            if isinstance(layer,nn.Conv2d):
                torch.nn.init.normal(layer.weight,std=0.01)
                torch.nn.init.constant_(layer.bias,0)

    def forward(self,x):
        logits=[]
        bbox_pre=[]
        for i ,feature in enumerate(x):
            t=F.relu(self.conv(feature))
            logits.append(self.cls_logits(t))
            bbox_pre.append(self.bbox_pred(t))
        return logits,bbox_pre

def permute_and_flatten(layer,N,A,C,H,W):
    layer=layer.view(N,-1,C,H,W)
    layer=layer.permute(0,3,4,1,2)
    layer=layer.reshape(N,-1,C)
    return layer

def concat_box_prediction_layers(box_cls,box_regression):
    box_cls_flattened = []
    box_regression_flattened = []
    for box_cls_per_level,box_regression_per_level in zip(box_cls,box_regression):
        N, A, H, W = box_cls_per_level.shape
        C=1
        # [N, A, H, W]->[N, -1, C]
        box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W)
        box_cls_flattened.append(box_cls_per_level)
        #[N, A*4, H, W] ->[N, -1, C]
        box_regression_per_level = permute_and_flatten(box_regression_per_level, N, A, 4, H, W)
        box_regression_flattened.append(box_regression_per_level)

    box_cls=torch.cat(box_cls_flattened,dim=1).flatten(0,-2)
    box_regression=torch.cat(box_regression_flattened,dim=1).reshape(-1,4)

    return box_cls,box_regression


class RegionProposalNetwork(torch.nn.Module):

    def __init__(self,anchor_generator,head,fg_iou_thresh, bg_iou_thresh,
                 batch_size_per_image, positive_fraction,
                 pre_nms_top_n, post_nms_top_n, nms_thresh, score_thresh=0.0):
        super(RegionProposalNetwork, self).__init__()
        self.anchor_generator=anchor_generator
        self.head = head

        #todo
        self.box_coder=det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))

        # 计算anchors与真实bbox的iou
        #todo
        # self.box_similarity = box_ops.box_iou

        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(
            batch_size_per_image, positive_fraction  # 256, 0.5
        )


        #根据gt和pre_box的iou对 pre_box分类
        self.proposal_matcher = det_utils.Matcher(
            fg_iou_thresh,  # 当iou大于fg_iou_thresh(0.7)时视为正样本
            bg_iou_thresh,  # 当iou小于bg_iou_thresh(0.3)时视为负样本
            allow_low_quality_matches=True
        )

        # use during testing
        self._pre_nms_top_n = pre_nms_top_n
        self._post_nms_top_n = post_nms_top_n
        self.nms_thresh = nms_thresh
        self.score_thresh = score_thresh
        self.min_size = 1.

    def pre_nms_top_n(self):
        if self.training:
            return self._pre_nms_top_n['training']
        return self._pre_nms_top_n['testing']

    def post_nms_top_n(self):
        if self.training:
            return self._post_nms_top_n['training']
        return self._post_nms_top_n['testing']
    # 计算每个anchors最匹配的gt,并划分为正样本,背景以及废弃的样本
    def assign_targets_to_anchors(self,anchors,targets):
        labels=[]
        matched_gt_boxes=[]
        # 遍历每张图像的anchors和targets
        for anchors_per_image,targets_per_image in zip(anchors,targets):
            gt_boxes=targets_per_image["boxes"]
            if gt_boxes.numel() == 0:
                device = anchors_per_image.device
                matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device)
                labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device)
            else:
                match_quality_matrix = det_utils.box_iou(gt_boxes, anchors_per_image)
                # 计算每个anchors与gt匹配iou最大的索引(如果iou<0.3索引置为-1,0.3<iou<0.7索引为-2)
                matched_idxs = self.proposal_matcher(match_quality_matrix)

                matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]

                labels_per_image = matched_idxs >= 0
                labels_per_image = labels_per_image.to(dtype=torch.float32)

                # background (negative examples)
                bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD  # -1
                labels_per_image[bg_indices] = 0.0

                # discard indices that are between thresholds
                inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS  # -2
                labels_per_image[inds_to_discard] = -1.0

            labels.append(labels_per_image)
            matched_gt_boxes.append(matched_gt_boxes_per_image)
        return labels,matched_gt_boxes



    #获取每层预测特征图上预测概率排前pre_nms_top_n的anchors索引值
    def _get_top_n_idx(self,objectness,num_anchors_per_level):
        r=[]
        offset=0
        for ob in objectness.split(num_anchors_per_level,1):
            #采用jit跟踪模型
            if torchvision._is_tracing():
                num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n())
            else:
                num_anchors = ob.shape[1]  # 预测特征层上的预测的anchors个数
                pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors)
            # Returns the k largest elements of the given input tensor along a given dimension
            _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
            top_n_idx=top_n_idx.to('cpu')
            r.append(top_n_idx+offset)
            offset+=num_anchors
        return torch.cat(r, dim=1)


    #  筛除小boxes框,nms处理,根据预测概率获取前post_nms_top_n个目标
    def filter_proposals(self,proposals,objectness,image_shapes,num_anchors_per_level):
        num_images = proposals.shape[0]
        device = proposals.device

        # do not backprop throught objectness
        objectness = objectness.detach()
        objectness = objectness.reshape(num_images, -1)

        # levels负责记录分隔不同预测特征层上的anchors索引信息
        levels = [torch.full((n, ), idx, dtype=torch.int64, device=device)
                  for idx, n in enumerate(num_anchors_per_level)]
        levels = torch.cat(levels, 0)

        # Expand this tensor to the same size as objectness
        levels = levels.reshape(1, -1).expand_as(objectness)

        # 获取每张预测特征图上预测概率排前pre_nms_top_n的anchors索引值
        top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)

        image_range = torch.arange(num_images, device=device)
        batch_idx = image_range[:, None]  # [batch_size, 1]

        # 根据每个预测特征层预测概率排前pre_nms_top_n的anchors索引值获取相应概率信息
        objectness = objectness[batch_idx, top_n_idx]
        levels = levels[batch_idx, top_n_idx]
        # 预测概率排前pre_nms_top_n的anchors索引值获取相应bbox坐标信息
        proposals = proposals[batch_idx, top_n_idx]

        objectness_prob = torch.sigmoid(objectness)

        final_boxes = []
        final_scores = []

        # 遍历每张图像的相关预测信息
        for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
            # 调整预测的boxes信息,将越界的坐标调整到图片边界上
            boxes = det_utils.clip_boxes_to_image(boxes, img_shape)
            # 返回boxes满足宽,高都大于min_size的索引
            keep = det_utils.remove_small_boxes(boxes, self.min_size)
            boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]

            # 移除小概率boxes,参考下面这个链接
            # https://github.com/pytorch/vision/pull/3205
            keep = torch.where(torch.ge(scores, self.score_thresh))[0]  # ge: >=
            boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]

            # non-maximum suppression, independently done per level
            keep = det_utils.batched_nms(boxes, scores, lvl, self.nms_thresh)

            # keep only topk scoring predictions
            keep = keep[: self.post_nms_top_n()]
            boxes, scores = boxes[keep], scores[keep]
            final_boxes.append(boxes)
            final_scores.append(scores)


        return final_boxes, final_scores

    # 计算RPN损失,包括类别损失(前景与背景),bbox
    # regression损失
    def compute_loss(self,objectness,pred_bbox_deltas,labels,regression_targets):
        # 按照给定的batch_size_per_image, positive_fraction选择正负样本
        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
        # 将一个batch中的所有正负样本List(Tensor)分别拼接在一起,并获取非零位置的索引
        # sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1)
        sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0]
        # sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1)
        sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0]

        # 将所有正负样本索引拼接在一起
        sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
        objectness = objectness.flatten()

        labels = torch.cat(labels, dim=0)
        regression_targets = torch.cat(regression_targets, dim=0)

        # 计算边界框回归损失
        box_loss = det_utils.smooth_l1_loss(
            pred_bbox_deltas[sampled_pos_inds],
            regression_targets[sampled_pos_inds],
            beta=1 / 9,
            size_average=False,
        ) / (sampled_inds.numel())

        # 计算目标预测概率损失,正负样本loss
        objectness_loss = F.binary_cross_entropy_with_logits(
            objectness[sampled_inds], labels[sampled_inds]
        )

        return objectness_loss, box_loss


    def forward(self,images,features,targets=None):
        features=list(features.values())

        # 通过特征得到每个像素点的分类和bbox
        objectness, pred_bbox_deltas = self.head(features)

        # 生成一个batch图像的所有anchors信息,list(tensor)元素个数等于batch_size
        anchors = self.anchor_generator(images, features)

        # batch_size
        num_images = len(anchors)

        # 计算每个预测特征层上的对应的anchors数量
        num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
        #每层特征图每个点有3个anchor,计算每层特征图的anchor数量  ,特征图w*特征图h*3
        num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]

        # 调整内部tensor格式以及shape = sum(每层特征图w*特征图h*3)
        objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness,pred_bbox_deltas)

        # 将预测的bbox regression参数应用到anchors上得到最终预测bbox坐标
        proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
        proposals = proposals.view(num_images, -1, 4)

        # 筛除小boxes框,nms处理,根据预测概率获取前post_nms_top_n个目标  2000个
        boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)

        losses={}
        if self.training:
            # 计算每个anchors最匹配的gt,并将anchors进行分类,前景,背景以及废弃的anchors
            labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)

            # 结合anchors以及对应的gt,计算regression参数
            regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)

            loss_objectness, loss_rpn_box_reg = self.compute_loss(
                objectness, pred_bbox_deltas, labels, regression_targets
            )
            losses = {
                "loss_objectness": loss_objectness,
                "loss_rpn_box_reg": loss_rpn_box_reg
            }
        return boxes, losses

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
xiaoxingxing的头像xiaoxingxing管理团队
上一篇 2022年5月26日
下一篇 2022年5月26日

相关推荐