









# coding:utf-8

# pip install lxml

import glob
import json
import shutil
import numpy as np
import pickle
import os
import argparse
import xml.etree.ElementTree as ET
from pathlib import Path
def get(root, name):
    return root.findall(name)

def get_and_check(root, name, length):
    vars = root.findall(name)
    if len(vars) == 0:
        raise NotImplementedError('Can not find %s in %s.' % (name, root.tag))
    if length > 0 and len(vars) != length:
        raise NotImplementedError('The size of %s is supposed to be %d, but is %d.' % (name, length, len(vars)))
    if length == 1:
        vars = vars[0]
    return vars

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--image_xml_dir', type=str,
                        help='Directory of images and xml.')
    parser.add_argument('--image_type', type=str, default='.png',
                        choices=['.jpg', '.png'], help='Type of image file.')
    parser.add_argument('--output_dir', type=str,
                        help='Directory of output.')
    image_geshi = a.image_type  # 设置图片的后缀名为png
    origin_ann_dir = a.image_xml_dir  # 设置存放所以xml和图片路径为tmp
    path2 =a.output_dir
    classes = []
    for dirpaths, dirnames, filenames in os.walk(origin_ann_dir):  # os.walk游走遍历目录名
        for filename in filenames:
            if filename.endswith('.xml'):
                if os.path.isfile(os.path.join(origin_ann_dir, filename)):  # 获取原始xml文件绝对路径,isfile()检测是否为文件 isdir检测是否为目录
                    origin_ann_path = os.path.join(r'%s%s' % (origin_ann_dir, filename))  # 如果是,获取绝对路径(重复代码)
                    # new_ann_path = os.path.join(r'%s%s' %(new_ann_dir, filename))
                    tree = ET.parse(origin_ann_path)  # ET是一个xml文件解析库,ET.parse()打开xml文件。parse--"解析"
                    root = tree.getroot()  # 获取根节点
                    for object in root.findall('object'):
                        xmlbox = object.find('bndbox')  # 找到根节点下所有“object”节点
                        name = str(object.find('name').text)  # 找到object节点下name子节点的值(字符串)
                        if name not in classes:
    with open(path2 + r"/class.txt", "w") as f:


    train_ratio = 0.7
    val_ratio = 0.2
    test_ratio = 0.1
    xml_dir = origin_ann_dir

    xml_list = glob.glob(xml_dir + "/*.xml")
    xml_list = np.sort(xml_list)

    train_num = int(len(xml_list) * train_ratio)
    val_num = int(len(xml_list) * val_ratio)
    xml_list_train = xml_list[:train_num]
    xml_list_val = xml_list[train_num:train_num + val_num]
    xml_list_test = xml_list[train_num + val_num:]
    f1 = open(path2 + r"/train.txt", "w")
    for xml in xml_list_train:
        img = xml[:-4] + image_geshi
        f1.write(os.path.basename(xml)[:-4] + "\n")
    f2 = open(path2+ r"/val.txt", "w")
    for xml in xml_list_val:
        img = xml[:-4] + image_geshi
        f2.write(os.path.basename(xml)[:-4] + "\n")
    f3 = open(path2 + r"/test.txt", "w")
    for xml in xml_list_test:
        img = xml[:-4] + image_geshi
        f3.write(os.path.basename(xml)[:-4] + "\n")
    print("train number:", len(xml_list_train))
    print("val number:", len(xml_list_val))
    print("test number:", len(xml_list_test))

    with open(path2 + r"\train.txt", "r", encoding="utf-8") as f:
        paths = [i.strip() for i in f.readlines()]

    path3 = path2 +'/train'
    if os.path.exists(path3):
    dst_dir=path2 + "/train"
    for i in paths:
        img_path =xml_dir+i+image_geshi

    with open(path2 + r"/class.txt", "r", encoding="utf-8") as f:
        a = [i.strip() for i in f.readlines()]
    classes = a
    xml_dir1 = path2 +"/train/"
    xml_list = glob.glob(xml_dir1 + "/*.xml")
    xml_list = np.sort(xml_list)
    pre_define_categories = {}
    pkl_dict = []
    categories = pre_define_categories.copy()
    all_categories = {}
    number = 0
    for index, line in enumerate(xml_list):
        box_data = []
        box_data1 = []
        labels_data = []
        xml_f = line
        tree = ET.parse(xml_f)
        root = tree.getroot()
        number = number + 1
        filename = os.path.basename(xml_f)[:-4] + image_geshi
        image_id = 20190000001 + index
        size = get_and_check(root, 'size', 1)
        width = int(get_and_check(size, 'width', 1).text)
        height = int(get_and_check(size, 'height', 1).text)
        for obj in get(root, 'object'):
            category = get_and_check(obj, 'name', 1).text
            bndbox = get_and_check(obj, 'bndbox', 1)
            xmin = int(float(get_and_check(bndbox, 'xmin', 1).text))
            ymin = int(float(get_and_check(bndbox, 'ymin', 1).text))
            xmax = int(float(get_and_check(bndbox, 'xmax', 1).text))
            ymax = int(float(get_and_check(bndbox, 'ymax', 1).text))
            assert (xmax > xmin), "xmax <= xmin, {}".format(line)
            assert (ymax > ymin), "ymax <= ymin, {}".format(line)
            o_width = abs(xmax - xmin)
            o_height = abs(ymax - ymin)
            box_data1 = [xmin, ymin, xmax, ymax]
            name = int(obj.find('name').text)
        image = {'filename': filename, 'width': width, 'height': height, 'ann': {'bboxes': np.array(box_data),
                                                                                      'labels': np.array(labels_data).T}}
    with open(path2 + r"/train.pkl","wb") as f:

    # -------------------------------------
    with open(path2 + r"/val.txt", "r", encoding="utf-8") as f:
        paths = [i.strip() for i in f.readlines()]

    path3 =path2 + '/val'
    if os.path.exists(path3):
    dst_dir=path2 + "/val"
    for i in paths:
        img_path =xml_dir+i+image_geshi

    with open(path2 + r"/class.txt", "r", encoding="utf-8") as f:
        a = [i.strip() for i in f.readlines()]
    classes = a
    xml_dir2 = path2 + "/val/"
    xml_list = glob.glob(xml_dir2 + "/*.xml")
    xml_list = np.sort(xml_list)
    pre_define_categories = {}
    pkl_dict = []
    categories = pre_define_categories.copy()
    all_categories = {}
    number = 0
    for index, line in enumerate(xml_list):
        box_data = []
        box_data1 = []
        labels_data = []
        xml_f = line
        tree = ET.parse(xml_f)
        root = tree.getroot()
        number = number + 1
        filename = os.path.basename(xml_f)[:-4] + image_geshi
        image_id = 20190000001 + index
        size = get_and_check(root, 'size', 1)
        width = int(get_and_check(size, 'width', 1).text)
        height = int(get_and_check(size, 'height', 1).text)
        for obj in get(root, 'object'):
            category = get_and_check(obj, 'name', 1).text
            bndbox = get_and_check(obj, 'bndbox', 1)
            xmin = int(float(get_and_check(bndbox, 'xmin', 1).text))
            ymin = int(float(get_and_check(bndbox, 'ymin', 1).text))
            xmax = int(float(get_and_check(bndbox, 'xmax', 1).text))
            ymax = int(float(get_and_check(bndbox, 'ymax', 1).text))
            assert (xmax > xmin), "xmax <= xmin, {}".format(line)
            assert (ymax > ymin), "ymax <= ymin, {}".format(line)
            o_width = abs(xmax - xmin)
            o_height = abs(ymax - ymin)
            box_data1 = [xmin, ymin, xmax, ymax]
            name = int(obj.find('name').text)
        image = {'filename': filename, 'width': width, 'height': height, 'ann': {'bboxes': np.array(box_data),
                                                                                      'labels': np.array(labels_data).T}}
    with open(path2 + r"\val.pkl","wb") as f:

    # -------------------------------------------
    with open(path2 + r"\test.txt", "r", encoding="utf-8") as f:
        paths = [i.strip() for i in f.readlines()]

    path3 =path2 +  r'/test'
    if os.path.exists(path3):
    dst_dir=path2 + "/test"
    for i in paths:
        img_path =xml_dir+i+image_geshi
        xml_path =xml_dir +i+".xml"

    with open(path2 + r"/class.txt", "r", encoding="utf-8") as f:
        a = [i.strip() for i in f.readlines()]
    classes = a
    xml_dir3 =path2 +  "/test/"
    xml_list = glob.glob(xml_dir3 + "/*.xml")
    xml_list = np.sort(xml_list)
    pre_define_categories = {}
    pkl_dict = []
    categories = pre_define_categories.copy()
    all_categories = {}
    number = 0
    for index, line in enumerate(xml_list):
        box_data = []
        box_data1 = []
        labels_data = []
        xml_f = line
        tree = ET.parse(xml_f)
        root = tree.getroot()
        number = number + 1
        filename = os.path.basename(xml_f)[:-4] + image_geshi
        image_id = 20190000001 + index
        size = get_and_check(root, 'size', 1)
        width = int(get_and_check(size, 'width', 1).text)
        height = int(get_and_check(size, 'height', 1).text)
        for obj in get(root, 'object'):
            category = get_and_check(obj, 'name', 1).text
            bndbox = get_and_check(obj, 'bndbox', 1)
            xmin = int(float(get_and_check(bndbox, 'xmin', 1).text))
            ymin = int(float(get_and_check(bndbox, 'ymin', 1).text))
            xmax = int(float(get_and_check(bndbox, 'xmax', 1).text))
            ymax = int(float(get_and_check(bndbox, 'ymax', 1).text))
            assert (xmax > xmin), "xmax <= xmin, {}".format(line)
            assert (ymax > ymin), "ymax <= ymin, {}".format(line)
            o_width = abs(xmax - xmin)
            o_height = abs(ymax - ymin)
            box_data1 = [xmin, ymin, xmax, ymax]
            name = int(obj.find('name').text)
        image = {'filename': filename, 'width': width, 'height': height, 'ann': {'bboxes': np.array(box_data),
                                                                                      'labels': np.array(labels_data).T}}
    with open(path2 + r"/test.pkl","wb") as f:


