如何将 detectorron2 的增强功能与使用 register_coco_instances 加载的数据集结合使用

2023-12-09

我已经在以 coco 格式标记和导出的自定义数据上训练了 detectorron2 模型,但现在我想应用增强并使用增强数据进行训练。如果我不使用自定义 DataLoader,而是使用 register_coco_instances 函数,我该如何做到这一点。

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
predictor = DefaultPredictor(cfg)
outputs = predictor(im)

train_annotations_path = "./data/cvat-corn-train-coco-1.0/annotations/instances_default.json"
train_images_path = "./data/cvat-corn-train-coco-1.0/images"
validation_annotations_path = "./data/cvat-corn-validation-coco-1.0/annotations/instances_default.json"
validation_images_path = "./data/cvat-corn-validation-coco-1.0/images"

register_coco_instances(
    "train-corn",
    {},
    train_annotations_path,
    train_images_path
)
register_coco_instances(
    "validation-corn",
    {},
    validation_annotations_path,
    validation_images_path
)
metadata_train = MetadataCatalog.get("train-corn")
dataset_dicts = DatasetCatalog.get("train-corn")

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("train-corn",)
cfg.DATASETS.TEST = ("validation-corn",)
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")  # Let training initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025
cfg.SOLVER.MAX_ITER = 10000
cfg.SOLVER.STEPS = []
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 4
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()

我在文档中看到您可以加载数据集并应用增强,如下所示:

dataloader = build_detection_train_loader(cfg,
   mapper=DatasetMapper(cfg, is_train=True, augmentations=[
      T.Resize((800, 800))
   ]))

但我没有使用自定义数据加载器,执行此操作的最佳方法是什么?


根据我的经验,如何注册数据集(即告诉 Detectron2 如何获取名为"my_dataset")与训练期间使用什么数据加载器(即如何从注册数据集中加载信息并将其处理为模型所需的格式)无关。

因此,您可以根据需要注册数据集 - 可以使用register_coco_instances函数或使用数据集 API (DatasetCatalog, MetadataCatalog) 直接地;没关系。重要的是您想要在数据加载部分应用一些转换。

基本上,您想要自定义数据加载部分,这只能通过使用自定义数据加载器来实现(除非您执行离线增强,这可能不是您想要的)。

现在,您不需要直接在顶级代码中定义和使用自定义数据加载器。您可以创建自己的训练器,派生自DefaultTrainer,并覆盖它的build_train_loader方法。这很简单,如下所示。

class MyTrainer(DefaultTrainer):

    @classmethod
    def build_train_loader(cls, cfg):
        mapper = DatasetMapper(cfg, is_train=True, augmentations=[T.Resize((800, 800))])
        return build_detection_train_loader(cfg, mapper=mapper)

那么,在您的顶级代码中,唯一需要的更改就是使用MyTrainer代替DefaultTrainer.

trainer = MyTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何将 detectorron2 的增强功能与使用 register_coco_instances 加载的数据集结合使用 的相关文章

