万物分割SAM使用教程

2023-10-26


原理篇

安装

# 创建虚拟环境
conda create -n sam python=3.8
# 激活环境
conda activate sam
# 下载代码
git clone git@github.com:facebookresearch/segment-anything.git
# 安装
cd segment-anything; pip install -e .
# 常见库安装
pip install torch torchvision opencv-python pycocotools matplotlib onnxruntime onnx

下载模型,放置models文件夹,本示例使用ViT-H
在这里插入图片描述

使用

SAM输入为points, boxes, textmask

全图分割

输入图片‘onepiece.jpg’,
在这里插入图片描述
输出结果如下图,
在这里插入图片描述

代码:

# coding=utf-8
import numpy as np
import matplotlib.pyplot as plt
import cv2
from pathlib import Path
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, SamPredictor

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)
    
def process_img(img_path):
    '''img_path to img(np.array)
    '''
    image = cv2.imread(img_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image

def entire_img(img_path):
    '''whole img generate mask
    '''
    image = process_img(img_path)
    sam = sam_model_registry["vit_h"](checkpoint="./models/sam_vit_h_4b8939.pth")
    sam.to(device="cuda")
    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(image)
    plt.figure(figsize=(20,20))
    plt.imshow(image)
    show_anns(masks)
    plt.axis('off')
    plt.savefig(str(Path(img_path).name))
    
    # predictor = SamPredictor(sam)
def main():
    img_path = './notebooks/images/onepiece.jpg'
    entire_img(img_path)


if __name__ == "__main__":
    main()

选取绿色五角星位置[1064, 1205]
在这里插入图片描述

选取框坐标[1305, 244, 2143, 1466]
在这里插入图片描述

完整代码

完整代码如下,欢迎大家体验

# coding=utf-8
import numpy as np
import matplotlib.pyplot as plt
import cv2
from pathlib import Path
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, SamPredictor

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)
    
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 

