COCO数据集的下载、介绍及如何使用(数据载入及数据增广,含代码)

2023-11-05

如何使用COCO数据集
COCO数据集可以说是语义分割等计算机视觉任务中应用较为广泛的一个数据集,具体可以应用到物体识别、语义分割及目标检测等方面。我是在做语义分割方面任务时用到了COCO数据集,但本文主要讲解的是数据载入方面,因此可以通用。

一、下载COCO数据集

首先,我们要下载COCO数据集,本文主要使用的是COCO2014和COCO2017,因为是国外数据集,因此下载需要翻墙下载。
MSCOCO数据集的官网为:http://mscoco.org/
具体来说,如果想只下载COCO2017/COCO2014的话,可以不需要翻墙下载,复制以下链接打开迅雷等下载软件下载即可,网速还可以。
COCO2017 训练数据:http://images.cocodataset.org/zips/train2017.zip
http://images.cocodataset.org/annotations/annotations_trainval2017.zip
COCO2017验证数据:http://images.cocodataset.org/zips/val2017.zip
http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip
COCO2017测试数据集:http://images.cocodataset.org/zips/test2017.zip
http://images.cocodataset.org/annotations/image_info_test2017.zip

COCO2014的相关数据只需要将以上链接中的7改成4即可。

二、COCO数据集介绍

网上关于COCO数据集的介绍多如牛毛,本文就不过多的加以介绍了,简要的介绍以下。
以COCO2014为例:
下载完COCO2014后进行解压后,目录如下:

三、COCO数据集使用(数据载入)

所需环境为:

  1. numpy
  2. torch
  3. tqdm(可视化数据载入)
  4. os
  5. pycocotools(coco数据集的应用API)
  6. torchvision
  7. PIL

如何安装pycocotools

相信能用到COCO数据集做语义分割等任务的大佬们应该都能安装以上绝大多数库,这里主要讲一下如何安装pycocotools库。作者在安装这个库的时候遇到了一些问题,不过及时的解决了。
步骤如下:

  1. 首先下载cocoapi,在终端输入
git clone git@github.com:lucky-ing/cocoapi.git
  1. 此时可以看到一个叫coco的文件夹,进入coco/PythonAPI中,懒人操作如下:
cd coco/PythonAPI
  1. 开始安装,在终端输入以下命令
    如果使用的是python2:
python setup.py build_ext install

如果使用的是python3

python3 setup.py build_ext install
  1. 如果一切顺利,安装完成,即可进入下一章节具体使用,作者在安装时遇到了以下问题。
error: command 'C:\Program Files (x86)\Microsoft Visual Studio\2017\BuildTools\VC\Tools\MSVC\14.16.27023\bin\HostX86\x64\cl.exe' failed with exit status 2

解决方法很简单,在终端安装cython即可,在终端输入:

conda install cython

若是没有使用conda,在终端输入

pip install cython

COCO数据集的载入

  1. dataloader
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import trange
import os
from pycocotools.coco import COCO
from pycocotools import mask
from torchvision import transforms
import custom_transforms as tr
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


