毕设 ssd tf_gpu2 predict.py 备份代码

2023-11-02

如图 代码所属:https://github.com/bubbliiiing/ssd-tf2

怕到时候改坏了 unbelievable(不是)

import time

import cv2
import numpy as np
import tensorflow as tf
from PIL import Image

from ssd import SSD

gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

if __name__ == "__main__":
    ssd = SSD()
    #----------------------------------------------------------------------------------------------------------#
    #   mode用于指定测试的模式:
    #   'predict'           表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
    #   'video'             表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
    #   'fps'               表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。
    #   'dir_predict'       表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
    #----------------------------------------------------------------------------------------------------------#
    mode = "video"
    #-------------------------------------------------------------------------#
    #   crop                指定了是否在单张图片预测后对目标进行截取
    #   count               指定了是否进行目标的计数
    #   crop、count仅在mode='predict'时有效
    #-------------------------------------------------------------------------#
    crop            = False
    count           = False
    #----------------------------------------------------------------------------------------------------------#
    #   video_path          用于指定视频的路径,当video_path=0时表示检测摄像头
    #                       想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
    #   video_save_path     表示视频保存的路径,当video_save_path=""时表示不保存
    #                       想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
    #   video_fps           用于保存的视频的fps
    #
    #   video_path、video_save_path和video_fps仅在mode='video'时有效
    #   保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
    #----------------------------------------------------------------------------------------------------------#
    video_path      = "video/SVID_6.mp4"
    video_save_path = "video_out/SVID_6_ep120_out.mp4"
    video_fps       = 25.0
    #----------------------------------------------------------------------------------------------------------#
    #   test_interval       用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。
    #   fps_image_path      用于指定测试的fps图片
    #   
    #   test_interval和fps_image_path仅在mode='fps'有效
    #----------------------------------------------------------------------------------------------------------#
    test_interval   = 100
    fps_image_path  = "img/street.jpg"
    #-------------------------------------------------------------------------#
    #   dir_origin_path     指定了用于检测的图片的文件夹路径
    #   dir_save_path       指定了检测完图片的保存路径
    #   
    #   dir_origin_path和dir_save_path仅在mode='dir_predict'时有效
    #-------------------------------------------------------------------------#
    dir_origin_path = "img/"
    dir_save_path   = "img_out/"

    if mode == "predict":
        '''
        1、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。 
        2、如果想要获得预测框的坐标,可以进入ssd.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。
        3、如果想要利用预测框截取下目标,可以进入ssd.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值
        在原图上利用矩阵的方式进行截取。
        4、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入ssd.detect_image函数,在绘图部分对predicted_class进行判断,
        比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。
        '''
        while True:
            img = input('Input image filename:')
            try:
                image = Image.open(img)
            except:
                print('Open Error! Try again!')
                continue
            else:
                r_image = ssd.detect_image(image, crop = crop, count=count)
                r_image.show()

    elif mode == "video":
        capture = cv2.VideoCapture(video_path)
        if video_save_path!="":
            fourcc  = cv2.VideoWriter_fourcc(*'XVID')
            size    = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
            out     = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)

        ref, frame = capture.read()
        if not ref:
            raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")

        fps = 0.0
        while(True):
            t1 = time.time()
            # 读取某一帧
            ref, frame = capture.read()
            if not ref:
                break
            # 格式转变,BGRtoRGB
            frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
            # 转变成Image
            frame = Image.fromarray(np.uint8(frame))
            # 进行检测
            frame = np.array(ssd.detect_image(frame))
            # RGBtoBGR满足opencv显示格式
            frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
            
            fps  = ( fps + (1./(time.time()-t1)) ) / 2
            print("fps= %.2f"%(fps))
            frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            
            cv2.imshow("video",frame)
            c= cv2.waitKey(1) & 0xff 
            if video_save_path!="":
                out.write(frame)

            if c==27:
                capture.release()
                break

        print("Video Detection Done!")
        capture.release()
        if video_save_path!="":
            print("Save processed video to the path :" + video_save_path)
            out.release()
        cv2.destroyAllWindows()
        
    elif mode == "fps":
        img = Image.open(fps_image_path)
        tact_time = ssd.get_FPS(img, test_interval)
        print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')

    elif mode == "dir_predict":
        import os

        from tqdm import tqdm

        img_names = os.listdir(dir_origin_path)
        for img_name in tqdm(img_names):
            if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
                image_path  = os.path.join(dir_origin_path, img_name)
                image       = Image.open(image_path)
                r_image     = ssd.detect_image(image)
                if not os.path.exists(dir_save_path):
                    os.makedirs(dir_save_path)
                r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)

    else:
        raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps' or 'dir_predict'.")

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

