Voc2Json--挑选voc中的类别生成json文件

发布时间 2023-12-26 17:29:40作者: 水木清扬

 

 

import argparse
import json ,shutil
import os,sys
import xml.etree.ElementTree as ET


parent = os.path.dirname(os.path.realpath(__file__))
gadent = os.path.dirname(parent)
sys.path.insert(0,gadent)
sys.path.append(gadent)

from utils.tool import listDir,draw_res



json_base_info = {'version': '4.5.6',
                'flags': {},
                'shapes': [],
                "imagePath": "0001250.jpg",
                "imageData": "null",
                "imageHeight": 720,
                "imageWidth": 1280
                }


def voc2Json(voc_base_dir,save_path_root,classes,prefix):

    ann_dir = os.path.abspath(os.path.join(voc_base_dir,"Annotations"))
    img_dir = os.path.abspath(os.path.join(voc_base_dir,"JPEGImages"))

    xml_list =[]
    listDir(ann_dir,xml_list,"xml")

    for xml_path in xml_list:

        print("xml path : {}".format(xml_path))

        basename_xml = os.path.basename(xml_path)
        basename_jpg = basename_xml.replace("xml","jpg")
        img_path = os.path.abspath(os.path.join(img_dir,basename_jpg))
        if not os.path.exists(img_path):
            continue

        tree = ET.parse(xml_path)
        root = tree.getroot()
        size = root.find('size')
        width = int(size.find('width').text)
        height = int(size.find('height').text)

        json_rects = []
        for obj in root.iter('object'):
            difficult = obj.find('difficult').text
            cls = obj.find('name').text
            if cls not in classes or int(difficult) == 1:
                continue
            cls_id = classes.index(cls)
            xmlbox = obj.find('bndbox')
            xmin,ymin,xmax,ymax = float(xmlbox.find('xmin').text),float(xmlbox.find('ymin').text),float(xmlbox.find('xmax').text) ,float(xmlbox.find('ymax').text)


            rect_dict_person = {"label": "person",
                                "points": [[0, 0], [0, 0]],
                                "group_id": "null",
                                "shape_type": "rectangle",
                                "flags": {}
                                }

            min_x, min_y, max_x, max_y = int(xmin), int(ymin), int(xmax), int(ymax)
            rect_dict_person["points"][0][0] = max(min_x, 0)
            rect_dict_person["points"][0][1] = max(int(ymin), 0)
            rect_dict_person["points"][1][0] = min(max_x, width)
            rect_dict_person["points"][1][1] = min(max_y, height)
            rect_dict_person["group_id"] = None

            json_rects.append(rect_dict_person)



        if len(json_rects) > 0:
            # 第二步创建json文件
            jsondata = json.dumps(json_base_info, indent=4, separators=(',', ': '))
            new_basename_jpg = prefix + basename_jpg

            save_path = os.path.abspath(os.path.join(save_path_root, new_basename_jpg))
            json_data_path = save_path.replace('jpg', 'json')

            f = open(json_data_path, 'w')
            f.write(jsondata)
            f.close()
            # 修正其中的内容
            with open(json_data_path, "r", encoding='utf-8') as jsonFile:
                json_data = json.load(jsonFile)

            json_data['imagePath'] = new_basename_jpg
            json_data['imageData'] = None
            json_data['imageHeight'] = height
            json_data['imageWidth'] = width


            json_data['shapes'] = json_rects
            with open(json_data_path, "w") as jsonFile:
                json.dump(json_data, jsonFile, ensure_ascii=False, indent=4, separators=(',', ': '))
                jsonFile.close()
            shutil.copy(img_path, save_path)




class Voc2Json():
    def __init__(self,args):
        self.voc_base_dir = args.VOC_base_dir
        self.save_path = args.save_path
        self.classes =  args.classes
        self.prefix = args.prefix

    def make(self):
        voc2Json(self.voc_base_dir,self.save_path,self.classes,self.prefix)



parser = argparse.ArgumentParser(description='VOC Datasets Convert json dataset')
parser.add_argument('--VOC_base_dir',default=None,type=str,help='VOC数据集的基础路径')
parser.add_argument('--save_path',default=None,type=str,help="转换后的数据集保存路径")
parser.add_argument('--classes',default=['person'],type=list,help="需要从VOC提取的数据集的类别标签")
parser.add_argument('--prefix',default="voc2012_",type=str,help="新的数据集前缀")



if __name__ =='__main__':
    args =parser.parse_args()
    args.VOC_base_dir = "E:/datasets/public_datasets/VOC/VOCtrainval_11-May-2012/VOCdevkit/VOC2012"
    args.save_path = "E:/datasets/public_datasets/person_voc12_src_size_box"

    print("start !!!")
    vt = Voc2Json(args)
    vt.make()
    print("Done !!!")