class COCOSegmentation(Dataset):
    NUM_CLASSES = 21
    CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4,
        1, 64, 20, 63, 7, 72]

    def __init__(self,
                 args,
                 base_dir=./Path/COCO/,
                 split='train',
                 year='2014'):
        super().__init__()
        ann_file = os.path.join(base_dir, 'annotations/instances_{}{}.json'.format(split, year))
        ids_file = os.path.join(base_dir, 'annotations/{}_ids_{}.pth'.format(split, year))
        self.img_dir = os.path.join(base_dir, 'images/{}{}'.format(split, year))
        self.split = split
        self.coco = COCO(ann_file)
        self.coco_mask = mask
        if os.path.exists(ids_file):
            self.ids = torch.load(ids_file)
        else:
            ids = list(self.coco.imgs.keys())
            self.ids = self._preprocess(ids, ids_file)
        self.args = args

    def __getitem__(self, index):
        _img, _target = self._make_img_gt_point_pair(index)
        sample = {'image': _img, 'label': _target}

        if self.split == "train":
            return self.transform_tr(sample)
        elif self.split == 'val':
            return self.transform_val(sample)

    def _make_img_gt_point_pair(self, index):
        coco = self.coco
        img_id = self.ids[index]
        img_metadata = coco.loadImgs(img_id)[0]
        path = img_metadata['file_name']
        _img = Image.open(os.path.join(self.img_dir, path)).convert('RGB')
        cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id))
        _target = Image.fromarray(self._gen_seg_mask(
            cocotarget, img_metadata['height'], img_metadata['width']))

        return _img, _target

    def _preprocess(self, ids, ids_file):
        print("Preprocessing mask, this will take a while. " + \
              "But don't worry, it only run once for each split.")
        tbar = trange(len(ids))
        new_ids = []
        for i in tbar:
            img_id = ids[i]
            cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id))
            img_metadata = self.coco.loadImgs(img_id)[0]
            mask = self._gen_seg_mask(cocotarget, img_metadata['height'],
                                      img_metadata['width'])
            # more than 1k pixels
            if (mask > 0).sum() > 1000:
                new_ids.append(img_id)
            tbar.set_description('Doing: {}/{}, got {} qualified images'. \
                                 format(i, len(ids), len(new_ids)))
        print('Found number of qualified images: ', len(new_ids))
        torch.save(new_ids, ids_file)
        return new_ids

    def _gen_seg_mask(self, target, h, w):
        mask = np.zeros((h, w), dtype=np.uint8)
        coco_mask = self.coco_mask
        for instance in target:
            rle = coco_mask.frPyObjects(instance['segmentation'], h, w)
            m = coco_mask.decode(rle)
            cat = instance['category_id']
            if cat in self.CAT_LIST:
                c = self.CAT_LIST.index(cat)
            else:
                continue
            if len(m.shape) < 3:
                mask[:, :] += (mask == 0) * (m * c)
            else:
                mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8)
        return mask

    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
            tr.RandomGaussianBlur(),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])

        return composed_transforms(sample)

    def transform_val(self, sample):

        composed_transforms = transforms.Compose([
            tr.FixScaleCrop(crop_size=self.args.crop_size),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])

        return composed_transforms(sample)


    def __len__(self):
        return len(self.ids)



if __name__ == "__main__":
    from dataloaders import custom_transforms as tr
    from dataloaders.utils import decode_segmap
    from torch.utils.data import DataLoader
    from torchvision import transforms
    import matplotlib.pyplot as plt
    import argparse

    parser = argparse.ArgumentParser()
    args = parser.parse_args()
    args.base_size = 513
    args.crop_size = 513

    coco_val = COCOSegmentation(args, split='val', year='2017')

    dataloader = DataLoader(coco_val, batch_size=4, shuffle=True, num_workers=0)

    for ii, sample in enumerate(dataloader):
        for jj in range(sample["image"].size()[0]):
            img = sample['image'].numpy()
            gt = sample['label'].numpy()
            tmp = np.array(gt[jj]).astype(np.uint8)
            segmap = decode_segmap(tmp, dataset='coco')
            img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
            img_tmp *= (0.229, 0.224, 0.225)
            img_tmp += (0.485, 0.456, 0.406)
            img_tmp *= 255.0
            img_tmp = img_tmp.astype(np.uint8)
            plt.figure()
            plt.title('display')
            plt.subplot(211)
            plt.imshow(img_tmp)
            plt.subplot(212)
            plt.imshow(segmap)

        if ii == 1:
            break

    plt.show(block=True)

下面的main函数为测试使用。

  1. custom_transforms.py 是数据增广的代码
import torch
import random
import numpy as np

from PIL import Image, ImageOps, ImageFilter