毕设 ssd tf_gpu2 predict.py 备份代码 的相关文章

  • Tensorflow 中的自定义资源

    由于某些原因 我需要为 Tensorflow 实现自定义资源 我试图从查找表实现中获得灵感 如果我理解得好的话 我需要实现3个TF操作 创建我的资源 资源的初始化 例如 在查找表的情况下填充哈希表 执行查找 查找 查询步骤 为了促进实施 我
  • 异常:加载数据时 URL 获取失败

    我正在尝试设置我的机器来运行 Tensorflow 2 我从未使用过 Tensorflow 只是下载了 Python 3 7 我不确定这是否是我的机器的问题 我按照上面列出的安装说明进行操作TensorFlow 的网站 https www
  • pip:需要将包名称tensorflow-gpu更改为tensorflow

    我正在尝试将具有 GPU 支持的张量流安装到 conda 环境中 我使用命令 pip install ignore installed upgrade https storage googleapis com tensorflow linu
  • 我可以在我的机器上同时安装 python 2.7 和 3.5 的tensorflow吗?

    目前我通过 Anaconda 在我的机器 MAC OX 上安装了 Python 2 7 Python 3 5 Tensorflow for Python 3 5 我也想在我的机器上安装 Tensorflow for Python 2 7 当
  • 张量流中的复杂卷积

    我正在尝试运行一个简单的卷积 但包含复数 r np random random 1 10 10 10 i np random random 1 10 10 10 x tf complex r i conv layer tf layers c
  • 在 Keras 模型中删除然后插入新的中间层

    给定一个预定义的 Keras 模型 我尝试首先加载预先训练的权重 然后删除一到三个模型内部 非最后几层 层 然后用另一层替换它 我似乎找不到任何有关的文档keras io https keras io 即将做这样的事情或从预定义的模型中删除
  • 为什么我的结果仍然无法重现?

    我想要为 CNN 获得可重复的结果 我使用带有 GPU 的 Keras 和 Google Colab 除了建议插入某些代码片段 这应该允许再现性 之外 我还在层中添加了种子 This is the first code snipped to
  • Tensorflow 对 Python3.11 的支持

    我在 Windows10 PC 上安装了 Python3 11 0 尝试使用以下命令安装张量流 pip install tensorflow 给出错误 访问tensorflow网站后 我意识到它仅支持3 7 3 10 我应该降级 pytho
  • AttributeError:模块“tensorflow.python.summary.summary”没有属性“FileWriter”

    我收到此错误 尽管我到处都看过file writer tf summary FileWriter path to logs sess graph 被提到为正确的实施this https github com tensorflow tenso
  • 在 Tensorflow 中每行选择一个元素的优雅方法

    Given 一个矩阵A形状的 m n 张量I形状的 m 我想要一份清单J的元素来自A where J i A i I i 那是 I保存要从每行中选择的元素的索引A 背景 我已经有了argmax A 1 现在我也想要max 我知道我可以使用r
  • 如何在 Tensorflow 对象检测 API 中查找边界框坐标

    我正在使用 Tensorflow 对象检测 API 代码 我训练了我的模型并获得了很高的检测百分比 我一直在尝试获取边界框坐标 但它不断打印出 100 个奇怪数组的列表 经过在线广泛搜索后 我发现数组中的数字意味着什么 边界框坐标相对于底层
  • 移动设备上的 TensorFlow(Android、iOS、Windows Phone)

    我目前正在寻找不同的深度学习框架 特别是用于训练和部署卷积神经网络 要求是 它可以在带有 GPU 的普通 PC 上进行训练 但训练后的模型必须部署在三个主要的移动操作系统上 即 Android iOS 和 Windows Phone Ten
  • 在c++中的嵌入式python中导入tensorflow时出错

    我的问题是关于在 C 程序中嵌入 Python 3 5 解释器以从 C 接收图像 并将其用作我训练的张量流模型的输入 当我在 python 代码中导入tensorflow库时 出现错误 其他库工作正常 简化后的代码如下 include
  • PyInstaller 是否包含 CUDA

    我正在开发一个Python脚本 我使用Python 3 7 3 它使用tensorflow gpu 1 14 0 并使用PyInstaller 3 5将此脚本转换为可执行文件 我使用的是 CUDA 10 0 和 cuDNN 7 6 1 我的
  • 如何在张量流中使用带有估计器的衰减学习率?

    我正在尝试将 LinearClassifier 与具有衰减学习率的 GradientDescentOptimizer 一起使用 My code def main load data features np load data feature
  • 卷积神经网络 (CNN) 输入形状

    我是 CNN 的新手 我有一个关于 CNN 的问题 我对 CNN 特别是 Keras 的输入形状有点困惑 我的数据是不同时隙的二维数据 比方说10X10 因此 我有 3D 数据 我将把这些数据输入到我的模型中来预测即将到来的时间段 所以 我
  • 增加 sigmoid 预测输出值?

    我创建了一个用于文本分类的 Conv1D 模型 当在最后一个密集处使用 softmax sigmoid 时 它产生的结果为 softmax gt 0 98502016 0 0149798 sigmoid gt 0 03902826 0 00
  • 查找张量流运算所依赖的所有变量

    有没有办法找到给定操作 通常是损失 所依赖的所有变量 我想用它来将该集合传递到optimizer minimize or tf gradients 使用各种set intersection 组合 到目前为止我已经找到了op op input
  • Keras 中批量大小可变的batch_dot

    我正在尝试编写一个层来合并 2 个张量formula https i stack imgur com I49aj png x 0 和x 1 的形状都是 1 500 M是500 500的矩阵 我希望输出为 500 500 我认为这在理论上是可
  • Tensorflow 数据集的数据预处理是针对整个数据集还是针对每次调用 iterator.next() 进行一次?

    您好 我现在正在研究tensorflow中的数据集API 我有一个关于执行数据预处理的dataset map 函数的问题 file name image1 jpg image2 jpg im dataset tf data Dataset

