MMdetection数据集格式转换——LabelImg/xml/yolo格式转Custom自定义格式数据集/Coco格式数据集

2023-11-10

训练customdataset自定义数据集

一、修改:

1.将mmdetection/configs/yolox/yolox_s_8x8_300e_coco.py中的metric设置为mAP,如下图

2.将mmdetection/mmdet/datasets/custom.py中的333行注释,新增334行内容,如下图。

3.修改mmdetection/configs/yolox/yolox_s_8x8_300e_coco.py中数据集以及pkl文件路径,data_root设置为自己的数据集根目录,如下图。

二、数据集格式转换:

将所有需要训练的图片以及xml标注文件存放入tmp文件夹中,以电线数据集为例,在tmp文件夹中存放了所以的图片以及xml文件,在与tmp文件夹同级目录下运行格式转换脚本xml_custom.py,即命令Python xml_custom.py,即可自动生成val、train、test、val.pkl、train.pkl、test.pkl文件,最后直接执行模型训练命令即可执行训练。

xml—custom转换脚本代码

# coding:utf-8

# pip install lxml

import glob
import json
import shutil
import numpy as np
import pickle
import os
import argparse
import xml.etree.ElementTree as ET
from pathlib import Path
def get(root, name):
    return root.findall(name)


def get_and_check(root, name, length):
    vars = root.findall(name)
    if len(vars) == 0:
        raise NotImplementedError('Can not find %s in %s.' % (name, root.tag))
    if length > 0 and len(vars) != length:
        raise NotImplementedError('The size of %s is supposed to be %d, but is %d.' % (name, length, len(vars)))
    if length == 1:
        vars = vars[0]
    return vars

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--image_xml_dir', type=str,
                        help='Directory of images and xml.')
    parser.add_argument('--image_type', type=str, default='.png',
                        choices=['.jpg', '.png'], help='Type of image file.')
    parser.add_argument('--output_dir', type=str,
                        help='Directory of output.')
    a=parser.parse_args()
    image_geshi = a.image_type  # 设置图片的后缀名为png
    origin_ann_dir = a.image_xml_dir  # 设置存放所以xml和图片路径为tmp
    path2 =a.output_dir
    classes = []
    for dirpaths, dirnames, filenames in os.walk(origin_ann_dir):  # os.walk游走遍历目录名
        for filename in filenames:
            if filename.endswith('.xml'):
                if os.path.isfile(os.path.join(origin_ann_dir, filename)):  # 获取原始xml文件绝对路径,isfile()检测是否为文件 isdir检测是否为目录
                    origin_ann_path = os.path.join(r'%s%s' % (origin_ann_dir, filename))  # 如果是,获取绝对路径(重复代码)
                    # new_ann_path = os.path.join(r'%s%s' %(new_ann_dir, filename))
                    tree = ET.parse(origin_ann_path)  # ET是一个xml文件解析库,ET.parse()打开xml文件。parse--"解析"
                    root = tree.getroot()  # 获取根节点
                    for object in root.findall('object'):
                        xmlbox = object.find('bndbox')  # 找到根节点下所有“object”节点
                        name = str(object.find('name').text)  # 找到object节点下name子节点的值(字符串)
                        if name not in classes:
                            classes.append(name)
    with open(path2 + r"/class.txt", "w") as f:
        f.write('\n'.join(classes))
    f.close()

    START_BOUNDING_BOX_ID = 1

    train_ratio = 0.7
    val_ratio = 0.2
    test_ratio = 0.1
    xml_dir = origin_ann_dir

    xml_list = glob.glob(xml_dir + "/*.xml")
    xml_list = np.sort(xml_list)
    np.random.seed(100)
    np.random.shuffle(xml_list)

    train_num = int(len(xml_list) * train_ratio)
    val_num = int(len(xml_list) * val_ratio)
    xml_list_train = xml_list[:train_num]
    xml_list_val = xml_list[train_num:train_num + val_num]
    xml_list_test = xml_list[train_num + val_num:]
    f1 = open(path2 + r"/train.txt", "w")
    for xml in xml_list_train:
        img = xml[:-4] + image_geshi
        f1.write(os.path.basename(xml)[:-4] + "\n")
    f2 = open(path2+ r"/val.txt", "w")
    for xml in xml_list_val:
        img = xml[:-4] + image_geshi
        f2.write(os.path.basename(xml)[:-4] + "\n")
    f3 = open(path2 + r"/test.txt", "w")
    for xml in xml_list_test:
        img = xml[:-4] + image_geshi
        f3.write(os.path.basename(xml)[:-4] + "\n")
    f1.close()
    f2.close()
    f3.close()
    print("-------------------------------")
    print("train number:", len(xml_list_train))
    print("val number:", len(xml_list_val))
    print("test number:", len(xml_list_test))

    with open(path2 + r"\train.txt", "r", encoding="utf-8") as f:
        paths = [i.strip() for i in f.readlines()]

    path3 = path2 +'/train'
    if os.path.exists(path3):
        shutil.rmtree(path3)
        os.mkdir(path3)
    else:
        os.mkdir(path3)
    dst_dir=path2 + "/train"
    for i in paths:
        img_path =xml_dir+i+image_geshi
        xml_path=xml_dir+i+".xml"
        shutil.copy(img_path,dst_dir+"/"+i+image_geshi)
        shutil.copy(xml_path,dst_dir+"/"+i+".xml")

    with open(path2 + r"/class.txt", "r", encoding="utf-8") as f:
        a = [i.strip() for i in f.readlines()]
    classes = a
    xml_dir1 = path2 +"/train/"
    xml_list = glob.glob(xml_dir1 + "/*.xml")
    xml_list = np.sort(xml_list)
    pre_define_categories = {}
    pkl_dict = []
    categories = pre_define_categories.copy()
    bnd_id = START_BOUNDING_BOX_ID
    all_categories = {}
    number = 0
    for index, line in enumerate(xml_list):
        box_data = []
        box_data1 = []
        labels_data = []
        xml_f = line
        tree = ET.parse(xml_f)
        root = tree.getroot()
        number = number + 1
        filename = os.path.basename(xml_f)[:-4] + image_geshi
        image_id = 20190000001 + index
        size = get_and_check(root, 'size', 1)
        width = int(get_and_check(size, 'width', 1).text)
        height = int(get_and_check(size, 'height', 1).text)
        for obj in get(root, 'object'):
            category = get_and_check(obj, 'name', 1).text
            bndbox = get_and_check(obj, 'bndbox', 1)
            xmin = int(float(get_and_check(bndbox, 'xmin', 1).text))
            ymin = int(float(get_and_check(bndbox, 'ymin', 1).text))
            xmax = int(float(get_and_check(bndbox, 'xmax', 1).text))
            ymax = int(float(get_and_check(bndbox, 'ymax', 1).text))
            assert (xmax > xmin), "xmax <= xmin, {}".format(line)
            assert (ymax > ymin), "ymax <= ymin, {}".format(line)
            o_width = abs(xmax - xmin)
            o_height = abs(ymax - ymin)
            box_data1 = [xmin, ymin, xmax, ymax]
            box_data.append(box_data1)
            name = int(obj.find('name').text)
            labels_data.append(name)
        image = {'filename': filename, 'width': width, 'height': height, 'ann': {'bboxes': np.array(box_data),
                                                                                      'labels': np.array(labels_data).T}}
        pkl_dict.append(image)
        print(number)
    with open(path2 + r"/train.pkl","wb") as f:
        pickle.dump(pkl_dict,f)
    print("success-train")

    # -------------------------------------
    with open(path2 + r"/val.txt", "r", encoding="utf-8") as f:
        paths = [i.strip() for i in f.readlines()]

    path3 =path2 + '/val'
    if os.path.exists(path3):
        shutil.rmtree(path3)
        os.mkdir(path3)
    else:
        os.mkdir(path3)
    dst_dir=path2 + "/val"
    for i in paths:
        img_path =xml_dir+i+image_geshi
        xml_path=xml_dir+i+".xml"
        shutil.copy(img_path,dst_dir+"/"+i+image_geshi)
        shutil.copy(xml_path,dst_dir+"/"+i+".xml")

    with open(path2 + r"/class.txt", "r", encoding="utf-8") as f:
        a = [i.strip() for i in f.readlines()]
    classes = a
    xml_dir2 = path2 + "/val/"
    xml_list = glob.glob(xml_dir2 + "/*.xml")
    xml_list = np.sort(xml_list)
    pre_define_categories = {}
    pkl_dict = []
    categories = pre_define_categories.copy()
    bnd_id = START_BOUNDING_BOX_ID
    all_categories = {}
    number = 0
    for index, line in enumerate(xml_list):
        box_data = []
        box_data1 = []
        labels_data = []
        xml_f = line
        tree = ET.parse(xml_f)
        root = tree.getroot()
        number = number + 1
        filename = os.path.basename(xml_f)[:-4] + image_geshi
        image_id = 20190000001 + index
        size = get_and_check(root, 'size', 1)
        width = int(get_and_check(size, 'width', 1).text)
        height = int(get_and_check(size, 'height', 1).text)
        for obj in get(root, 'object'):
            category = get_and_check(obj, 'name', 1).text
            bndbox = get_and_check(obj, 'bndbox', 1)
            xmin = int(float(get_and_check(bndbox, 'xmin', 1).text))
            ymin = int(float(get_and_check(bndbox, 'ymin', 1).text))
            xmax = int(float(get_and_check(bndbox, 'xmax', 1).text))
            ymax = int(float(get_and_check(bndbox, 'ymax', 1).text))
            assert (xmax > xmin), "xmax <= xmin, {}".format(line)
            assert (ymax > ymin), "ymax <= ymin, {}".format(line)
            o_width = abs(xmax - xmin)
            o_height = abs(ymax - ymin)
            box_data1 = [xmin, ymin, xmax, ymax]
            box_data.append(box_data1)
            name = int(obj.find('name').text)
            labels_data.append(name)
        image = {'filename': filename, 'width': width, 'height': height, 'ann': {'bboxes': np.array(box_data),
                                                                                      'labels': np.array(labels_data).T}}
        pkl_dict.append(image)
        print(number)
    with open(path2 + r"\val.pkl","wb") as f:
        pickle.dump(pkl_dict,f)
    print("success-val")

    # -------------------------------------------
    with open(path2 + r"\test.txt", "r", encoding="utf-8") as f:
        paths = [i.strip() for i in f.readlines()]

    path3 =path2 +  r'/test'
    if os.path.exists(path3):
        shutil.rmtree(path3)
        os.mkdir(path3)
    else:
        os.mkdir(path3)
    dst_dir=path2 + "/test"
    for i in paths:
        img_path =xml_dir+i+image_geshi
        xml_path =xml_dir +i+".xml"
        shutil.copy(img_path,dst_dir+"/"+i+image_geshi)
        shutil.copy(xml_path,dst_dir+"/"+i+".xml")

    with open(path2 + r"/class.txt", "r", encoding="utf-8") as f:
        a = [i.strip() for i in f.readlines()]
    classes = a
    xml_dir3 =path2 +  "/test/"
    xml_list = glob.glob(xml_dir3 + "/*.xml")
    xml_list = np.sort(xml_list)
    pre_define_categories = {}
    pkl_dict = []
    categories = pre_define_categories.copy()
    bnd_id = START_BOUNDING_BOX_ID
    all_categories = {}
    number = 0
    for index, line in enumerate(xml_list):
        box_data = []
        box_data1 = []
        labels_data = []
        xml_f = line
        tree = ET.parse(xml_f)
        root = tree.getroot()
        number = number + 1
        filename = os.path.basename(xml_f)[:-4] + image_geshi
        image_id = 20190000001 + index
        size = get_and_check(root, 'size', 1)
        width = int(get_and_check(size, 'width', 1).text)
        height = int(get_and_check(size, 'height', 1).text)
        for obj in get(root, 'object'):
            category = get_and_check(obj, 'name', 1).text
            bndbox = get_and_check(obj, 'bndbox', 1)
            xmin = int(float(get_and_check(bndbox, 'xmin', 1).text))
            ymin = int(float(get_and_check(bndbox, 'ymin', 1).text))
            xmax = int(float(get_and_check(bndbox, 'xmax', 1).text))
            ymax = int(float(get_and_check(bndbox, 'ymax', 1).text))
            assert (xmax > xmin), "xmax <= xmin, {}".format(line)
            assert (ymax > ymin), "ymax <= ymin, {}".format(line)
            o_width = abs(xmax - xmin)
            o_height = abs(ymax - ymin)
            box_data1 = [xmin, ymin, xmax, ymax]
            box_data.append(box_data1)
            name = int(obj.find('name').text)
            labels_data.append(name)
        image = {'filename': filename, 'width': width, 'height': height, 'ann': {'bboxes': np.array(box_data),
                                                                                      'labels': np.array(labels_data).T}}
        pkl_dict.append(image)
        print(number)
    with open(path2 + r"/test.pkl","wb") as f:
        pickle.dump(pkl_dict,f)
    print("success-test")

