yolo5 自动标注代码

发布时间 2023-03-22 21:16:05作者: CHHC
# yolov5自动标注,打开labelimg,Open Dir:选择图片文件夹;Change Save Dir:选择标签文件夹,完成操作后,自动关联
import os
import torch
import cv2
import numpy as np

class AutoLabelImg:
    def __init__(self, dictype, yolov5src, yolov5pt):
        self.strtype = ''
        self.dictype = dictype
        for key, value in self.dictype.items():
            self.strtype += key + "\n"

        self.detector = torch.hub.load(yolov5src, 'custom', yolov5pt, source='local')
        self.detector.conf = 0.1

    def write_txt(self, filepath, msg):
        print ('filepath:' + filepath + ",msg:" + msg)
        file = open(filepath, 'w')
        file.write(msg)

    def xyxy2xywh(self, x):
        # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
        y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
        y[:, 0] = (x[:, 0] + x[:, 2]) / 2  # x center
        y[:, 1] = (x[:, 1] + x[:, 3]) / 2  # y center
        y[:, 2] = x[:, 2] - x[:, 0]  # width
        y[:, 3] = x[:, 3] - x[:, 1]  # height
        return y

    def start(self, imagedir, labeldir):
        self.delte_labelfile(labeldir)
        self.write_txt(os.path.join(labeldir, "classes.txt"), self.strtype)
        self.add_labelfile(imagedir, labeldir)

    def delte_labelfile(self, path):
        ls = os.listdir(path)
        for i in ls:
            c_path = os.path.join(path, i)
            if os.path.isdir(c_path):  # 如果是文件夹那么递归调用一下
                self.dellabelfile(c_path)
            else:  # 如果是一个文件那么直接删除
                os.remove(c_path)

    def add_labelfile(self, imagedir, labeldir):
        for filename in os.listdir(imagedir):
            image_filepath = os.path.join(imagedir, filename)
            img = cv2.imread(image_filepath)

            results = self.detector(img)
            pd = results.pandas().xyxy[0]
            gn = torch.tensor(img.shape)[[1, 0, 1, 0]]

            # 绘制检测框
            for obj in pd.to_numpy():
                a = torch.Tensor([obj[0], obj[1], obj[2], obj[3]])
                labeltag = (self.xyxy2xywh(a.view(1, 4)) / gn).view(-1).tolist()

                confidence = obj[4]
                name = obj[5]
                obj_name = obj[6]

                # 设置类型
                if not obj_name in self.dictype:
                    continue

                dstType = self.dictype[obj_name]
                note = str(dstType) + " " + str(labeltag[0]) + " " + str(labeltag[1]) + " " + str(labeltag[2]) + " " + str(labeltag[3])

                # 保存文件
                root, extension = os.path.splitext(filename)
                label_filepath = os.path.join(labeldir, root + ".txt")
                self.write_txt(label_filepath, note)

if __name__ == '__main__':
    dictype = {'cat': 0, 'dog': 1}
    yolov5src = 'G:\\yolov\\yolov5-5.0'
    yolov5pt = 'G:\\yolov\\yolov5-5.0\\weights\\yolov5s.pt'
    imagedir_base = 'G:\\yolov\\datasets\\catdog\\images\\'
    labeldir_base = 'G:\\yolov\\datasets\\catdog\\labels\\'

    auto = AutoLabelImg(dictype, yolov5src, yolov5pt)

    # 标注测试文件
    imagedir = imagedir_base + 'test'
    labeldir = labeldir_base + 'test'
    auto.start(imagedir, labeldir)

    # # 标注训练文件
    # imagedir = imagedir_base + 'train'
    # labeldir = labeldir_base + 'train'
    # auto.start(imagedir, labeldir)
    # #
    # # # 标注预测文件
    # imagedir = imagedir_base + 'val'
    # labeldir = labeldir_base + 'val'
    # auto.start(imagedir, labeldir)