随机推荐

  • 如何摆脱 pygame 表面?

    在下面的代码中 不仅有one在任何给定时间点屏幕上出现圆圈 我想修复这个问题 使其看起来只有一个圆圈 而不是在鼠标光标所在的地方留下污迹 import pygame sys from pygame locals import pygame
  • HTML 数字输入最小值和最大值无法正常工作

    I have type number输入字段和我已经设置min and max它的值
  • JSON 格式在附加文件时添加 \ 字符,但不添加到输出中的字符串

    我正在使用以下函数来获取json来自 flickrAPI 它返回的字符串是格式正确的 JSON 块 def get photo data photo id para para photo id photo id para method fl
  • Tensorflow - 推理时间评估

    我正在使用 Tensorflow 评估不同的图像分类模型 特别是使用不同设备的推理时间 我想知道是否必须使用预训练模型 我使用一个脚本生成 1000 个随机输入图像 将它们一一输入到网络 并计算平均推理时间 谢谢 让我首先发出警告 大多数人
  • 在c#中一段时间​​内非阻塞等待/延迟的最佳实现是什么

    目前我需要在Windows Store应用程序项目中实现一个简单的非阻塞延迟功能 该函数不应该执行任何操作 只是在特定时间段内空闲而不阻塞 UI 我的问题是 如何正确实现这样的功能 我知道这是一个老问题 但在网上搜索后我真的没有任何线索 最
  • Predict.svm 中的错误:测试数据与模型不匹配

    我有一个大约 500 行和 170 列的数据框 我正在尝试使用 e1071 包中的 svm 运行分类模型 分类变量称为 SEGMENT 是一个有 6 个级别的因子变量 数据框中还有其他三个因子变量 其余都是数字 data lt my dat
  • 从 Google 通讯录中删除重复或重复的联系人

    我的目标是制作一个包含客户联系信息 地址和注释的电子表格 创建新的 Google 通讯录条目 并将联系人 ID 和 已添加 标记到表格中 该表格将已输入到 Google 通讯录的联系人标记为 已添加 这张表 我成功地做到了 一切正常 问题只
  • 即时搜索 PB 级数据

    我需要在 CSV 格式文件中搜索 PB 级的数据 使用LUCENE建立索引后 索引文件的大 小是原始文件的两倍 是否可以减小索引文件的大 小 如何在HADOOP中分发LUCENE索引文件以及如何在搜索环境中使用 或者是否有必要 我应该使用s
  • 如何在 pytorch 中更改输入图片的尺寸?

    我制作了一个卷积神经网络 我希望它获取输入图片和输出图片 但是当我将图片转换为张量时 它们的尺寸错误 RuntimeError Expected 4 dimensional input for 4 dimensional weight 20
  • Jenkins 无法识别生成的 allure 报告 xml 文件的正确目录

    我已成功将 Allure 报告集成到我的基于 Maven 的 testNG 项目中 并且能够使用 jetty 服务器查看该报告 但现在我正在尝试按照此处建议的说明将魅力报告与詹金斯集成 http wiki qatools ru displa
  • 如何处理不在 UINavigationController 堆栈顶部的 UIViewController 的旋转?

    我在 UINavigationController 中有一个根 UIViewController VC1 它通过在 willRotateToInterfaceOrientation 方法中手动调整其视图 子视图框架来处理旋转 如果根 UIV
  • 加密/解密字节数组 Crypto++

    我正在尝试使用 AES 加密字节数组 我已经能够毫无问题地加密字符串和文件 但是字节数组似乎不适合我 我传入一个要加密的字节数组 为了便于测试 我只传入由 crypto bArrayToEncrypt 生成的 AES 密钥 加密似乎有效 但
  • systemd 服务未使用 dbus 接口启动

    我正在尝试启动 systemd 服务 usnig dbus 服务 我正在关注下面提到的链接的示例 5 http www freedesktop org software systemd man systemd service html 我的
  • 在 Lua 中按值对表进行关联排序

    我有一个 key gt value 表 我想在 Lua 中排序 键都是整数 但不连续 并且有意义 Lua唯一的排序函数似乎是table sort 它将表视为简单数组 丢弃原始键及其与特定项目的关联 相反 我本质上希望能够使用PHP s as
  • GCP Firestore Python 凭证

    我在将数据从 Linux 虚拟机发送到 GCP 的 Firestore 时遇到问题 我只是想更新数据库内的项目 我遇到有关凭据的问题 根据我使用的方法 我会得到不同的错误 但我相信它们都源于同一问题 请注意 我有一个带有 json 凭据的服
  • 为什么 T* 可以在寄存器中传递,但 unique_ptr 却不能?

    我正在观看 Chandler Carruth 在 CppCon 2019 上的演讲 不存在零成本抽象 在其中 他举了一个例子 说明他对使用std unique ptr
  • 为此使用什么正则表达式

    我正在编写一个正则表达式 它将找到 1个或多个点 后面跟一个空格或者后面根本不跟任何东西 1 个或多个问号 再次后面跟一个空格或者后面根本不跟任何东西 我该如何编写这个正则表达式 以便让它执行此或操作 你只需要逃避 or with a 从字
  • Rails 3 关联错误

    我有一个表格页面和一个表格作者 每一页都属于一位作者 还为表和模型创建了迁移 但在表单中使用它时出现此错误 NoMethodError in Pages new Showing C rorapp app views pages form h
  • 不活动和活动、应用程序空闲、用户不活动自动注销

    经过大量谷歌搜索并花费了 4 个小时后 我想这是查找用户不活动和锁定屏幕的最佳方法 public MainWindow InitializeComponent var timer new DispatcherTimer Interval T
  • 如何将 detectorron2 的增强功能与使用 register_coco_instances 加载的数据集结合使用

    我已经在以 coco 格式标记和导出的自定义数据上训练了 detectorron2 模型 但现在我想应用增强并使用增强数据进行训练 如果我不使用自定义 DataLoader 而是使用 register coco instances 函数 我