yolo/xml—Coco转换脚本代码

# pylint: disable=no-member
"""

在参数列表里修改对应的参数,--image_dir是指图片存放的路径,--label_dir是指标签存放的路径
--res_dir是指标完框的图片存放的位置 --concat_dir是指所有图片拼接之后存放的路径
--num是指拼接图片的尺寸 例如num=3,拼接完的图片为3*3,
--type是指label的类型,选项有xml_cp 、 xml_lt 、 txt_cp 、 txt_lt
xml_cp是指标签文件为xml且需要将xml文件转换为txt文件,txt文件保存内容为类别、中心点坐标(x,y)以及宽w、高h,txt文件会被保存在原标签文件夹下
xml_lt是指标签文件为xml但是不需要转化txt
txt_cp是指标签文件为txt,且文件内容为类别、label中心点坐标(x,y)以及宽w、高h
txt_lt是指标签文件为txt,且文件内容为类别、label左上角和右下角坐标

此外,还需要修改dic,用来存放标注的标签数字与真实的类别一一对应的字典,例如dic{'0':'red'},表示标签0对应red
如果标注时,并没有简化成0、1、2...这样的数字,可以不修改,图片上会直接显示标注时的名称red
"""
import os
import argparse
import logging
from typing import List, Tuple, Union

