通过yolo格式标注将标签中的目标从原图中截取出来

在很多任务中,需要将yolo网络与其他分类网络相结合,这时候需要通过yolo的标签,将标签中的目标从原图中截取出来,作为分类网络的数据集。

脚本很简单,写下来为了节约时间,反复利用。

代码简单介绍:
读取图像对于的标签txt,判断类别是否是需要提取的类,若是则用标签的坐标框信息从原图中生成目标roi,并保存到输出目录内,以类命名的子目录内。

代码:

import os
import cv2


def main():
	# yolo标签目录
    path_root_labels = '/media/clw/work/workspace/source/heima_train/train/labels'
    # 图像目录
    path_root_imgs ='/media/clw/work/workspace/source/heima_train/train/images'
    #标签文件类型
    type_object = '.txt'
	# 图像文件类型
    type_img = 'jpg'
    # 需要提取的类的编号
    cls_idx=['1','2']
    # 输出目录
    output_path = './output_cls'


    for ii in os.walk(path_root_imgs):
        for j in ii[2]:
            type = j.split(".")[1]
            if type != type_img:
                continue
            path_img = os.path.join(path_root_imgs, j)
            print(path_img)
            label_name = j[:-4]+type_object
            path_label = os.path.join(path_root_labels, label_name)
            # print(path_label)
            if os.path.exists(path_label) == True:
                f = open(path_label, 'r+', encoding='utf-8')
                img = cv2.imread(path_img)
                w = img.shape[1]
                h = img.shape[0]
                new_lines = []
                count = 0
                while True:
                    line = f.readline()
                    if line:
                        img_tmp = img.copy()
                        msg = line.split(" ")
                        cls = msg[0]
                        flag=0
                        for idx in cls_idx:
                            if idx == cls:
                                flag =1
                        if flag==0:
                            continue
                        # print(x_center,",",y_center,",",width,",",height)
                        x1 = int((float(msg[1]) - float(msg[3]) / 2) * w)  # x_center - width/2
                        y1 = int((float(msg[2]) - float(msg[4]) / 2) * h)  # y_center - height/2
                        x2 = int((float(msg[1]) + float(msg[3]) / 2) * w)  # x_center + width/2
                        y2 = int((float(msg[2]) + float(msg[4]) / 2) * h)  # y_center + height/2
                        print(x1,",",y1,",",x2,",",y2)
                        # cv2.rectangle(img_tmp,(x1,y1),(x2,y2),(0,0,255),5)
                        img_roi = img_tmp[y1:y2,x1:x2]
                        # cv2.imshow("show", img_roi)
                        # c = cv2.waitKey(0)
                        if os.path.exists(output_path) == False:
                            os.mkdir(output_path)
                        save_path = os.path.join(output_path, cls)
                        if os.path.exists(save_path) == False:
                            os.mkdir(save_path)
                        count +=1
                        rot_name = j + '_' + str(count) + '.jpg'
                        save_roi = os.path.join(save_path, rot_name)
                        cv2.imwrite(save_roi,img_roi)
                    else :
                        break
            # cv2.imshow("show", img_tmp)
            # c = cv2.waitKey(0)



if __name__ == '__main__':
    main()



文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
上一篇 2022年5月22日 下午12:29
下一篇 2022年5月22日 下午12:33

相关推荐

本站注重文章个人版权,不会主动收集付费或者带有商业版权的文章,如果出现侵权情况只可能是作者后期更改了版权声明,如果出现这种情况请主动联系我们,我们看到会在第一时间删除!本站专注于人工智能高质量优质文章收集,方便各位学者快速找到学习资源,本站收集的文章都会附上文章出处,如果不愿意分享到本平台,我们会第一时间删除!