class Normalize(object):
   """Normalize a tensor image with mean and standard deviation.
   Args:
       mean (tuple): means for each channel.
       std (tuple): standard deviations for each channel.
   """
   def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
       self.mean = mean
       self.std = std

   def __call__(self, sample):
       img = sample['image']
       mask = sample['label']
       img = np.array(img).astype(np.float32)
       mask = np.array(mask).astype(np.float32)
       img /= 255.0
       img -= self.mean
       img /= self.std

       return {'image': img,
               'label': mask}

class Normalize_test(object):
   """Normalize a tensor image with mean and standard deviation.
   Args:
       mean (tuple): means for each channel.
       std (tuple): standard deviations for each channel.
   """
   def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
       self.mean = mean
       self.std = std

   def __call__(self, sample):
       img = sample['image']
       #mask = sample['label']
       img = np.array(img).astype(np.float32)
       #mask = np.array(mask).astype(np.float32)
       img /= 255.0
       img -= self.mean
       img /= self.std

       return {'image': img}


class ToTensor(object):
   """Convert ndarrays in sample to Tensors."""

   def __call__(self, sample):
       # swap color axis because
       # numpy image: H x W x C
       # torch image: C X H X W
       img = sample['image']
       mask = sample['label']
       img = np.array(img).astype(np.float32).transpose((2, 0, 1))
       mask = np.array(mask).astype(np.float32)

       img = torch.from_numpy(img).float()
       mask = torch.from_numpy(mask).float()

       return {'image': img,
               'label': mask}

class ToTensor_test(object):
   """Convert ndarrays in sample to Tensors."""

   def __call__(self, sample):
       # swap color axis because
       # numpy image: H x W x C
       # torch image: C X H X W
       img = sample['image']
       #mask = sample['label']
       img = np.array(img).astype(np.float32).transpose((2, 0, 1))
       #mask = np.array(mask).astype(np.float32)

       img = torch.from_numpy(img).float()
       #mask = torch.from_numpy(mask).float()

       return {'image': img}


class RandomHorizontalFlip(object):
   def __call__(self, sample):
       img = sample['image']
       mask = sample['label']
       if random.random() < 0.5:
           img = img.transpose(Image.FLIP_LEFT_RIGHT)
           mask = mask.transpose(Image.FLIP_LEFT_RIGHT)

       return {'image': img,
               'label': mask}


class RandomRotate(object):
   def __init__(self, degree):
       self.degree = degree

   def __call__(self, sample):
       img = sample['image']
       mask = sample['label']
       rotate_degree = random.uniform(-1*self.degree, self.degree)
       img = img.rotate(rotate_degree, Image.BILINEAR)
       mask = mask.rotate(rotate_degree, Image.NEAREST)

       return {'image': img,
               'label': mask}


class RandomGaussianBlur(object):
   def __call__(self, sample):
       img = sample['image']
       mask = sample['label']
       if random.random() < 0.5:
           img = img.filter(ImageFilter.GaussianBlur(
               radius=random.random()))

       return {'image': img,
               'label': mask}


class RandomScaleCrop(object):
   def __init__(self, base_size, crop_size, fill=0):
       self.base_size = base_size
       self.crop_size = crop_size
       self.fill = fill

   def __call__(self, sample):
       img = sample['image']
       mask = sample['label']
       # random scale (short edge)
       short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
       w, h = img.size
       if h > w:
           ow = short_size
           oh = int(1.0 * h * ow / w)
       else:
           oh = short_size
           ow = int(1.0 * w * oh / h)
       img = img.resize((ow, oh), Image.BILINEAR)
       mask = mask.resize((ow, oh), Image.NEAREST)
       # pad crop
       if short_size < self.crop_size:
           padh = self.crop_size - oh if oh < self.crop_size else 0
           padw = self.crop_size - ow if ow < self.crop_size else 0
           img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
           mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill)
       # random crop crop_size
       w, h = img.size
       x1 = random.randint(0, w - self.crop_size)
       y1 = random.randint(0, h - self.crop_size)
       img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
       mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))

       return {'image': img,
               'label': mask}