from xml.dom import minidom
import imagesize
import cv2
import numpy as np
from tqdm import tqdm

class_names = []


def read_xml_bbox(xml_path_list: List[str]) -> List[List[Union[List, Tuple]]]:
    """Get bounding boxes from XML files"""
    ans = []
    for curr_xml_path in tqdm(xml_path_list, desc='Loading XML files'):
        curr_image = []
        ans.append(curr_image)
        dom = minidom.parse(curr_xml_path)
        for obj_node in dom.getElementsByTagName('object'):
            class_name = obj_node.getElementsByTagName('name')[0].firstChild.data

            bbox_node = obj_node.getElementsByTagName('bndbox')[0]
            xmin, ymin, xmax, ymax = (
                int(float(bbox_node.getElementsByTagName(name)[0].firstChild.data)) - 1
                for name in ['xmin', 'ymin', 'xmax', 'ymax'])  # x in [0, width), y in [0, height)

            curr_image.append([class_name, (xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
    return ans


def save_xml(bbox_list, save_dir, image_path_list):
    """
    Save bounding boxes for all images
    :param bbox_list: Bounding box list
    :param save_dir: Save directory
    :param image_path_list: Image path list
    :return: None
    """
    for bounding_boxes, image_path in tqdm(zip(bbox_list, image_path_list),
                                           total=len(image_path_list),
                                           desc='Saving XML files'):
        doc = minidom.Document()

        root_node = doc.createElement('annotation')
        doc.appendChild(root_node)

        create_element_with_text(doc, root_node, 'folder', os.path.split(image_path)[-2])
        create_element_with_text(doc, root_node, 'filename', os.path.split(image_path)[-1])
        create_element_with_text(doc, root_node, 'path', image_path)

        source_node = doc.createElement('source')

        create_element_with_text(doc, source_node, 'database', 'Unknown')

        root_node.appendChild(source_node)

        size_node = doc.createElement('size')
        for element_name, value in zip(['width', 'height', 'depth'],
                                       [*imagesize.get(image_path), 3]):
            elem = doc.createElement(element_name)
            elem.appendChild(doc.createTextNode(str(value)))
            size_node.appendChild(elem)
        root_node.appendChild(size_node)

        create_element_with_text(doc, root_node, 'segmented', '0')

        for bbox in bounding_boxes:
            obj_node = doc.createElement('object')

            create_element_with_text(doc, obj_node, 'name', get_class_name(bbox[0]))
            create_element_with_text(doc, obj_node, 'pose', 'Unspecified')
            create_element_with_text(doc, obj_node, 'truncated', '0')
            create_element_with_text(doc, obj_node, 'difficult', '0')

            bndbox_node = doc.createElement('bndbox')
            for element_name, value in zip(['xmin', 'ymin', 'xmax', 'ymax'],
                                           [min(bbox[1:], key=lambda x: x[0])[0],
                                            min(bbox[1:], key=lambda x: x[1])[1],
                                            max(bbox[1:], key=lambda x: x[0])[0],
                                            max(bbox[1:], key=lambda x: x[1])[1]]):
                create_element_with_text(doc, bndbox_node, element_name, str(value + 1))

            obj_node.appendChild(bndbox_node)
            root_node.appendChild(obj_node)

        with open(os.path.join(save_dir, os.path.split(image_path)[-1][:-4] + '.xml'),
                  'w', encoding='utf-8') as file:
            doc.writexml(file, indent='', addindent='\t', newl='\n', encoding='utf-8')


def create_element_with_text(doc, node, element_name, text):
    element_node = doc.createElement(element_name)
    element_node.appendChild(doc.createTextNode(text))
    node.appendChild(element_node)


def read_txt_bbox(txt_path_list: List[str],
                  image_path_list: List[str]) -> List[List[Union[List, Tuple]]]:
    """Get bounding boxes from TXT files"""
    ans = []
    for i in tqdm(range(len(txt_path_list)), desc='Loading TXT files'):
        curr_image = []
        ans.append(curr_image)
        with open(txt_path_list[i], 'r') as file:
            s = file.read().strip()
        # image_height, image_width = \
        #     cv2.imdecode(np.fromfile(image_path_list[i], dtype=np.uint8), -1).shape[:2]
        image_width, image_height = imagesize.get(image_path_list[i])
        for line in s.strip().split('\n'):
            if not line:
                continue
            class_id, x, y, w, h = line.split()
            class_id = int(class_id)
            x, y = float(x) * image_width, float(y) * image_height
            w, h = float(w) * image_width, float(h) * image_height
            xmin, ymin, xmax, ymax = \
                int(x - w / 2), int(y - h / 2), int(x + w / 2) - 1, int(y + h / 2) - 1
            curr_image.append([class_id, (xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
    return ans


def save_txt(bbox_list, output_label_dir, image_path_list):
    for bounding_boxes, image_path in tqdm(zip(bbox_list, image_path_list),
                                           total=len(image_path_list),
                                           desc='Saving TXT files'):
        image_width, image_height = imagesize.get(image_path)

        txt_path = os.path.join(output_label_dir, os.path.split(image_path)[-1][:-4] + '.txt')
        with open(txt_path, 'w') as file:
            for bbox in bounding_boxes:
                xmin = min(bbox[1:], key=lambda x: x[0])[0]
                ymin = min(bbox[1:], key=lambda x: x[1])[1]
                xmax = max(bbox[1:], key=lambda x: x[0])[0]
                ymax = max(bbox[1:], key=lambda x: x[1])[1]
                cx, cy = (xmin + xmax) / 2, (ymin + ymax) / 2
                w, h = xmax - xmin, ymax - ymin
                file.write(str(class_names.index(bbox[0])) + ' ' +
                           str(cx / image_width) + ' ' +
                           str(cy / image_height) + ' ' +
                           str(w / image_width) + ' ' +
                           str(h / image_height) + '\n')


def read_txt_rotate_bbox(txt_path_list, image_path_list):
    """Get bounding boxes from TXT files with rotation"""
    ans = []
    for i in tqdm(range(len(txt_path_list)), desc='Loading TXT files'):
        curr_image = []
        ans.append(curr_image)
        with open(txt_path_list[i], 'r') as file:
            s = file.read().strip()
        # image_height, image_width = \
        #     cv2.imdecode(np.fromfile(image_path_list[i], dtype=np.uint8), -1).shape[:2]
        image_width, image_height = imagesize.get(image_path_list[i])
        for line in s.strip().split('\n'):
            if not line:
                continue
            class_id, cx, cy, w, h, theta = line.split()
            class_id = int(class_id)
            cx, cy = float(cx) * image_width, float(cy) * image_height
            w, h = float(w) * image_width, float(h) * image_height
            theta = float(theta)
            xmin, ymin, xmax, ymax = \
                int(cx - w / 2), int(cy - h / 2), int(cx + w / 2) - 1, int(cy + h / 2) - 1
            points = []
            for p in [(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)]:
                p = rotated_point(p, (cx, cy), theta)
                points.append((int(p[0]), int(p[1])))
            curr_image.append([class_id, *points])
    return ans


def save_txt_rotate(bbox_list, output_label_dir, image_path_list):
    for bounding_boxes, image_path in tqdm(zip(bbox_list, image_path_list),
                                           total=len(image_path_list),
                                           desc='Saving TXT files'):
        image_width, image_height = imagesize.get(image_path)

        txt_path = os.path.join(output_label_dir, os.path.split(image_path)[-1][:-4] + '.txt')
        with open(txt_path, 'w') as file:
            for bbox in bounding_boxes:
                rect = cv2.minAreaRect(np.array(bbox[1:]))
                (cx, cy), (w, h), theta = rect
                file.write(str(get_class_id(bbox[0])) + ' ' +
                           str(cx / image_width) + ' ' +
                           str(cy / image_height) + ' ' +
                           str(w / image_width) + ' ' +
                           str(h / image_height) + ' ' +
                           str(180 - theta) + '\n')


def rotated_point(p, q, theta):
    """Return The coordinate of point p after
    rotating counterclockwise by angle theta around point q."""
    # The order of the coordinate axes in the coordinate system of the image
    # is opposite to that in the Cartesian coordinate system, so we need to reverse the angle.
    theta = -theta
    x = (p[0] - q[0]) * np.cos(np.deg2rad(theta)) - (p[1] - q[1]) * np.sin(np.deg2rad(theta)) + q[0]
    y = (p[0] - q[0]) * np.sin(np.deg2rad(theta)) + (p[1] - q[1]) * np.cos(np.deg2rad(theta)) + q[1]
    return x, y


def get_path_list(root_dir):
    """Return all file names in the specified directory"""
    return [os.path.join(root_dir, file_name) for file_name in os.listdir(root_dir)]


def read_class_names(class_path):
    global class_names
    with open(class_path, 'r') as file:
        class_names = file.read().strip().split('\n')


def get_class_name(class_id: Union[str, int]) -> str:
    global class_names
    if isinstance(class_id, str):
        return class_id
    return class_names[class_id]


def get_class_id(class_name: Union[str, int]) -> int:
    global class_names
    if isinstance(class_name, int):
        return class_name
    return class_names.index(class_name)


def draw_one_image(image_path, bbox_list, output_path):
    """
    Draw all bounding boxes in one image.

    :param image_path: Image path
    :param bbox_list: Bounding box list
    :param output_path: Output path to the final image
    :return: None
    """
    image = cv2.imread(image_path)
    for bbox in bbox_list:
        # draw current bounding box
        for i in range(2, len(bbox)):
            cv2.line(image, bbox[i - 1], bbox[i], (0, 255, 0), 2)
        cv2.line(image, bbox[-1], bbox[1], (0, 255, 0), 2)

        # draw class name
        top_left = min(bbox[1:])
        class_name = get_class_name(bbox[0])
        x, y = min((top_left[0], image.shape[1] - 16 * len(class_name))), max(top_left[1], 20)
        cv2.putText(image, class_name, (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.75,
                    (0, 255, 0), 1, cv2.LINE_AA, False)

    cv2.imwrite(output_path, image)


def concat_images_into_one(image_path_list, output_path, row, col, width, height):
    """
    Concatenate images into one image.

    :param image_path_list: paths of input images
    :param output_path: output image path
    :param row: Number of images in one column of the output image
    :param col: Number of images in one row of the output image
    :param width: width of one column
    :param height: height of one row
    :return: None
    """
    if row * col < len(image_path_list):
        raise ValueError('Cannot concat too many images')
    output_image = np.zeros((height * row, width * col, 3))
    for i in range(row):
        for j in range(col):
            idx = i * col + j
            if idx >= len(image_path_list):
                cv2.imwrite(output_path, output_image)
                return
            image = cv2.imdecode(np.fromfile(image_path_list[idx], dtype=np.uint8), -1)
            curr_image = cv2.resize(image, (width, height))
            output_image[height * i:height * (i + 1), width * j:width * (j + 1)] = curr_image
    cv2.imwrite(output_path, output_image)


def concat_all_images(image_path_list, output_dir, row, col, width, height):
    """
    Concatenate images in groups.

    :param image_path_list: paths of input images
    :param output_dir: output image path
    :param row: Number of images in one column of the output image
    :param col: Number of images in one row of the output image
    :param width: width of one column
    :param height: height of one row
    :return: None
    """
    i = 0
    count = 0
    pbar = tqdm(total=len(image_path_list), desc='Concatenating')

    while i + row * col - 1 < len(image_path_list):
        concat_images_into_one(image_path_list[i:i + row * col],
                               os.path.join(output_dir, str(count) + '.png'),
                               row, col, width, height)
        i += row * col
        count += 1
        pbar.update(row * col)

    if i + col - 1 < len(image_path_list):
        curr_row = (len(image_path_list) - i) // col
        concat_images_into_one(image_path_list[i:i + curr_row * col],
                               os.path.join(output_dir, str(count) + '.png'),
                               curr_row, col, width, height)
        i += curr_row * col
        count += 1
        pbar.update(curr_row * col)
    if i < len(image_path_list):
        concat_images_into_one(image_path_list[i:],
                               os.path.join(output_dir, str(count) + '.png'),
                               1, len(image_path_list) - i, width, height)
        pbar.update(len(image_path_list) - i)
    pbar.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--image_dir', type=str, default='resize_img_30_new/resize_img_30_new',
                        help='Directory of images.')
    parser.add_argument('--image_type', type=str, default='png',
                        choices=['jpg', 'png'], help='Type of image file.')
    parser.add_argument('--label_dir', type=str, default='xml_resize_30_2',
                        help='Directory of labels.')
    parser.add_argument('--label_type', type=str, default='xml',
                        choices=['xml', 'txt', 'txt_rotate'], help='Type of input label file.')

    parser.add_argument('--class_path', type=str, default='class.txt',
                        help='Path to the file which stores class names.')

    parser.add_argument('--draw_dir', type=str, default='',
                        help='Output directory of drawn images.')

    parser.add_argument('--concat_dir', type=str, default='',
                        help='Output directory of concatenated images.')
    parser.add_argument('--concat_row', type=int, default=4,
                        help='The number of rows of the images in each concatenated image.')
    parser.add_argument('--concat_col', type=int, default=4,
                        help='The number of columns of the images in each concatenated image.')
    parser.add_argument('--concat_width', type=int, default=800,
                        help='The width of each image in each concatenated image.')
    parser.add_argument('--concat_height', type=int, default=600,
                        help='The height of each image in each concatenated image.')

    parser.add_argument('--output_label_type', type=str, default='txt',
                        choices=['xml', 'txt', 'txt_rotate'], help='Type of output label file.')
    parser.add_argument('--output_label_dir', type=str, default='dx_crop_sf30_txt',
                        help='Directory of output label files.')
    args = parser.parse_args()

    # Get file path list
    image_path_list = get_path_list(args.image_dir)
    label_path_list = get_path_list(args.label_dir)

    # Read class names
    read_class_names(args.class_path)

    # Filter path
    image_path_list = list(filter(lambda p: p[-4:] == '.' + args.image_type, image_path_list))
    if args.label_type == 'txt_rotate':
        label_path_list = list(filter(lambda p: p[-4:] == '.txt', label_path_list))
    else:
        label_path_list = list(filter(lambda p: p[-4:] == '.' + args.label_type, label_path_list))

    # Filter uncommon files
    common_names = set(os.path.split(path)[-1][:-4] for path in label_path_list)
    common_names.intersection_update(os.path.split(path)[-1][:-4] for path in image_path_list)
    if len(common_names) < len(image_path_list) or len(common_names) < len(label_path_list):
        logging.warning('Files in the label folder and the image folder are inconsistent.')
    image_path_list = list(filter(lambda p: os.path.split(p)[-1][:-4] in common_names,
                                  image_path_list))
    label_path_list = list(filter(lambda p: os.path.split(p)[-1][:-4] in common_names,
                                  label_path_list))
    image_path_list.sort()
    label_path_list.sort()

    # Read bounding boxes
    if args.label_type == 'xml':
        bbox_list = read_xml_bbox(label_path_list)
    elif args.label_type == 'txt':
        bbox_list = read_txt_bbox(label_path_list, image_path_list)
    else:
        bbox_list = read_txt_rotate_bbox(label_path_list, image_path_list)

    # Output labels
    if args.output_label_type and args.output_label_dir:
        if not os.path.exists(args.output_label_dir):
            os.makedirs(args.output_label_dir)

        if args.output_label_type == 'xml':
            save_xml(bbox_list, args.output_label_dir, image_path_list)
        elif args.output_label_type == 'txt':
            save_txt(bbox_list, args.output_label_dir, image_path_list)
        else:
            save_txt_rotate(bbox_list, args.output_label_dir, image_path_list)
    else:
        logging.warning('Since the directory or type of output labels is not specified, '
                        'skip output labels.')

    # Draw bounding boxes
    if args.draw_dir:
        if not os.path.exists(args.draw_dir):
            os.makedirs(args.draw_dir)
        for image_path, bbox in tqdm(zip(image_path_list, bbox_list),
                                     total=len(image_path_list), desc='Drawing bounding boxes'):
            draw_one_image(image_path, bbox,
                           os.path.join(args.draw_dir, os.path.split(image_path)[-1]))
    else:
        logging.warning('Since the directory of drawn images is not specified, '
                        'skip drawing bounding boxes.')

    # Concatenate images
    if args.draw_dir and args.concat_dir:
        if not os.path.exists(args.concat_dir):
            os.makedirs(args.concat_dir)
        labeled_image_path_list = [os.path.join(args.draw_dir, name + '.' + args.image_type)
                                   for name in common_names]
        concat_all_images(labeled_image_path_list, args.concat_dir,
                          args.concat_row, args.concat_col, args.concat_width, args.concat_height)
    else:
        logging.warning('Since draw directory or concat directory is not specified, '
                        'skip concatenating labels.')
#coding:utf-8
 
# pip install lxml
 
import os
import glob
import json
import shutil
import numpy as np
import xml.etree.ElementTree as ET
 
 
 
path2 = "."
image_geshi = ".png"
 
START_BOUNDING_BOX_ID = 1
 
 
def get(root, name):
    return root.findall(name)
 
 
def get_and_check(root, name, length):
    vars = root.findall(name)
    if len(vars) == 0:
        raise NotImplementedError('Can not find %s in %s.'%(name, root.tag))
    if length > 0 and len(vars) != length:
        raise NotImplementedError('The size of %s is supposed to be %d, but is %d.'%(name, length, len(vars)))
    if length == 1:
        vars = vars[0]
    return vars
 
 
def convert(xml_list, json_file):
    json_dict = {"images": [], "type": "instances", "annotations": [], "categories": []}
    categories = pre_define_categories.copy()
    bnd_id = START_BOUNDING_BOX_ID
    all_categories = {}
    for index, line in enumerate(xml_list):
        # print("Processing %s"%(line))
        xml_f = line
        tree = ET.parse(xml_f)
        root = tree.getroot()
        
        filename = os.path.basename(xml_f)[:-4] + image_geshi
        image_id = 20190000001 + index
        size = get_and_check(root, 'size', 1)
        width = int(get_and_check(size, 'width', 1).text)
        height = int(get_and_check(size, 'height', 1).text)
        image = {'file_name': filename, 'height': height, 'width': width, 'id':image_id}
        json_dict['images'].append(image)
        ## Cruuently we do not support segmentation
        #  segmented = get_and_check(root, 'segmented', 1).text
        #  assert segmented == '0'
        for obj in get(root, 'object'):
            category = get_and_check(obj, 'name', 1).text
            if category in all_categories:
                all_categories[category] += 1
            else:
                all_categories[category] = 1
            if category not in categories:
                if only_care_pre_define_categories:
                    continue
                new_id = len(categories) + 1
                print("[warning] category '{}' not in 'pre_define_categories'({}), create new id: {} automatically".format(category, pre_define_categories, new_id))
                categories[category] = new_id
            category_id = categories[category]
            bndbox = get_and_check(obj, 'bndbox', 1)
            xmin = int(float(get_and_check(bndbox, 'xmin', 1).text))
            ymin = int(float(get_and_check(bndbox, 'ymin', 1).text))
            xmax = int(float(get_and_check(bndbox, 'xmax', 1).text))
            ymax = int(float(get_and_check(bndbox, 'ymax', 1).text))
            assert(xmax > xmin), "xmax <= xmin, {}".format(line)
            assert(ymax > ymin), "ymax <= ymin, {}".format(line)
            o_width = abs(xmax - xmin)
            o_height = abs(ymax - ymin)
            ann = {'area': o_width*o_height, 'iscrowd': 0, 'image_id':
                   image_id, 'bbox':[xmin, ymin, o_width, o_height],
                   'category_id': category_id, 'id': bnd_id, 'ignore': 0,
                   'segmentation': []}
            json_dict['annotations'].append(ann)
            bnd_id = bnd_id + 1
 
    for cate, cid in categories.items():
        cat = {'supercategory': 'none', 'id': cid, 'name': cate}
        json_dict['categories'].append(cat)
    json_fp = open(json_file, 'w')
    json_str = json.dumps(json_dict)
    json_fp.write(json_str)
    json_fp.close()
    print("------------create {} done--------------".format(json_file))
    print("find {} categories: {} -->>> your pre_define_categories {}: {}".format(len(all_categories), all_categories.keys(), len(pre_define_categories), pre_define_categories.keys()))
    print("category: id --> {}".format(categories))
    print(categories.keys())
    print(categories.values())
 
 
if __name__ == '__main__':
    with open(r"class.txt", "r", encoding="utf-8") as f:
        a = [i.strip() for i in f.readlines()]
    classes = a
    pre_define_categories = {}
    for i, cls in enumerate(classes):
        pre_define_categories[cls] = i + 1
    # pre_define_categories = {'a1': 1, 'a3': 2, 'a6': 3, 'a9': 4, "a10": 5}
    only_care_pre_define_categories = True
    # only_care_pre_define_categories = False
 
    train_ratio = 0.7
    val_ratio = 0.2
    test_ratio = 0.1
    save_json_train = 'instances_train2017.json'
    save_json_val = 'instances_val2017.json'
    save_json_test = 'instances_test2017.json'
    xml_dir = "./dx_coco_sf60_crop_np/tmp/"
 
    xml_list = glob.glob(xml_dir + "/*.xml")
    xml_list = np.sort(xml_list)
    np.random.seed(100)
    np.random.shuffle(xml_list)
 
    train_num = int(len(xml_list)*train_ratio)
    val_num = int(len(xml_list) * val_ratio)
    xml_list_train = xml_list[:train_num]
    xml_list_val = xml_list[train_num:train_num+val_num]
    xml_list_test = xml_list[train_num+val_num:]

 
    convert(xml_list_train, save_json_train)
    convert(xml_list_val, save_json_val)
    convert(xml_list_test, save_json_test)
 
    if os.path.exists(path2 + "/annotations"):
        shutil.rmtree(path2 + "/annotations")
    os.makedirs(path2 + "/annotations")
    if os.path.exists(path2 + "/train2017"):
        shutil.rmtree(path2 + "/train2017")
    os.makedirs(path2 + "/train2017")
    if os.path.exists(path2 + "/val2017"):
        shutil.rmtree(path2 +"/val2017")
    os.makedirs(path2 + "/val2017")
    if os.path.exists(path2 + "/test2017"):
        shutil.rmtree(path2 + "/test2017")
    os.makedirs(path2 + "/test2017")
 
    f1 = open("train.txt", "w")
    for xml in xml_list_train:
        img = xml[:-4] + image_geshi
        f1.write(os.path.basename(xml)[:-4] + "\n")
        shutil.copyfile(img, path2 + "/train2017/" + os.path.basename(img))

    f2 = open("val.txt", "w")
    for xml in xml_list_val:
        img = xml[:-4] + image_geshi
        f2.write(os.path.basename(xml)[:-4] + "\n") 
        shutil.copyfile(img, path2 + "/val2017/" + os.path.basename(img))

    f3 = open("test.txt", "w")
    for xml in xml_list_test:
        img = xml[:-4] + image_geshi
        f3.write(os.path.basename(xml)[:-4] + "\n")
        shutil.copyfile(img, path2 + "/test2017/" + os.path.basename(img))
    f1.close()
    f2.close()
    f3.close()
    print("-------------------------------")
    print("train number:", len(xml_list_train))
    print("val number:", len(xml_list_val))
    print("test number:", len(xml_list_test))

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

MMdetection数据集格式转换——LabelImg/xml/yolo格式转Custom自定义格式数据集/Coco格式数据集 的相关文章

  • 使用 Mac M1 在 Docker 容器内的 pip 安装中找不到 Tensorflow

    我正在尝试使用新的 Mac M1 运行一些项目 这些项目已经在英特尔处理器上运行 并被使用英特尔的其他开发人员使用 我无法构建这个简单的 Dockerfile FROM python 3 9 RUN python m pip install
  • Heroku 上的 Django 应用程序在一段时间后删除对象

    我编写了一个简单的 Django 问答论坛应用程序并将其部署在 Heroku 上 该网站的本地版本运行良好 但是 生产版本不会将问题 答案等存储超过几个小时 我决定坚持使用 Django 附带的 sqlite3 我预计该网站不会有太多流量
  • SparkSession 初始化需要很长时间

    SparkSession 初始化需要很长时间才能成功 这是我的代码 import findspark findspark init import pyspark from pyspark sql import SparkSession sp
  • 来自 Pandas DataFrame 的用户定义的 Json 格式

    我有一个 pandas dataFrame 打印 pandas DataFrame 后 结果如下所示 country branch no of employee total salary count DOB count email x a
  • ImportError:无法导入名称 GstRtspServer,未找到内省类型库

    我目前正在尝试让一个简单的 GstRtspServer 程序在外部亚马逊 Linux EC2 服务器上运行 但在让它实际运行时遇到了严重的问题 无论我做什么 当我尝试运行它时 即使程序仅减少到 import gi gi require ve
  • 忽略覆盖率报告中的空文件

    覆盖率 py https github com nedbat coveragepy会包括 init py在其报告中并将其显示为 0 行 但覆盖率为 100 我想从覆盖率报告中排除所有空白文件 我不能只添加 init py to omit作为
  • 群组名称不能以数字开头?

    看来我不能使用像这样的正则表达式 P lt 74xxx gt 0 9 重新打包会引发错误 sre constants error bad character in group name u 74xxx 我似乎无法使用以数字开头的组名称 为什
  • sqlalchemy,使用反向包含(不在)子列值列表中进行选择

    我在flask sqlalchemy 中有一个典型的帖子 标签 与一篇帖子相关的许多标签 关系 并且我想选择我提供的列表中未标记任何标签的帖子 首先 我建立的模型 class Post db Model id db Column db In
  • PyGTK TreeView 中的自动换行

    如何在 PyGTK TreeView 中自动换行文本 gtk TreeView 中的文本是使用 gtk CellRendererText 渲染的 文本换行归结为在单元格渲染器上设置正确的属性 为了让文本换行 您需要设置wrap width单
  • 从 Java 调用 Python 代码时出现问题(不使用 jython)

    我发现这是从 java 运行 使用 exec 方法 python 脚本的方法之一 我在 python 文件中有一个简单的打印语句 但是 我的程序在运行时什么也没做 它既不打印Python文件中编写的语句 也不抛出异常 程序什么都不做就终止了
  • Pandas 对 HDFStore 中的大数据进行“分组”查询?

    我有大约 700 万行HDFStore有60多个柱子 数据超出了我的记忆能力 我希望根据 A 列的值将数据聚合到组中 pandas 的文档分割 聚合 组合 http pandas pydata org pandas docs stable
  • 如何更改Python中的全局变量[重复]

    这个问题在这里已经有答案了 我正在尝试更改程序中的变量 我在程序开始时声明了一个全局变量 我想在程序中的不同函数中更改该变量 我可以通过再次声明函数内的变量来做到这一点 但我想知道是否有更好的方法来做到这一点 下面是一些测试代码来解释我的意
  • 函数调用中的星号[重复]

    这个问题在这里已经有答案了 我正在使用 itertools chain 以这种方式 展平 列表列表 uniqueCrossTabs list itertools chain uniqueCrossTabs 这与说有什么不同 uniqueCr
  • 在IPython笔记本中自动播放声音

    我经常在 IPython 笔记本中运行长时间运行的单元 我希望笔记本在单元完成执行时自动发出蜂鸣声或播放声音 有没有办法在 iPython 笔记本中执行此操作 或者我可以在单元格末尾放置一些命令来自动播放声音 我正在使用 Chrome 如果
  • 为什么我只能在异步函数中使用await关键字?

    假设我有这样的代码 async def fetch text gt str return text async def show something something await fetch text print something 这很
  • 计算列表中的子列表

    L 2 4 5 6 2 1 6 6 3 2 4 5 3 4 5 我想知道任意子序列出现了多少次 s 2 4 5 例如会返回2次 I tried L count s 但它不起作用 因为我认为它期望寻找类似的东西 random numbers
  • 类型错误:对于仅使用浮点数的函数,返回数组必须是 ArrayType

    这个实在是难倒我了 我有一个计算单词权重的函数 我已经确认 a 和 b 局部变量都是 float 类型 def word weight term a term freq term print a type a b idf term prin
  • python生成器太慢,无法使用它。我为什么要使用它?什么时候?

    最近我收到一个问题 哪一个是最快的 iterator list comprehension iter list comprehension and generator 然后编写简单的代码如下 n 1000000 iter a iter ra
  • Pepper Robot:如何将 Python 地标检测移植到 Choregraphe?

    我正在尝试编写一个小程序 让 Pepper 通过 Choregraphe 检查房间内的地标 用于地标检测的常规 Python 代码工作得很好 但我无法将其移植到 Choregraphe http doc aldebaran com 2 5
  • 关闭 IPython Notebook 中的自动保存

    我正在寻找一种方法来关闭 iPython 笔记本中的自动保存 我已经通过 Google Stack Overflow 搜索看到了有关如何打开自动保存的参考资料 但我想要相反的内容 关闭自动保存 如果这是可以永久设置的东西而不是在每个笔记本的

随机推荐