def process_img(img_path):
    '''img_path to img(np.array)
    '''
    image = cv2.imread(img_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image

def entire_img(img_path):
    '''whole img generate mask
    '''
    image = process_img(img_path)
    sam = sam_model_registry["vit_h"](checkpoint="./models/sam_vit_h_4b8939.pth")
    sam.to(device="cuda")
    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(image)
    plt.figure(figsize=(20,20))
    plt.imshow(image)
    show_anns(masks)
    plt.axis('off')
    plt.savefig(str(Path(img_path).name))

def predict(img_path, type='point'):
    image = process_img(img_path)
    sam = sam_model_registry["vit_h"](checkpoint="./models/sam_vit_h_4b8939.pth")
    sam.to(device="cuda")

    predictor = SamPredictor(sam)
    predictor.set_image(image)
    if type == 'point':
        # [X, Y]
        input_point = np.array([[1064, 1205]])
        input_label = np.array([1])
        masks, scores, logits = predictor.predict(
                point_coords=input_point,
                point_labels=input_label,
                multimask_output=True,
        )
    elif type == 'bbox':
        input_box = np.array([1305, 244, 2143, 1466])
        masks, scores, logits = predictor.predict(
            point_coords=None,
            point_labels=None,
            box=input_box[None, :],
            multimask_output=False,
        )


    index = np.argmax(scores)

    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(masks[index], plt.gca())
    if type == 'point':
        show_points(input_point, input_label, plt.gca())
    elif type == 'bbox':
        show_box(input_box, plt.gca())
    plt.title(f"Score: {scores[index]:.3f}", fontsize=18)
    plt.savefig(str(Path(img_path).stem)+f'{scores[index]:.3f}.png')


    # predictor = SamPredictor(sam)
def main():
    img_path = './notebooks/images/onepiece.jpg'
    # entire_img(img_path)
    predict(img_path, type='bbox')
    # predict(img_path)


if __name__ == "__main__":
    main()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

万物分割SAM使用教程 的相关文章

  • 将 yerr/xerr 绘制为阴影区域而不是误差线

    在 matplotlib 中 如何将误差绘制为阴影区域而不是误差条 例如 而不是 忽略示例图中各点之间的平滑插值 这需要进行一些手动插值 或者只是获得更高分辨率的数据 您可以使用pyplot fill between https matpl
  • 从字典的元素创建 Pandas 数据框

    我正在尝试从字典创建一个 pandas 数据框 字典设置为 nvalues y1 1 2 3 4 y2 5 6 7 8 y3 a b c d 我希望数据框仅包含 y1 and y2 到目前为止我可以使用 df pd DataFrame fr
  • 定义Python源代码编码的正确方法

    PEP 263 http www python org dev peps pep 0263 定义如何声明Python源代码编码 通常 Python 文件的前两行应以以下内容开头 usr bin python coding
  • 如何在python中附加两个字节?

    说你有b x04 and b x00 你如何将它们组合起来b x0400 使用Python 3 gt gt gt a b x04 gt gt gt b b x00 gt gt gt a b b x04 x00
  • 如何将 sql 数据输出到 QCalendarWidget

    我希望能够在日历小部件上突出显示 SQL 数据库中的一天 就像启动程序时突出显示当前日期一样 在我的示例中 它是红色突出显示 我想要发生的是 当用户按下突出显示的日期时 数据库中日期旁边的文本将显示在日历下方的标签上 这是我使用 QT De
  • Python“非规范化”unicode 组合字符

    我正在寻找标准化 python 中的一些 unicode 文本 我想知道是否有一种简单的方法可以在 python 中获得组合 unicode 字符的 非规范化 形式 例如如果我有序列u o xaf i e latin small lette
  • 在 Mac OS X 上安装 libxml2 时出现问题

    我正在尝试在我的 Mac 操作系统 10 6 4 上安装 libxml2 我实际上正在尝试在 Python 中运行 Scrapy 脚本 这需要我安装 Twisted Zope 现在还需要安装 libxml2 我已经下载了最新版本 2 7 7
  • 如何在 Django Rest 框架中编写“删除”操作的测试

    我正在为 Django Rest Framework API 编写测试 我一直在测试 删除 我对 创建 的测试工作正常 这是我的测试代码 import json from django urls import reverse from re
  • PIL.Image.open和tf.image.decode_jpeg返回值的区别

    我使用 PIL Image open 和 tf image decode jpeg 将图像文件解析为数组 但发现PIL Image open 中的像素值与tf image decode jpeg不一样 为什么会出现这种情况 Thanks 代
  • 时间序列数据预处理 - numpy strides 技巧以节省内存

    我正在预处理一个时间序列数据集 将其形状从二维 数据点 特征 更改为三维 数据点 时间窗口 特征 在这样的视角中 时间窗口 有时也称为回顾 指示作为输入变量来预测下一个时间段的先前时间步长 数据点的数量 换句话说 时间窗口是机器学习算法在对
  • 使用标签或 href 传递 Django 数据

    我有一个包含链接的表 当单击该链接进行更多操作时 我想将一些数据传递给我的函数 my html table tbody for query in queries tr td value a href internal my func que
  • Flask 应用程序路由中的多个参数

    烧瓶怎么写app route如果我在 URL 调用中有多个参数 这是我从 AJax 调用的 URL http 0 0 0 0 8888 createcm summary VVV change Feauure 我试图写我的烧瓶app rout
  • 如何在 Seaborn 中的热图轴上表达类

    我使用 Seaborn 创建了一个非常简单的热图 显示相似性方阵 这是我使用的一行代码 sns heatmap sim mat linewidths 0 square True robust True sns plt show 这是我得到的
  • 基于值而不是类型的单次调度

    我在 Django 上构建 SPA 并且有一个庞大的功能 其中包含许多功能if用于检查我的对象字段的状态名称的语句 像这样 if self state new do some logic if self state archive do s
  • 使用 selenium 和 python 来提取 javascript 生成的 HTML?萤火虫?

    这里是Python新手 我遇到的是数据收集问题 我在这个网站上 当我用 Firebug 检查我想要的元素时 它显示了包含我需要的信息的源 然而常规源代码 没有 Firebug 不会给我这个信息 这意味着我也无法通过正常的 selenium
  • 根据多个阈值将 SciPy 分层树状图切割成簇

    我想将 SciPy 的树状图切割成多个具有多个阈值的簇 我尝试过使用 fcluster 但它只能削减一个阈值 例如 这是我从另一个问题中摘取的一段代码 import pandas data pandas DataFrame total ru
  • 高效创建抗锯齿圆形蒙版

    我正在尝试创建抗锯齿 加权而不是布尔 圆形掩模 以制作用于卷积的圆形内核 radius 3 no of pixels to be 1 on either side of the center pixel shall be decimal a
  • 如何绘制更大的边界框和仅裁剪边界框文本 Python Opencv

    我正在使用 easyocr 来检测图像中的文本 该方法给出输出边界框 输入图像如下所示 Image 1 Image 2 使用下面的代码获得输出图像 But I want to draw a Single Bigger bounding bo
  • 如何使用xlwt设置文本颜色

    我无法找到有关如何设置文本颜色的文档 在 xlwt 中如何完成以下操作 style xlwt XFStyle bold font xlwt Font font bold True style font font background col
  • 描述符“join”需要“unicode”对象,但收到“str”

    代码改编自here http wiki geany org howtos convert camelcase from foo bar to Foo Bar def lower case underscore to camel case s

随机推荐

  • 动手学数据分析 Task5

    动手学数据分析 Task5 一 逻辑回归 二 随机森林 三 模型评估 3 1 k折交叉验证 3 2 混淆矩阵 3 3 ROC曲线 一 逻辑回归 LogisticRegression penalty l2 dual False tol 0 0
  • 如何将Zookeeper和Kafka的log4j升级到2.16

    1 删除lib下的jar文件 对于kafka lib 删除 slf4j api 1 7 25 jar slf4j log4j12 1 7 25 jar log4j 1 2 17 jar 对于zk lib 删除 log4j 1 2 17 ja
  • 毕业设计 - stm32单片机的远程WIFI密码锁 - 物联网 嵌入式

    文章目录 0 前言 1 简介 主要器件 实现效果 4 硬件设计 WIFI模块 OLED显示屏 相关原理图 硬件接线 5 软件说明 开发环境介绍 程序下载配置 设备初始化打印的信息 6 部分核心代码 7 最后 0 前言 这两年开始毕业设计和毕
  • kubernetes集群-Master节点升级-kubeadm,kubectl,kubelet升级

    kubernetes Master单节点升级 kubeadm 升级 kubelet 升级 kubectl 升级 生产环境注意事项 由于 kubeadm upgrade 不会升级 etcd 请确保已对其进行了备份 例如 您可以使用 etcdc
  • java setsession_Java Session.setServerAliveInterval方法代码示例

    import com jcraft jsch Session 导入方法依赖的package包 类 private Session startNewSession boolean acquireChannel throws JSchExcep
  • 华为od机试 Java【跳房子2】

    题目 有若干个连续的方格地板 儿童们喜欢在上面玩游戏 在这个游戏中 玩家需要在三个回合内 按照规定的步数 从第一格跳到最后一格 跳到最后的玩家有机会选择一个他们喜欢的房子 直到所有的房子都被选完 当然 游戏中最多房子的人是胜者 但游戏并不那
  • 快速浏览Swift-笔记

    快速浏览Swift 笔记 快速浏览Swift https docs swift org swift book GuidedTour GuidedTour html 变量也常量 多行字符串 使用 let quotation I said I
  • python文件工程化,隐藏源码

    python文件工程化 隐藏源码 py文件转换为pyc文件 全文来自博客https www cnblogs com HByang p 13223118 html pyc介绍 pyc是一种二进制文件 是由py文件经过编译后 生成的文件 是一种
  • 3 个 C 程序示例,用于创建包含数据的文件

    本教程介绍如何使用 C 程序创建文件 在这些示例中 我们将创建新的 HTML 文件并向其中写入一些内容 文件的内容会有所不同 但这三个 C 示例程序应该向大家说明如何使用 fopen fprintf 等 c 文件函数来创建和操作文件 示例一
  • ibm中间键服务器缺少文件夹,存储中间件-MQ常见问题解决方法FAQ.doc

    存储中间件 MQ常见问题解决方法FAQ IBM Websphere MQ FAQ Last Release 2006 1 2 这里整理了IBM Websphere MQ的一些常见错误和解决方法 当发现MQ错误而一时无法解决时 可以参阅这里的
  • 【LibTorch】C++中部署TorchScript模型

    文章目录 1 LibTorch安装 2 C 调用PyTorch模型 2 1 Python中保存tensor数据 2 2 C 中保存tensor数据 2 3 C 加载tensor并调用模型 3 编译执行C 推理用例 3 1 编写CMakeLi
  • Kali配置SSH服务,并且通过Xshell远程登录

    在很多时候 需要通过远程登录到Kali主机进行操作 什么是SSH SSH 为建立在应用层基础上的安全协议 SSH 是较可靠 专为远程登录会话和其他网络服务提供安全性的协议 利用 SSH 协议可以有效防止远程管理过程中的信息泄露问题 1 配置
  • 把正整数数组里面的数字组合成最小的数字

    题目描述 把数组里所有数字拼接起来排成一个数 打印能拼接出的所有数字中最小的一个 例如输入数组 7 302 12 则打印出这三个数字能排成的最小数字为123027 题目分析 需要打印出三个数字可以排成的最小数字 表明算法涉及全排列 算法设计
  • VSCode之C++ & CUDA极简环境配置

    背景 想要了解CUDA并行计算原理 同时针对深度学习中出现一些 不支持算子 可能需要手写的需要 配置一个简单的CUDA编译环境 探索CUDA编程的范式 注 CUDA环境配置略 结果展示 示例代码 include cuda runtime h
  • Open3D 点云按高程进行渲染赋色

    目录 一 概述 二 代码实现 三 结果展示 一 概述 如题 使用Open3D内置函数来基于点云的高程对点云进行颜色渲染赋值 其结果如下图所示 此外 还可以根据颜色配赋表选择任意形式的渲染效果 附 配赋表 二 代码实现 import nump
  • 文本情感分析当前研究热点

    先介绍文本情感分析主要的数据集 Stanford Sentiment Treebank 11855个句子划分为239231个短语 每个短语有个概率值 越小越负面 越大越正面链接 IMDB 100 000句子 正面负面两类链接 附LSTM和C
  • 使用datetime库,对当前日期输出3种不同日期输出方法。

    import datetime import time print datetime date today print time strftime Y m d time localtime time time 更详细time strftim
  • [Linux] 多网卡主机之间指定双方通信网卡的办法

    一 Linux 下使用router 工具 指定路由解析 先看一下本机的路由信息 root gt route Kernel IP routing table Destination Gateway Genmask Flags Metric R
  • kafka后台启动命令

    命令 sh kafka server start sh config server properties 目的是想让服务后台启动 符号代表后台启动 运行命令后服务确实后台启动了 但日志会打印在控制台 而且关掉命令行窗口 服务就会随之停止 这
  • 万物分割SAM使用教程

    文章目录 安装 使用 全图分割 点 框 完整代码 原理篇 安装 创建虚拟环境 conda create n sam python 3 8 激活环境 conda activate sam 下载代码 git clone git github c