class FixScaleCrop(object):
   def __init__(self, crop_size):
       self.crop_size = crop_size

   def __call__(self, sample):
       img = sample['image']
       mask = sample['label']
       w, h = img.size
       if w > h:
           oh = self.crop_size
           ow = int(1.0 * w * oh / h)
       else:
           ow = self.crop_size
           oh = int(1.0 * h * ow / w)
       img = img.resize((ow, oh), Image.BILINEAR)
       mask = mask.resize((ow, oh), Image.NEAREST)
       # center crop
       w, h = img.size
       x1 = int(round((w - self.crop_size) / 2.))
       y1 = int(round((h - self.crop_size) / 2.))
       img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
       mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))

       return {'image': img,
               'label': mask}

class FixedResize(object):
   def __init__(self):
       self.size = (size, size)  # size: (h, w)

   def __call__(self, sample):
       img = sample['image']
       mask = sample['label']

       assert img.size == mask.size

       img = img.resize(self.size, Image.BILINEAR)
       mask = mask.resize(self.size, Image.NEAREST)

       return {'image': img,
               'label': mask}

class FixedResize_test(object):
   def __init__(self):
       super().__init__()
       #self.size = (size, size)  # size: (h, w)

   def __call__(self, sample):
       img = sample['image']
       w, h = img.size
       #mask = sample['label']

       #assert img.size == mask.size

       img = img.resize(img.size, Image.BILINEAR)
       #mask = mask.resize(self.size, Image.NEAREST)

       return {'image': img}

将以上两个文件加入到你的代码中,就完成了COCO数据集的载入啦~

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

COCO数据集的下载、介绍及如何使用(数据载入及数据增广,含代码) 的相关文章

  • Linux系统之neofetch工具的基本使用

    Linux系统之neofetch工具的基本使用 一 neofetch工具介绍 1 1 neofetch简介 1 2 neofetch特点 二 检查本地环境 2 1 检查操作系统版本 2 2 检查内核版本 三 安装neofetch工具 3 1
  • VMware Workstation Pro 安装教程

    文章目录 笔者的运行环境 VMware Workstation 16 Pro Red Hat Enterprise Linux 8 3 0 需要提前一个操作系统的镜像文件 ISO 这个文件与 VMware 无关 实际上 在安装完 VMwar
  • 生成字典的三种方式

    字典是记录一些特殊或有目的性的密码集合 通常以txt格式进行记录保存 在渗透许多服务器 smb ftp ssh 远程桌面rdp 网页后台等一些用户登录时 没有正确密码 使用密码字典爆破就是最直接的黑客攻击方法 一 使用cupp工具生成 1

