在很多任务中,需要将yolo网络与其他分类网络相结合,这时候需要通过yolo的标签,将标签中的目标从原图中截取出来,作为分类网络的数据集。
脚本很简单,写下来为了节约时间,反复利用。
代码简单介绍:
读取图像对于的标签txt,判断类别是否是需要提取的类,若是则用标签的坐标框信息从原图中生成目标roi,并保存到输出目录内,以类命名的子目录内。
代码:
import os
import cv2
def main():
# yolo标签目录
path_root_labels = '/media/clw/work/workspace/source/heima_train/train/labels'
# 图像目录
path_root_imgs ='/media/clw/work/workspace/source/heima_train/train/images'
#标签文件类型
type_object = '.txt'
# 图像文件类型
type_img = 'jpg'
# 需要提取的类的编号
cls_idx=['1','2']
# 输出目录
output_path = './output_cls'
for ii in os.walk(path_root_imgs):
for j in ii[2]:
type = j.split(".")[1]
if type != type_img:
continue
path_img = os.path.join(path_root_imgs, j)
print(path_img)
label_name = j[:-4]+type_object
path_label = os.path.join(path_root_labels, label_name)
# print(path_label)
if os.path.exists(path_label) == True:
f = open(path_label, 'r+', encoding='utf-8')
img = cv2.imread(path_img)
w = img.shape[1]
h = img.shape[0]
new_lines = []
count = 0
while True:
line = f.readline()
if line:
img_tmp = img.copy()
msg = line.split(" ")
cls = msg[0]
flag=0
for idx in cls_idx:
if idx == cls:
flag =1
if flag==0:
continue
# print(x_center,",",y_center,",",width,",",height)
x1 = int((float(msg[1]) - float(msg[3]) / 2) * w) # x_center - width/2
y1 = int((float(msg[2]) - float(msg[4]) / 2) * h) # y_center - height/2
x2 = int((float(msg[1]) + float(msg[3]) / 2) * w) # x_center + width/2
y2 = int((float(msg[2]) + float(msg[4]) / 2) * h) # y_center + height/2
print(x1,",",y1,",",x2,",",y2)
# cv2.rectangle(img_tmp,(x1,y1),(x2,y2),(0,0,255),5)
img_roi = img_tmp[y1:y2,x1:x2]
# cv2.imshow("show", img_roi)
# c = cv2.waitKey(0)
if os.path.exists(output_path) == False:
os.mkdir(output_path)
save_path = os.path.join(output_path, cls)
if os.path.exists(save_path) == False:
os.mkdir(save_path)
count +=1
rot_name = j + '_' + str(count) + '.jpg'
save_roi = os.path.join(save_path, rot_name)
cv2.imwrite(save_roi,img_roi)
else :
break
# cv2.imshow("show", img_tmp)
# c = cv2.waitKey(0)
if __name__ == '__main__':
main()
文章出处登录后可见!
已经登录?立即刷新