# pylint: disable=no-member

--res_dir是指标完框的图片存放的位置 --concat_dir是指所有图片拼接之后存放的路径
--num是指拼接图片的尺寸 例如num=3,拼接完的图片为3*3,
--type是指label的类型,选项有xml_cp 、 xml_lt 、 txt_cp 、 txt_lt

import os
import argparse
import logging
from typing import List, Tuple, Union

from xml.dom import minidom
import imagesize
import cv2
import numpy as np
from tqdm import tqdm

class_names = []

def read_xml_bbox(xml_path_list: List[str]) -> List[List[Union[List, Tuple]]]:
    """Get bounding boxes from XML files"""
    ans = []
    for curr_xml_path in tqdm(xml_path_list, desc='Loading XML files'):
        curr_image = []
        dom = minidom.parse(curr_xml_path)
        for obj_node in dom.getElementsByTagName('object'):
            class_name = obj_node.getElementsByTagName('name')[0]

            bbox_node = obj_node.getElementsByTagName('bndbox')[0]
            xmin, ymin, xmax, ymax = (
                int(float(bbox_node.getElementsByTagName(name)[0] - 1
                for name in ['xmin', 'ymin', 'xmax', 'ymax'])  # x in [0, width), y in [0, height)

            curr_image.append([class_name, (xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
    return ans

def save_xml(bbox_list, save_dir, image_path_list):
    Save bounding boxes for all images
    :param bbox_list: Bounding box list
    :param save_dir: Save directory
    :param image_path_list: Image path list
    :return: None
    for bounding_boxes, image_path in tqdm(zip(bbox_list, image_path_list),
                                           desc='Saving XML files'):
        doc = minidom.Document()

        root_node = doc.createElement('annotation')

        create_element_with_text(doc, root_node, 'folder', os.path.split(image_path)[-2])
        create_element_with_text(doc, root_node, 'filename', os.path.split(image_path)[-1])
        create_element_with_text(doc, root_node, 'path', image_path)

        source_node = doc.createElement('source')

        create_element_with_text(doc, source_node, 'database', 'Unknown')


        size_node = doc.createElement('size')
        for element_name, value in zip(['width', 'height', 'depth'],
                                       [*imagesize.get(image_path), 3]):
            elem = doc.createElement(element_name)

        create_element_with_text(doc, root_node, 'segmented', '0')

        for bbox in bounding_boxes:
            obj_node = doc.createElement('object')

            create_element_with_text(doc, obj_node, 'name', get_class_name(bbox[0]))
            create_element_with_text(doc, obj_node, 'pose', 'Unspecified')
            create_element_with_text(doc, obj_node, 'truncated', '0')
            create_element_with_text(doc, obj_node, 'difficult', '0')

            bndbox_node = doc.createElement('bndbox')
            for element_name, value in zip(['xmin', 'ymin', 'xmax', 'ymax'],
                                           [min(bbox[1:], key=lambda x: x[0])[0],
                                            min(bbox[1:], key=lambda x: x[1])[1],
                                            max(bbox[1:], key=lambda x: x[0])[0],
                                            max(bbox[1:], key=lambda x: x[1])[1]]):
                create_element_with_text(doc, bndbox_node, element_name, str(value + 1))


        with open(os.path.join(save_dir, os.path.split(image_path)[-1][:-4] + '.xml'),
                  'w', encoding='utf-8') as file:
            doc.writexml(file, indent='', addindent='\t', newl='\n', encoding='utf-8')

def create_element_with_text(doc, node, element_name, text):
    element_node = doc.createElement(element_name)

def read_txt_bbox(txt_path_list: List[str],
                  image_path_list: List[str]) -> List[List[Union[List, Tuple]]]:
    """Get bounding boxes from TXT files"""
    ans = []
    for i in tqdm(range(len(txt_path_list)), desc='Loading TXT files'):
        curr_image = []
        with open(txt_path_list[i], 'r') as file:
            s =
        # image_height, image_width = \
        #     cv2.imdecode(np.fromfile(image_path_list[i], dtype=np.uint8), -1).shape[:2]
        image_width, image_height = imagesize.get(image_path_list[i])
        for line in s.strip().split('\n'):
            if not line:
            class_id, x, y, w, h = line.split()
            class_id = int(class_id)
            x, y = float(x) * image_width, float(y) * image_height
            w, h = float(w) * image_width, float(h) * image_height
            xmin, ymin, xmax, ymax = \
                int(x - w / 2), int(y - h / 2), int(x + w / 2) - 1, int(y + h / 2) - 1
            curr_image.append([class_id, (xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
    return ans

def save_txt(bbox_list, output_label_dir, image_path_list):
    for bounding_boxes, image_path in tqdm(zip(bbox_list, image_path_list),
                                           desc='Saving TXT files'):
        image_width, image_height = imagesize.get(image_path)

        txt_path = os.path.join(output_label_dir, os.path.split(image_path)[-1][:-4] + '.txt')
        with open(txt_path, 'w') as file:
            for bbox in bounding_boxes:
                xmin = min(bbox[1:], key=lambda x: x[0])[0]
                ymin = min(bbox[1:], key=lambda x: x[1])[1]
                xmax = max(bbox[1:], key=lambda x: x[0])[0]
                ymax = max(bbox[1:], key=lambda x: x[1])[1]
                cx, cy = (xmin + xmax) / 2, (ymin + ymax) / 2
                w, h = xmax - xmin, ymax - ymin
                file.write(str(class_names.index(bbox[0])) + ' ' +
                           str(cx / image_width) + ' ' +
                           str(cy / image_height) + ' ' +
                           str(w / image_width) + ' ' +
                           str(h / image_height) + '\n')

def read_txt_rotate_bbox(txt_path_list, image_path_list):
    """Get bounding boxes from TXT files with rotation"""
    ans = []
    for i in tqdm(range(len(txt_path_list)), desc='Loading TXT files'):
        curr_image = []
        with open(txt_path_list[i], 'r') as file:
            s =
        # image_height, image_width = \
        #     cv2.imdecode(np.fromfile(image_path_list[i], dtype=np.uint8), -1).shape[:2]
        image_width, image_height = imagesize.get(image_path_list[i])
        for line in s.strip().split('\n'):
            if not line:
            class_id, cx, cy, w, h, theta = line.split()
            class_id = int(class_id)
            cx, cy = float(cx) * image_width, float(cy) * image_height
            w, h = float(w) * image_width, float(h) * image_height
            theta = float(theta)
            xmin, ymin, xmax, ymax = \
                int(cx - w / 2), int(cy - h / 2), int(cx + w / 2) - 1, int(cy + h / 2) - 1
            points = []
            for p in [(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)]:
                p = rotated_point(p, (cx, cy), theta)
                points.append((int(p[0]), int(p[1])))
            curr_image.append([class_id, *points])
    return ans

def save_txt_rotate(bbox_list, output_label_dir, image_path_list):
    for bounding_boxes, image_path in tqdm(zip(bbox_list, image_path_list),
                                           desc='Saving TXT files'):
        image_width, image_height = imagesize.get(image_path)

        txt_path = os.path.join(output_label_dir, os.path.split(image_path)[-1][:-4] + '.txt')
        with open(txt_path, 'w') as file:
            for bbox in bounding_boxes:
                rect = cv2.minAreaRect(np.array(bbox[1:]))
                (cx, cy), (w, h), theta = rect
                file.write(str(get_class_id(bbox[0])) + ' ' +
                           str(cx / image_width) + ' ' +
                           str(cy / image_height) + ' ' +
                           str(w / image_width) + ' ' +
                           str(h / image_height) + ' ' +
                           str(180 - theta) + '\n')

def rotated_point(p, q, theta):
    """Return The coordinate of point p after
    rotating counterclockwise by angle theta around point q."""
    # The order of the coordinate axes in the coordinate system of the image
    # is opposite to that in the Cartesian coordinate system, so we need to reverse the angle.
    theta = -theta
    x = (p[0] - q[0]) * np.cos(np.deg2rad(theta)) - (p[1] - q[1]) * np.sin(np.deg2rad(theta)) + q[0]
    y = (p[0] - q[0]) * np.sin(np.deg2rad(theta)) + (p[1] - q[1]) * np.cos(np.deg2rad(theta)) + q[1]
    return x, y

def get_path_list(root_dir):
    """Return all file names in the specified directory"""
    return [os.path.join(root_dir, file_name) for file_name in os.listdir(root_dir)]

def read_class_names(class_path):
    global class_names
    with open(class_path, 'r') as file:
        class_names ='\n')

def get_class_name(class_id: Union[str, int]) -> str:
    global class_names
    if isinstance(class_id, str):
        return class_id
    return class_names[class_id]

def get_class_id(class_name: Union[str, int]) -> int:
    global class_names
    if isinstance(class_name, int):
        return class_name
    return class_names.index(class_name)

def draw_one_image(image_path, bbox_list, output_path):
    Draw all bounding boxes in one image.

    :param image_path: Image path
    :param bbox_list: Bounding box list
    :param output_path: Output path to the final image
    :return: None
    image = cv2.imread(image_path)
    for bbox in bbox_list:
        # draw current bounding box
        for i in range(2, len(bbox)):
            cv2.line(image, bbox[i - 1], bbox[i], (0, 255, 0), 2)
        cv2.line(image, bbox[-1], bbox[1], (0, 255, 0), 2)

        # draw class name
        top_left = min(bbox[1:])
        class_name = get_class_name(bbox[0])
        x, y = min((top_left[0], image.shape[1] - 16 * len(class_name))), max(top_left[1], 20)
        cv2.putText(image, class_name, (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.75,
                    (0, 255, 0), 1, cv2.LINE_AA, False)

    cv2.imwrite(output_path, image)

def concat_images_into_one(image_path_list, output_path, row, col, width, height):
    Concatenate images into one image.

    :param image_path_list: paths of input images
    :param output_path: output image path
    :param row: Number of images in one column of the output image
    :param col: Number of images in one row of the output image
    :param width: width of one column
    :param height: height of one row
    :return: None
    if row * col < len(image_path_list):
        raise ValueError('Cannot concat too many images')
    output_image = np.zeros((height * row, width * col, 3))
    for i in range(row):
        for j in range(col):
            idx = i * col + j
            if idx >= len(image_path_list):
                cv2.imwrite(output_path, output_image)
            image = cv2.imdecode(np.fromfile(image_path_list[idx], dtype=np.uint8), -1)
            curr_image = cv2.resize(image, (width, height))
            output_image[height * i:height * (i + 1), width * j:width * (j + 1)] = curr_image
    cv2.imwrite(output_path, output_image)

def concat_all_images(image_path_list, output_dir, row, col, width, height):
    Concatenate images in groups.

    :param image_path_list: paths of input images
    :param output_dir: output image path
    :param row: Number of images in one column of the output image
    :param col: Number of images in one row of the output image
    :param width: width of one column
    :param height: height of one row
    :return: None
    i = 0
    count = 0
    pbar = tqdm(total=len(image_path_list), desc='Concatenating')

    while i + row * col - 1 < len(image_path_list):
        concat_images_into_one(image_path_list[i:i + row * col],
                               os.path.join(output_dir, str(count) + '.png'),
                               row, col, width, height)
        i += row * col
        count += 1
        pbar.update(row * col)

    if i + col - 1 < len(image_path_list):
        curr_row = (len(image_path_list) - i) // col
        concat_images_into_one(image_path_list[i:i + curr_row * col],
                               os.path.join(output_dir, str(count) + '.png'),
                               curr_row, col, width, height)
        i += curr_row * col
        count += 1
        pbar.update(curr_row * col)
    if i < len(image_path_list):
                               os.path.join(output_dir, str(count) + '.png'),
                               1, len(image_path_list) - i, width, height)
        pbar.update(len(image_path_list) - i)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--image_dir', type=str, default='resize_img_30_new/resize_img_30_new',
                        help='Directory of images.')
    parser.add_argument('--image_type', type=str, default='png',
                        choices=['jpg', 'png'], help='Type of image file.')
    parser.add_argument('--label_dir', type=str, default='xml_resize_30_2',
                        help='Directory of labels.')
    parser.add_argument('--label_type', type=str, default='xml',
                        choices=['xml', 'txt', 'txt_rotate'], help='Type of input label file.')

    parser.add_argument('--class_path', type=str, default='class.txt',
                        help='Path to the file which stores class names.')

    parser.add_argument('--draw_dir', type=str, default='',
                        help='Output directory of drawn images.')

    parser.add_argument('--concat_dir', type=str, default='',
                        help='Output directory of concatenated images.')
    parser.add_argument('--concat_row', type=int, default=4,
                        help='The number of rows of the images in each concatenated image.')
    parser.add_argument('--concat_col', type=int, default=4,
                        help='The number of columns of the images in each concatenated image.')
    parser.add_argument('--concat_width', type=int, default=800,
                        help='The width of each image in each concatenated image.')
    parser.add_argument('--concat_height', type=int, default=600,
                        help='The height of each image in each concatenated image.')

    parser.add_argument('--output_label_type', type=str, default='txt',
                        choices=['xml', 'txt', 'txt_rotate'], help='Type of output label file.')
    parser.add_argument('--output_label_dir', type=str, default='dx_crop_sf30_txt',
                        help='Directory of output label files.')
    args = parser.parse_args()

    # Get file path list
    image_path_list = get_path_list(args.image_dir)
    label_path_list = get_path_list(args.label_dir)

    # Read class names

    # Filter path
    image_path_list = list(filter(lambda p: p[-4:] == '.' + args.image_type, image_path_list))
    if args.label_type == 'txt_rotate':
        label_path_list = list(filter(lambda p: p[-4:] == '.txt', label_path_list))
        label_path_list = list(filter(lambda p: p[-4:] == '.' + args.label_type, label_path_list))

    # Filter uncommon files
    common_names = set(os.path.split(path)[-1][:-4] for path in label_path_list)
    common_names.intersection_update(os.path.split(path)[-1][:-4] for path in image_path_list)
    if len(common_names) < len(image_path_list) or len(common_names) < len(label_path_list):
        logging.warning('Files in the label folder and the image folder are inconsistent.')
    image_path_list = list(filter(lambda p: os.path.split(p)[-1][:-4] in common_names,
    label_path_list = list(filter(lambda p: os.path.split(p)[-1][:-4] in common_names,

    # Read bounding boxes
    if args.label_type == 'xml':
        bbox_list = read_xml_bbox(label_path_list)
    elif args.label_type == 'txt':
        bbox_list = read_txt_bbox(label_path_list, image_path_list)
        bbox_list = read_txt_rotate_bbox(label_path_list, image_path_list)

    # Output labels
    if args.output_label_type and args.output_label_dir:
        if not os.path.exists(args.output_label_dir):

        if args.output_label_type == 'xml':
            save_xml(bbox_list, args.output_label_dir, image_path_list)
        elif args.output_label_type == 'txt':
            save_txt(bbox_list, args.output_label_dir, image_path_list)
            save_txt_rotate(bbox_list, args.output_label_dir, image_path_list)
        logging.warning('Since the directory or type of output labels is not specified, '
                        'skip output labels.')

    # Draw bounding boxes
    if args.draw_dir:
        if not os.path.exists(args.draw_dir):
        for image_path, bbox in tqdm(zip(image_path_list, bbox_list),
                                     total=len(image_path_list), desc='Drawing bounding boxes'):
            draw_one_image(image_path, bbox,
                           os.path.join(args.draw_dir, os.path.split(image_path)[-1]))
        logging.warning('Since the directory of drawn images is not specified, '
                        'skip drawing bounding boxes.')

    # Concatenate images
    if args.draw_dir and args.concat_dir:
        if not os.path.exists(args.concat_dir):
        labeled_image_path_list = [os.path.join(args.draw_dir, name + '.' + args.image_type)
                                   for name in common_names]
        concat_all_images(labeled_image_path_list, args.concat_dir,
                          args.concat_row, args.concat_col, args.concat_width, args.concat_height)
        logging.warning('Since draw directory or concat directory is not specified, '
                        'skip concatenating labels.')
# pip install lxml
import os
import glob
import json
import shutil
import numpy as np
import xml.etree.ElementTree as ET
path2 = "."
image_geshi = ".png"
def get(root, name):
    return root.findall(name)
def get_and_check(root, name, length):
    vars = root.findall(name)
    if len(vars) == 0:
        raise NotImplementedError('Can not find %s in %s.'%(name, root.tag))
    if length > 0 and len(vars) != length:
        raise NotImplementedError('The size of %s is supposed to be %d, but is %d.'%(name, length, len(vars)))
    if length == 1:
        vars = vars[0]
    return vars
def convert(xml_list, json_file):
    json_dict = {"images": [], "type": "instances", "annotations": [], "categories": []}
    categories = pre_define_categories.copy()
    all_categories = {}
    for index, line in enumerate(xml_list):
        # print("Processing %s"%(line))
        xml_f = line
        tree = ET.parse(xml_f)
        root = tree.getroot()
        filename = os.path.basename(xml_f)[:-4] + image_geshi
        image_id = 20190000001 + index
        size = get_and_check(root, 'size', 1)
        width = int(get_and_check(size, 'width', 1).text)
        height = int(get_and_check(size, 'height', 1).text)
        image = {'file_name': filename, 'height': height, 'width': width, 'id':image_id}
        ## Cruuently we do not support segmentation
        #  segmented = get_and_check(root, 'segmented', 1).text
        #  assert segmented == '0'
        for obj in get(root, 'object'):
            category = get_and_check(obj, 'name', 1).text
            if category in all_categories:
                all_categories[category] += 1
                all_categories[category] = 1
            if category not in categories:
                if only_care_pre_define_categories:
                new_id = len(categories) + 1
                print("[warning] category '{}' not in 'pre_define_categories'({}), create new id: {} automatically".format(category, pre_define_categories, new_id))
                categories[category] = new_id
            category_id = categories[category]
            bndbox = get_and_check(obj, 'bndbox', 1)
            xmin = int(float(get_and_check(bndbox, 'xmin', 1).text))
            ymin = int(float(get_and_check(bndbox, 'ymin', 1).text))
            xmax = int(float(get_and_check(bndbox, 'xmax', 1).text))
            ymax = int(float(get_and_check(bndbox, 'ymax', 1).text))
            assert(xmax > xmin), "xmax <= xmin, {}".format(line)
            assert(ymax > ymin), "ymax <= ymin, {}".format(line)
            o_width = abs(xmax - xmin)
            o_height = abs(ymax - ymin)
            ann = {'area': o_width*o_height, 'iscrowd': 0, 'image_id':
                   image_id, 'bbox':[xmin, ymin, o_width, o_height],
                   'category_id': category_id, 'id': bnd_id, 'ignore': 0,
                   'segmentation': []}
            bnd_id = bnd_id + 1
    for cate, cid in categories.items():
        cat = {'supercategory': 'none', 'id': cid, 'name': cate}
    json_fp = open(json_file, 'w')
    json_str = json.dumps(json_dict)
    print("------------create {} done--------------".format(json_file))
    print("find {} categories: {} -->>> your pre_define_categories {}: {}".format(len(all_categories), all_categories.keys(), len(pre_define_categories), pre_define_categories.keys()))
    print("category: id --> {}".format(categories))
if __name__ == '__main__':
    with open(r"class.txt", "r", encoding="utf-8") as f:
        a = [i.strip() for i in f.readlines()]
    classes = a
    pre_define_categories = {}
    for i, cls in enumerate(classes):
        pre_define_categories[cls] = i + 1
    # pre_define_categories = {'a1': 1, 'a3': 2, 'a6': 3, 'a9': 4, "a10": 5}
    only_care_pre_define_categories = True
    # only_care_pre_define_categories = False
    train_ratio = 0.7
    val_ratio = 0.2
    test_ratio = 0.1
    save_json_train = 'instances_train2017.json'
    save_json_val = 'instances_val2017.json'
    save_json_test = 'instances_test2017.json'
    xml_dir = "./dx_coco_sf60_crop_np/tmp/"
    xml_list = glob.glob(xml_dir + "/*.xml")
    xml_list = np.sort(xml_list)
    train_num = int(len(xml_list)*train_ratio)
    val_num = int(len(xml_list) * val_ratio)
    xml_list_train = xml_list[:train_num]
    xml_list_val = xml_list[train_num:train_num+val_num]
    xml_list_test = xml_list[train_num+val_num:]

    convert(xml_list_train, save_json_train)
    convert(xml_list_val, save_json_val)
    convert(xml_list_test, save_json_test)
    if os.path.exists(path2 + "/annotations"):
        shutil.rmtree(path2 + "/annotations")
    os.makedirs(path2 + "/annotations")
    if os.path.exists(path2 + "/train2017"):
        shutil.rmtree(path2 + "/train2017")
    os.makedirs(path2 + "/train2017")
    if os.path.exists(path2 + "/val2017"):
        shutil.rmtree(path2 +"/val2017")
    os.makedirs(path2 + "/val2017")
    if os.path.exists(path2 + "/test2017"):
        shutil.rmtree(path2 + "/test2017")
    os.makedirs(path2 + "/test2017")
    f1 = open("train.txt", "w")
    for xml in xml_list_train:
        img = xml[:-4] + image_geshi
        f1.write(os.path.basename(xml)[:-4] + "\n")
        shutil.copyfile(img, path2 + "/train2017/" + os.path.basename(img))

    f2 = open("val.txt", "w")
    for xml in xml_list_val:
        img = xml[:-4] + image_geshi
        f2.write(os.path.basename(xml)[:-4] + "\n") 
        shutil.copyfile(img, path2 + "/val2017/" + os.path.basename(img))

    f3 = open("test.txt", "w")
    for xml in xml_list_test:
        img = xml[:-4] + image_geshi
        f3.write(os.path.basename(xml)[:-4] + "\n")
        shutil.copyfile(img, path2 + "/test2017/" + os.path.basename(img))
    print("train number:", len(xml_list_train))
    print("val number:", len(xml_list_val))
    print("test number:", len(xml_list_test))


MMdetection数据集格式转换——LabelImg/xml/yolo格式转Custom自定义格式数据集/Coco格式数据集 的相关文章

  • 使用 Mac M1 在 Docker 容器内的 pip 安装中找不到 Tensorflow

    我正在尝试使用新的 Mac M1 运行一些项目 这些项目已经在英特尔处理器上运行 并被使用英特尔的其他开发人员使用 我无法构建这个简单的 Dockerfile FROM python 3 9 RUN python m pip install
  • Heroku 上的 Django 应用程序在一段时间后删除对象

    我编写了一个简单的 Django 问答论坛应用程序并将其部署在 Heroku 上 该网站的本地版本运行良好 但是 生产版本不会将问题 答案等存储超过几个小时 我决定坚持使用 Django 附带的 sqlite3 我预计该网站不会有太多流量
  • SparkSession 初始化需要很长时间

    SparkSession 初始化需要很长时间才能成功 这是我的代码 import findspark findspark init import pyspark from pyspark sql import SparkSession sp
  • 来自 Pandas DataFrame 的用户定义的 Json 格式

    我有一个 pandas dataFrame 打印 pandas DataFrame 后 结果如下所示 country branch no of employee total salary count DOB count email x a
  • ImportError:无法导入名称 GstRtspServer,未找到内省类型库

    我目前正在尝试让一个简单的 GstRtspServer 程序在外部亚马逊 Linux EC2 服务器上运行 但在让它实际运行时遇到了严重的问题 无论我做什么 当我尝试运行它时 即使程序仅减少到 import gi gi require ve
  • 忽略覆盖率报告中的空文件

    覆盖率 py https github com nedbat coveragepy会包括 init py在其报告中并将其显示为 0 行 但覆盖率为 100 我想从覆盖率报告中排除所有空白文件 我不能只添加 init py to omit作为
  • 群组名称不能以数字开头?

    看来我不能使用像这样的正则表达式 P lt 74xxx gt 0 9 重新打包会引发错误 sre constants error bad character in group name u 74xxx 我似乎无法使用以数字开头的组名称 为什
  • sqlalchemy,使用反向包含(不在)子列值列表中进行选择

    我在flask sqlalchemy 中有一个典型的帖子 标签 与一篇帖子相关的许多标签 关系 并且我想选择我提供的列表中未标记任何标签的帖子 首先 我建立的模型 class Post db Model id db Column db In
  • PyGTK TreeView 中的自动换行

    如何在 PyGTK TreeView 中自动换行文本 gtk TreeView 中的文本是使用 gtk CellRendererText 渲染的 文本换行归结为在单元格渲染器上设置正确的属性 为了让文本换行 您需要设置wrap width单
  • 从 Java 调用 Python 代码时出现问题(不使用 jython)

    我发现这是从 java 运行 使用 exec 方法 python 脚本的方法之一 我在 python 文件中有一个简单的打印语句 但是 我的程序在运行时什么也没做 它既不打印Python文件中编写的语句 也不抛出异常 程序什么都不做就终止了
  • Pandas 对 HDFStore 中的大数据进行“分组”查询?

    我有大约 700 万行HDFStore有60多个柱子 数据超出了我的记忆能力 我希望根据 A 列的值将数据聚合到组中 pandas 的文档分割 聚合 组合 http pandas pydata org pandas docs stable
  • 如何更改Python中的全局变量[重复]

    这个问题在这里已经有答案了 我正在尝试更改程序中的变量 我在程序开始时声明了一个全局变量 我想在程序中的不同函数中更改该变量 我可以通过再次声明函数内的变量来做到这一点 但我想知道是否有更好的方法来做到这一点 下面是一些测试代码来解释我的意
  • 函数调用中的星号[重复]

    这个问题在这里已经有答案了 我正在使用 itertools chain 以这种方式 展平 列表列表 uniqueCrossTabs list itertools chain uniqueCrossTabs 这与说有什么不同 uniqueCr
  • 在IPython笔记本中自动播放声音

    我经常在 IPython 笔记本中运行长时间运行的单元 我希望笔记本在单元完成执行时自动发出蜂鸣声或播放声音 有没有办法在 iPython 笔记本中执行此操作 或者我可以在单元格末尾放置一些命令来自动播放声音 我正在使用 Chrome 如果
  • 为什么我只能在异步函数中使用await关键字?

    假设我有这样的代码 async def fetch text gt str return text async def show something something await fetch text print something 这很
  • 计算列表中的子列表

    L 2 4 5 6 2 1 6 6 3 2 4 5 3 4 5 我想知道任意子序列出现了多少次 s 2 4 5 例如会返回2次 I tried L count s 但它不起作用 因为我认为它期望寻找类似的东西 random numbers
  • 类型错误:对于仅使用浮点数的函数,返回数组必须是 ArrayType

    这个实在是难倒我了 我有一个计算单词权重的函数 我已经确认 a 和 b 局部变量都是 float 类型 def word weight term a term freq term print a type a b idf term prin
  • python生成器太慢,无法使用它。我为什么要使用它?什么时候?

    最近我收到一个问题 哪一个是最快的 iterator list comprehension iter list comprehension and generator 然后编写简单的代码如下 n 1000000 iter a iter ra
  • Pepper Robot:如何将 Python 地标检测移植到 Choregraphe?

    我正在尝试编写一个小程序 让 Pepper 通过 Choregraphe 检查房间内的地标 用于地标检测的常规 Python 代码工作得很好 但我无法将其移植到 Choregraphe http doc aldebaran com 2 5
  • 关闭 IPython Notebook 中的自动保存

    我正在寻找一种方法来关闭 iPython 笔记本中的自动保存 我已经通过 Google Stack Overflow 搜索看到了有关如何打开自动保存的参考资料 但我想要相反的内容 关闭自动保存 如果这是可以永久设置的东西而不是在每个笔记本的