随机推荐

  • 对输入数据排序后进行二分查找(C语言)

    输入数据后的排序方法有很多种 这里我用的是暴力排序 各位友友们可以尝试更改排序方法 include
  • Python selenium 滚动页面以及滚动至元素可见之详细讲解

    我们滚动浏览器页面向上 下 左右可以用一下代码 向上和向左需要加 向下滚动xx个像素 driver execute script window scrollBy 0 xx 向上滚动x个像素 driver execute script win
  • JM解码(一):参考帧列表和DPB处理

    以P帧为例 void alloc ref pic list reordering buffer Slice currSlice int size currSlice gt num ref idx active LIST 0 1 if cur
  • 谷粒商城-分布式高级篇[商城业务-订单服务]

    谷粒商城 分布式基础篇 环境准备 谷粒商城 分布式基础 业务编写 谷粒商城 分布式高级篇 业务编写 持续更新 谷粒商城 分布式高级篇 ElasticSearch 谷粒商城 分布式高级篇 分布式锁与缓存 项目托管于gitee 一 页面环境搭建
  • ubuntu freeradius 3.0 + mariadb

    安装数据库及Radius sudo apt update sudo apt install y freeradius freeradius mysql freeradius utils mariadb server mariadb clie
  • Unity Rotate鼠标控制人物旋转

    添加碰撞盒 一定要添加碰撞盒才能响应鼠标事件 将碰撞盒复制给骨架 如果鼠标划动的向量 X轴大于Y轴 则是左右划动 让它旋转 SpinWithMouse using System Collections using System Collec
  • 窗体,组件,事件

    窗体对象JFrame package frame import javax swing public class JFrameTest public static void main String args 创建窗体对象 JFrame jF
  • 使用javacv中的ffmpeg实现录屏,结果连运行都失败了,现在终于解决了

    前言 今天突发奇想 想自己写一个录屏的软件 上次写了一个专门录音的Demo 但是要把声音和视频放到一起合成一个mp4文件 着实有一点艰难 所以就打算使用ffmpeg来写一个 而这篇博客中会顺便谈一谈我碰到的各种坑 ffmpeg是一个c 程序
  • 中兴EPON OLT-C300开局配置

    一 基础配置 1 自定义时间 clock set hh mm ss Apr 8 2018 con t username zte password zte privilege 15 用户名密码 2 自定义名称 hostname CeShi O
  • js如何进行数组去重?

    1 数组反转 使用 reverse 实现数组反转 const arr 1 2 3 console log arr 1 2 3 arr reverse console log arr 3 2 1 2 数组去重 1 new Set array
  • python学习语法中与c语言不同之处(1)

    一 发现使用打印使用的是print 而在C语言中我们更多的使用的是printf 比如想要打印出来hello world 直接如下 C语言 printf a d a python语言 print hello world 然后就是直接回车键就可
  • angular表单验证

    表单验证 通常 我们都需要对用户的表单输入做验证 以保证数据的整体质量 Angular也有两种验证表单的形式 使用属性验证 用于模板驱动表单 使用验证器函数进行验证 用于响应式表单 验证器 Validator 函数 验证器函数可以是同步函数
  • 13功能之C++类默认生成的六个成员函数的自定义

    13功能之C 类默认生成的六个成员函数的自定义 1 代码理解即可 pragma warning disable 4996 include
  • UGUI之rectTransform属性

    RectTransform 本文转载自uGUI知识点剖析之RectTransform 一 基本要点 RectTransform继承于Transform 在 Transform 基础上 RectTransform 增加了 轴心 pivot 锚
  • 【文献翻译】构建网络安全知识库的框架-A Framework to Construct Knowledge Base for Cyber Security

    摘要 现在有一些针对不同方面的独立网络安全知识库 在互联网上 也有很多网络安全相关的内容以文字的形式存在 融合这些网络安全相关信息可以是一项有意义的工作 在本文中 我们提出了一个框架来整合现有的网络安全知识库并从文本中提取网络安全相关信息
  • java8日期时间相关

    java8时间相关api 一 java8时间相关api出现的原因 二 LocalDate LocalTime LocalDateTime的使用 1 解释 2 学习点 3 代码示例 三 Instant 1 解释 2 学习点 3 代码示例 四
  • Ubuntu下的CUDA编程(二)

    Ubuntu下cuda编程的基本过程 一 运行程序 按照上一篇文章所述 安装好cuda软件以后 就可以使用 nvcc V 命令查看所用到的编译器版本 本人用版本信息来自 Cuda compilation tools release 3 2
  • python学习——如何求质数/素数

    质数判断 方法一 一个大于1的自然数 除了1和它本身外 不能被其他自然数 质数 整除 2 3 5 7等 换句话说就是该数除了1和它本身以外不再有其他的因数 也就是说 从2到n 1遍历 如果存在一个数是这个整数n的因数 那么它就不是质数 但是
  • docker保存镜像到本地,并加载本地镜像文件

    docker保存镜像到本地 并加载本地镜像文件 1 查看已有的镜像文件 docker images 显示效果如下所示 2 将镜像打包成本地文件 指令 docker save 镜像id gt 文件名 tar docker save 17282
  • COCO数据集的下载、介绍及如何使用(数据载入及数据增广,含代码)

    如何使用COCO数据集 COCO数据集可以说是语义分割等计算机视觉任务中应用较为广泛的一个数据集 具体可以应用到物体识别 语义分割及目标检测等方面 我是在做语义分割方面任务时用到了COCO数据集 但本文主要讲解的是数据载入方面 因此可以通用