随机推荐

  • didUpdateWidget详解

    概述 只要在父widget中调用setState 子widget的didUpdateWidget就一定会被调用 不管父widget传递给子widget构造方法的参数有没有改变 只要didUpdateWidget被调用 接来下build方法就
  • 低功耗蓝牙MESH基础知识

    一 MESH VS 点对点 大多数蓝牙低功耗设备使用一对一简单点对点网络拓扑结构来进行相互间的通信 在蓝牙核心规格中 这称为 微微网 想象一下 智能手机已经建立了与心率监测仪的点对点连接 并可借此传输数据 同样的智能手机也可以建立与其他设备
  • gettimeofday 获取毫秒时间溢出问题

    之前为了测试C中代码执行消耗的时间 所以写了这么一个函数 long long getmstime timeval tv gettimeofday tv NULL return tv tv sec 1000 tv tv usec 1000 之
  • 使用Hutool向第三方的接口发起请求

    使用Hutool工具请求第三方接口遇到的一篮子问题 1 请求第三方接口的几种方式 1 1 使用HttpUtil请求 返回String类型的JSON串 一般用在请求普通的页面情况下 返回的结果是JSON格式 但是如果出现了404 504错误
  • newInstance过时

    在今天使用反射的newInstance 时候发现 jdk9版本将class newInstance 过时 Class stack1 Class forName Stack Stack stack2 Stack stack1 getConst
  • leetcode 300. Longest Increasing Subsequence

    leetcode 300 Longest Increasing Subsequence 题目 Given an unsorted array of integers find the length of longest increasing
  • PHP中小型民宿酒店管理系统源码

    PHP中小型民宿酒店管理系统源码 近年来 民宿酒店行业以其独特的住宿体验和个性化服务受到越来越多旅行者的青睐 为了提高运营效率 改善客户体验 许多中小型民宿酒店开始引入管理系统 本文将介绍一款基于PHP开发的中小型民宿酒店管理系统源码 帮助
  • Axure动态布局,中部加入滚动条

    1 将部件设置为动态面板 然后再部件属性和样式中使用按需显示纵向滚动条 2 可以很好的处理因为内部页面过大挤占低端内容的问题
  • 解决Excel打开CSV文件中文乱码问题

    CSV打开乱码的处理方法 方法一 Excel的数据导入功能 方法二 CSV打开乱码的处理方法 CSV是用UTF 8编码的 而EXCEL是ANSI编码 由于编码方式不一致导致出现乱码 明白了原因之后 我们只需要把CSV文件的编码方式修改成与E
  • 第5章 数组 第3题

    题目 编写一个程序 输入一个字符串 输出其中每个字符在字母表中的序号 对于不是英文字母的字符 输出0 例如 输入为 acbf8g 输出为1 3 2 6 0 7 代码 include
  • 基础练习—矩阵乘法

    题目描述 给定一个N阶矩阵A 输出A的M次幂 M是非负整数 例如 A 1 2 3 4 A的2次幂 7 10 15 22 输入 第一行是一个正整数N M 1 lt N lt 30 0 lt M lt 5 表示矩阵A的阶数和要求的幂数 接下来N
  • excel或txt格式坐标到面图层(python)

    背景 现有如下图所示的多个界址点坐标 excel格式或txt格式 需求 根据大量界址点坐标转换为界址点坐标对应的面shp图层 解决思路 一 为方便处理首先将txt文件或excel文件转换为csv文件 逗号分割符 二 为方便理解和使用 我们将
  • python numpy的学习

    0 引入numpy import numpy as np 1 将list变成np a 1 2 3 4 5 6 b np array a 将list变成array a shape 2 3 2行3列 a shape 0 2 获取行数 a sha
  • 攻防世界之WEB新手练习区(更新至11)

    攻防世界之WEB新手练习区 目录 001 view source 002 get post 003robots 004backup 005cookie 006disable button 007simple js 008xff refere
  • FTDI FT2232H在嵌入式教学中的应用

    FT2232H是FTDI chip在2012年发布的一款高速USB转串行通信的协议转换芯片 作为第五代USB协议转串行总线通信协议的芯片 完全符合USB2 0规范 480Mb s 并且可以依靠编程的方式配置成为串行或者并行的其他总线接口规范
  • Pycharm无法正常安装第三方库的时候,有以下几条应对方法

    1 首先检查自己的环境变量是否配置正确 点击setting 点击 Python Interpreter 点击Add Interpreter 配置完毕之后再试一次从这里下载 如果还不行的话可以换其他方法 2 从cmd或Pycharm Term
  • 什么是EL表达式

    EL表达式 expression language 即表达语言 它是为了便于存取数据而定义的一种语言 JSP2 0之后才成为一种标准 形式 以 开头 以 结尾 通过PAGE指令来说明是否支持EL表达式 具体举例 声明可以使用EL表达式 如果
  • 【腾讯云 TDSQL-C Serverless 产品测评】全面测评TDSQL-C Mysql Serverless

    全面测评TDSQL C Mysql Serverless 文章目录 全面测评TDSQL C Mysql Serverless 前言 什么是TDSQL C Mysql Serverless 初始化 TDSQL C Mysql Serverle
  • 前车之覆,后车之鉴——开源项目经验谈

    前车之覆 后车之鉴 开源项目经验谈 本文发表于 程序员 2005年第2期 随着开源文化的日益普及 参与开源 似乎也变成了一种时尚 一时间 似乎大家都乐于把自己的代码拿出来分享了 就在新年前夕 我的一位老朋友 一位向来对开源嗤之以鼻的J2EE
  • 毕设 ssd tf_gpu2 predict.py 备份代码

    如图 代码所属 https github com bubbliiiing ssd tf2 怕到时候改坏了 unbelievable 不是 import time import cv2 import numpy as np import te