使用Stable Diffusion图像修复来生成自己的目标检测数据集

2023-11-18

点击上方“AI公园”,关注公众号,选择加“星标“或“置顶”


作者:Rédigé par Gabriel Guerin

编译:ronghuaiyang

导读

有些情况下,收集各种场景下的数据很困难,本文给出了一种方法。

c88b3e81d116ba7659389133a2356236.jpeg

深度学习模型需要大量的数据才能得到很好的结果,目标检测模型也是一样。

要训练一个YOLOv5的模型来自动检测你最喜欢的玩具,你需要拍几千张你的玩具在不同上下文中的照片,对于每张图,你需要标注玩具在图中的位置。

这样是非常耗时的。

本文提出了使用图像分割和stable diffusion来自动生成目标检测数据集的方法。

e572d5dbe912d3df83062dd48eeb4ac7.jpeg

生成自定义数据集的pipeline

生成目标检测数据集的pipeline包含4个步骤:

  • 找一个和你要识别的物体属于相同实例的数据集(比如狗数据集)。

  • 使用图像分割生成狗的mask。

  • 微调图像修复Stable Diffusion模型。

  • 使用Stable Diffusion图像修复模型和生成的mask来生成数据。

图像分割:生成mask图像

Stable Diffusion图像修复pipeline需要输入一个提示,一张图像和一张mask图像,这个模型会只从mask图像中的白色像素部分上去生成新的图像。

PixelLib这个库帮助我们来做图像分割,只用几行代码就可以,在这个例子里,我们会使用PointRend模型来检测狗,下面是图像分割的代码。

import pixellib
from pixellib.torchbackend.instance import instanceSegmentation

ins = instanceSegmentation()
ins.load_model("pointrend_resnet50.pkl")
target_classes = ins.select_target_classes(dog=True)
results, output = ins.segmentImage(
  "dog.jpg", 
  show_bboxes=True, 
  segment_target_classes=target_classes, 
  output_image_name="mask_image.jpg"
)
使用pixellib来做图像分割

segmentImage 函数返回一个tuple:

  • results : 是一个字典,包含了 'boxes', 'class_ids', 'class_names', 'object_counts', 'scores', 'masks', 'extracted_objects'这些字段。

  • output : 原始的图像和mask图像进行了混合,如果show_bboxes 设置为True,还会有包围框。

生成mask图像

我们生成的mask只包含白色和黑色的像素,我们的mask会比原来图中的狗略大一些,这样可以给Stable Diffusion足够的空间来进行修复。为了做到这种效果,我们将mask向左、右、上、下分别平移了10个像素。

from PIL import Image
import numpy as np

width, height = 512, 512
image=Image.open("dog.jpg")

# Store the mask of dogs found by the pointrend model
mask_image = np.zeros(image.size)
for idx, mask in enumerate(results["masks"].transpose()):
  if results["class_names"][idx] == "dog":
    mask_image += mask


# Create a mask image bigger than the original segmented image
mask_image += np.roll(mask_image, 10, axis=[0, 0]) # Translate the mask 10 pixels to the left
mask_image += np.roll(mask_image, -10, axis=[0, 0]) # Translate the mask 10 pixels to the right
mask_image += np.roll(mask_image, 10, axis=[1, 1]) # Translate the mask 10 pixels to the bottom
mask_image += np.roll(mask_image, -10, axis=[1, 1]) # Translate the mask 10 pixels to the top


# Set non black pixels to white pixels
mask_image = np.clip(mask_image, 0, 1).transpose() * 255
# Save the mask image
mask_image = Image.fromarray(np.uint8(mask_image)).resize((width, height))
mask_image.save("mask_image.jpg")
从pixellib的输出生成图像的mask

现在,我们有了狗图像的原始图和其对应的mask。

79cea50500cbc234b12e73efbbc36830.jpeg

使用pixellib基于狗的图像生成mask

微调Stable Diffusion图像修复pipeline

Dreambooth是微调Stable Diffusion的一种技术,我们可以使用很少的几张照片将新的概念教给模型,我们准备使用这种技术来微调图像修复模型。train_dreambooth_inpaint.py这个脚本中展示了如何在你自己的数据集上微调Stable Diffusion模型。

微调需要的硬件资源

在单个24GB的GPU上可以使用gradient_checkpointingmixed_precision来微调模型,如果要使用更大的batch_size 和更快的训练,需要使用至少30GB的GPU。

安装依赖

在运行脚本之前,确保安装了这些依赖:

pip install git+https://github.com/huggingface/diffusers.git
pip install -U -r requirements.txt

并初始化加速环境:

accelerate config

你需要注册Hugging Face Hub的用户,你还需要token来使用这些代码,运行下面的命令来授权你的token:

huggingface-cli login

微调样本

在运行这些计算量很大的训练的时候,超参数微调很关键,需要在你跑训练的机器上尝试不同的参数,我推荐的参数如下:

$ accelerate launch train_dreambooth_inpaint.py \
  --pretrained_model_name_or_path="runwayml/stable-diffusion-inpainting"  \
  --instance_data_dir="dog_images" \
  --output_dir="stable-diffusion-inpainting-toy-cat" \
  --instance_prompt="a photo of a toy cat" \
  --resolution=512 \
  --train_batch_size=1 \
  --learning_rate=5e-6 \   
  --lr_scheduler="constant" \   
  --lr_warmup_steps=0 \   
  --max_train_steps=400 \
  --gradient_accumulation_steps=2 \
  --gradient_checkpointing \
  --train_text_encoder

运行Stable Diffusion图像修复pipeline

Stable Diffusion图像修复是一个text2image的扩散模型,使用一张带mask的图像和文本输入来生成真实的图像。使用https://github.com/huggingface/diffusers来实现这个功能。

from PIL import Image
from diffusers import StableDiffusionInpaintPipeline


# Image and Mask
image = Image.open("dog.jpg")
mask_image = Image.open("mask_image.jpg")


# Inpainting model
pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "stable-diffusion-inpainting-toy-cat",
    torch_dtype=torch.float16,
)
image = pipe(prompt="a toy cat", image=image, mask_image=mask_image).images[0]

使用微调过的模型运行Stable Diffusion图像修复。

结论Conclusion

总结一下:

  • 使用pixellib进行图像分割,得到图像的mask。

  • 微调runwayml/stable-diffusion-inpainting模型使得该模型能够学习到新的玩具猫类型。

  • 在狗的图像上,使用微调过的模型和生成的mask运行StableDiffusionInpaintPipeline

最终的结果

所有步骤完成之后,我们生成了一个新的图像,玩具猫代替了原来的狗的位置,这样,2张图像可以使用相同的包围框。

![img](Stable Diffusion Inpainting Generate a Custom Dataset for Object Detection.assets/Capturedecran2023-01-22a23_17_25_8025faada328368a6335c61ced262d96_800.jpg)

我们现在可以为数据集中的所有的图像都生成新的图像。

局限性

Stable Diffusion并不能每次都生成好的结果,数据集生成之后,还需要进行清理的工作。

这个pipeline是非常耗费计算量的,Stable Diffusion的微调需要24GB内存的显卡,推理的时候也是需要GPU的。

这种构建数据集的方法当数据集中的图像很难获得的时候是很有用的,比如,你需要检测森林火焰,最好是使用这种方法,而不是去森林里放火。但是,对于普通的场景,数据标注还是最标准的做法。

46c30899308aa04b2cbe696632fc2108.png

—END—

英文原文:https://www.sicara.fr/blog-technique/dataset-generation-fine-tune-stable-diffusion-inpainting

9e7d7361f3f44f36e5fe653334743180.jpeg

请长按或扫描二维码关注本公众号

喜欢的话,请给我个在看吧

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用Stable Diffusion图像修复来生成自己的目标检测数据集 的相关文章

随机推荐

  • 实际工作中的高级技术(训练加速、推理加速、深度学习自适应、对抗神经网络)

    目录 一 训练加速 1 基于数据的并行 Model Average 模型平均 SSGD 同步随机梯度下降
  • 大学生选课抢课如何提高选中概率

    作者位于哈尔滨某高校 选课总是激动人心的一件大事 但是明明与同学一起进的系统 他就能顺利选课 而我却被强退出来 无数辛酸让我知道了一些道理 写下这篇文章给学弟学妹们作为参考 原理 问 为什么大多数学校教务系统选课时都会卡 答 学校教务系统平
  • 热敏电阻测温

    热敏电阻器主要分为 PTC 和 NTC 正温度系数热敏电阻器 PTC 在温度越高时电阻值越大 负温度系数热敏电阻器 NTC 在温度越高时电阻值越低 它们同属于半导体器件 测温的热敏电阻一般为NTC 其主要参数有以下几个 标称阻值 标称阻值是
  • 期货有哪些(正规期货公司排名)

    期货有哪些 期货暂时重要分为两大版块 辨别是商品期货和金融期货 与此同声这两大版块又不妨辨别细化出各别的品种 商品期货又可细分为非金属商品 动力商品 农产物等 金融期货重要指保守的金融商品或东西 如一手一足 内债 税率 汇率等 商品期货农产
  • 58同城面经

    文章目录 58一面 58二面 58同城通过了技术面试 但迟迟没有hr面 可能表现的不是很好 58一面 自我介绍 数据结构大概有哪些分类 关于项目 为什么会考虑做商城项目 商城首页的优化 操作系统为什么会有线程这个操作吗 Java创建线程的方
  • Golang基础 流程控制 循环控制

    循环控制 01 基础循环 for 02 键值循环 for range 参考资料 循环控制通常用于程序中需要重复执行的逻辑模块 循环结构通常由循环变量 循环终止条件和循环体三个部分构成 01 基础循环 for Golang 中所有的循环控制都
  • PCL 最小点数约束的改进半径滤波(C++详细过程版)

    目录 一 概述 1 不足 2 改进 二 代码实现 三 结果展示 一 概述 1 不足 传统半径滤波算法在点云数据量巨大的情况下 算法效率会大幅度降低 而对于稠密点云数据 一个影响效率的重要因素就是搜索半径的大小 当搜索半径较大时 需要计算邻域
  • @vue/cli 创建项目报Cannot find module ‘inquirer‘错

    解决 这可能是因为cli版本问题 1 第一步 2 第二步 npm uninstall g vue cli 3 第三步 npm install g vue cli
  • 由PyRetri浅谈基于深度学习的图像检索

    前言 最近发现face 开源了一个图像检索和行人重识别的基于深度学习的软件包 最近一段时间也一直在接触图像检索相关的东西 故借此机会 对里面涉及的一些常用的方法模块进行一个简单的介绍总结 便于日后回顾 PyRetri是什么 PyRetri是
  • 如何查看linux服务器字符集,Linux字符集查看与设置

    查看字符集 Linux 中字符集在系统中的体现是一个环境变量 以 CentOS 6 5 为例 查看当前终端使用的字符集的方式有 1 root jerry echo LANG zh CN GB18030 2 root jerry env gr
  • 对 React Hook的闭包陷阱的理解,有哪些解决方案?

    hooks中 奇怪 其实符合逻辑 的 闭包陷阱 的场景 同时 在许多 react hooks 的文章里 也能看到 useRef 的身影 那么为什么使用 useRef 又能摆脱 这个 闭包陷阱 搞清楚这些问题 将能较大的提升对 react h
  • vue 全局组件注册_如何注册vue3全局组件

    vue 全局组件注册 With the new versions of Vue3 out now it s useful to start learning how the new updates will change the way w
  • unity playerprefs android,Unity持久化存储之PlayerPrefs的使用

    一 PlayerPrefs类支持3中数据类型的保存和读取 浮点型 整形 和字符串型 分别对应的函数为 php SetInt 保存整型数据 GetInt 读取整形数据 SetFloat 保存浮点型数据 GetFlost 读取浮点型数据 Set
  • pygame之五子棋的实现

    先上代码 调用pygame库 import pygame import sys 调用常用关键字常量 from pygame locals import QUIT KEYDOWN import numpy as np 初始化pygame py
  • laravel-vue后端返回数据的字符串中(<br/> \n)换行无效

    laravel 做后端 vue做前端 后端返回数据的字符串中含有 br 或 n r n 等换行符 在前端页面无法正常渲染出换行效果 尝试用str replace方法无效 最终找到解决办法 解决办法 给包含换行符的字符串元素增加css whi
  • 【STM32学习】——串口通信协议&STM32-USART外设&数据帧/输入数据策略/波特率发生器&串口发送/接受实操

    文章目录 前言 一 串口通信 1 通信接口 2 串口通信 1 串口简介 2 串口硬件电路 3 串口软件部分 二 STM32的USART外设 1 USART简介 2 图示详解 三 细节问题 1 数据帧 2 输入数据策略 1 起始位侦测 2 数
  • iOS开发,tableView中cell的重用详解

    注意 原创版权 转载必须标明出处作者 翻版必究 iOS中tableView是一个大的模块组件 它的重要性每个iOSCoder都是了解的 但是tableView中却有个重大的坑 就是cell的重用 每个刚接触iOS开发的人都深受其海 那么经过
  • AD18出现Unknown Pin报错解决

    问题描述 检查错误 检查原理图对应元件的封装是否存在 检查原理图与封装PCB引脚数量是否对应 检查原理图与封装的管脚是否统一 找到原因 原理图的管脚命名与PCB封装管脚命名不一致 问题解决 修改原理图管脚名称 修改PCB Library的管
  • luajit struct

    This page is intended to give you an overview of the features of the FFI library by presenting a few use cases and guide
  • 使用Stable Diffusion图像修复来生成自己的目标检测数据集

    点击上方 AI公园 关注公众号 选择加 星标 或 置顶 作者 R dig par Gabriel Guerin 编译 ronghuaiyang 导读 有些情况下 收集各种场景下的数据很困难 本文给出了一种方法 深度学习模型需要大量的数据才能