Tensorflow numpy 图像重塑 [灰度图像]

2023-12-25

我正在尝试使用我训练过的神经网络数据在 jupyter 笔记本中执行 Tensorflow“object_detection_tutorial.py”,但它会抛出 ValueError。上面提到的文件是 YouTube 上用于对象检测的 Sentdexs 张量流教程的一部分。

你可以在这里找到它: ()

我的图像尺寸:490x704。这样就会得到一个 344960 数组。

但它说:ValueError: cannot reshape array of size 344960 into shape (490,704,3)

我究竟做错了什么?

Code:

Imports

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile

from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image

环境设置

# This is needed to display the images.
%matplotlib inline

# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("..")

物体检测导入

from utils import label_map_util

from utils import visualization_utils as vis_util

变量

# What model to download.
MODEL_NAME = 'shard_graph'

# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'

# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('training', 'object-detection.pbtxt')

NUM_CLASSES = 90

将(冻结的)Tensorflow 模型加载到内存中。

detection_graph = tf.Graph()
with detection_graph.as_default():
  od_graph_def = tf.GraphDef()
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
    serialized_graph = fid.read()
    od_graph_def.ParseFromString(serialized_graph)
    tf.import_graph_def(od_graph_def, name='')

加载标签图

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

辅助代码

def load_image_into_numpy_array(image):
  (im_width, im_height) = image.size
  return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)

检测

# For the sake of simplicity we will use only 2 images:
# image1.jpg
# image2.jpg
# If you want to test the code with your images, just add path to the images to the TEST_IMAGE_PATHS.
PATH_TO_TEST_IMAGES_DIR = 'test_images'
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'frame_{}.png'.format(i)) for i in range(0, 2) ]

# Size, in inches, of the output images.
IMAGE_SIZE = (12, 8)

-

with detection_graph.as_default():
  with tf.Session(graph=detection_graph) as sess:
    # Definite input and output Tensors for detection_graph
    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
    # Each box represents a part of the image where a particular object was detected.
    detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
    # Each score represent how level of confidence for each of the objects.
    # Score is shown on the result image, together with the class label.
    detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
    detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
    num_detections = detection_graph.get_tensor_by_name('num_detections:0')
    for image_path in TEST_IMAGE_PATHS:
      image = Image.open(image_path)
      # the array based representation of the image will be used later in order to prepare the
      # result image with boxes and labels on it.
      image_np = load_image_into_numpy_array(image)
      # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
      image_np_expanded = np.expand_dims(image_np, axis=0)
      # Actual detection.
      (boxes, scores, classes, num) = sess.run(
          [detection_boxes, detection_scores, detection_classes, num_detections],
          feed_dict={image_tensor: image_np_expanded})
      # Visualization of the results of a detection.
      vis_util.visualize_boxes_and_labels_on_image_array(
          image_np,
          np.squeeze(boxes),
          np.squeeze(classes).astype(np.int32),
          np.squeeze(scores),
          category_index,
          use_normalized_coordinates=True,
          line_thickness=8)
      plt.figure(figsize=IMAGE_SIZE)
      plt.imshow(image_np)

脚本的最后一部分抛出错误:

----------------------------------------------------------------------
ValueError                           Traceback (most recent call last)
<ipython-input-62-7493eea60222> in <module>()
     14       # the array based representation of the image will be used later in order to prepare the
     15       # result image with boxes and labels on it.
---> 16       image_np = load_image_into_numpy_array(image)
     17       # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
     18       image_np_expanded = np.expand_dims(image_np, axis=0)

<ipython-input-60-af094dcdd84a> in load_image_into_numpy_array(image)
      2   (im_width, im_height) = image.size
      3   return np.array(image.getdata()).reshape(
----> 4       (im_height, im_width, 3)).astype(np.uint8)

ValueError: cannot reshape array of size 344960 into shape (490,704,3)

Edit:

所以我改变了这个函数的最后一行:

def load_image_into_numpy_array(image):
  (im_width, im_height) = image.size
  return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)

to:

(im_height, im_width)).astype(np.uint8)

ValueError 已解决。但现在引发了另一个与数组格式相关的 ValueError:

----------------------------------------------------------------------
ValueError                           Traceback (most recent call last)
<ipython-input-107-7493eea60222> in <module>()
     20       (boxes, scores, classes, num) = sess.run(
     21           [detection_boxes, detection_scores, detection_classes, num_detections],
---> 22           feed_dict={image_tensor: image_np_expanded})
     23       # Visualization of the results of a detection.
     24       vis_util.visualize_boxes_and_labels_on_image_array(

~/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    898     try:
    899       result = self._run(None, fetches, feed_dict, options_ptr,
--> 900                          run_metadata_ptr)
    901       if run_metadata:
    902         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

~/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1109                              'which has shape %r' %
   1110                              (np_val.shape, subfeed_t.name,
-> 1111                               str(subfeed_t.get_shape())))
   1112           if not self.graph.is_feedable(subfeed_t):
   1113             raise ValueError('Tensor %s may not be fed.' % subfeed_t)

ValueError: Cannot feed value of shape (1, 490, 704) for Tensor 'image_tensor:0', which has shape '(?, ?, ?, 3)'

这是否意味着这个张量流模型不是为灰度图像设计的?有办法让它发挥作用吗?

SOLUTION

感谢 Matan Hugi,它现在工作得很好。我所要做的就是将此函数更改为:

def load_image_into_numpy_array(image):
    # The function supports only grayscale images
    last_axis = -1
    dim_to_repeat = 2
    repeats = 3
    grscale_img_3dims = np.expand_dims(image, last_axis)
    training_image = np.repeat(grscale_img_3dims, repeats, dim_to_repeat).astype('uint8')
    assert len(training_image.shape) == 3
    assert training_image.shape[-1] == 3
    return training_image

Tensorflow 预期输入以 NHWC 格式格式化, 这意味着:(批次、高度、宽度、通道)。

第 1 步 - 添加最后一个维度:

last_axis = -1
grscale_img_3dims = np.expand_dims(image, last_axis)

步骤 2 - 重复最后一个维度 3 次:

dim_to_repeat = 2
repeats = 3
np.repeat(grscale_img_3dims, repeats, dim_to_repeat)

所以你的函数应该是:

def load_image_into_numpy_array(image):
    # The function supports only grayscale images
    assert len(image.shape) == 2, "Not a grayscale input image" 
    last_axis = -1
    dim_to_repeat = 2
    repeats = 3
    grscale_img_3dims = np.expand_dims(image, last_axis)
    training_image = np.repeat(grscale_img_3dims, repeats, dim_to_repeat).astype('uint8')
    assert len(training_image.shape) == 3
    assert training_image.shape[-1] == 3
    return training_image
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow numpy 图像重塑 [灰度图像] 的相关文章

随机推荐

  • Java POI:如何读取Excel单元格值而不是公式计算?

    我正在使用 Apache POI API 从 Excel 文件中获取值 除了包含公式的单元格之外 一切都运行良好 事实上 cell getStringCellValue 返回单元格中使用的公式 而不是单元格的值 我尝试使用evaluateF
  • 使用 equals 方法比较字符串并 == [重复]

    这个问题在这里已经有答案了 可能的重复 如何在 Java 中比较字符串 https stackoverflow com questions 513832 how do i compare strings in java Java Strin
  • 如何使用 tqdm 迭代列表

    我想知道处理某个列表需要多长时间 for a in tqdm list1 if a in list2 do something 但这不起作用 如果我使用for a in tqdm range list1 我将无法检索列表值 你知道怎么做吗
  • 无法从 Django Docker 实例内部访问项目绝对 url

    我有一个使用 Cookiecutter Django 启动的项目 目前我正在添加 WeasyPrint 以将某些视图作为 PDF 文件提供 这在开发中运行良好 Cookiecutter Django 使用 Caddy 作为 HTTP 服务器
  • 禁止实例化为临时对象 (C++)

    我喜欢在 C 中使用哨兵类 但我似乎有一种精神困扰 导致反复编写如下错误 MySentryClass arg other code 不用说 这会失败 因为哨兵在创建后立即死亡 而不是按预期在作用域结束时死亡 有没有某种方法可以防止 MySe
  • Django CreateView 不保存对象

    我正在使用基本的博客应用程序练习 django 基于类的视图 然而 由于某种原因 我的 Post 模型的 CreateView 没有将帖子保存在数据库中 模型 py class Post models Model user models F
  • 如何正确使用头文件成为一个完整的类?

    初学者程序员 我遵循工作正常的头文件的样式 但我试图弄清楚在编译时如何不断收到所有这些错误 我正在 Cygwin 中使用 g 进行编译 Ingredient h 8 13 error expected unqualified id befo
  • 进化算法:最优重新群体分解

    这确实是标题中的全部内容 但对于任何对进化算法感兴趣的人来说 这里有一个细分 在 EA 中 基本前提是随机生成一定数量的有机体 实际上只是参数集 针对问题运行它们 然后让表现最好的有机体生存下来 然后 你会重新填充幸存者的杂交品种 幸存者的
  • 如何在 pandas 数据框中执行不同值的累积和

    我有一个像这样的数据框 id date company 123 2019 01 01 A 224 2019 01 01 B 345 2019 01 01 B 987 2019 01 03 C 334 2019 01 03 C 908 201
  • Delphi中从C DLL获取字符串返回值

    我有一个用 C 编写的遗留 DLL 其中包含一个返回字符串的函数 我需要从 Delphi 访问该函数 我所掌握的有关 DLL 的唯一信息是用于访问该函数的 VB 声明 公开声明函数 DecryptStr Lib strlib Str As
  • 根据标签对一行中的每个句子进行评分并总结文本。 (爪哇)

    我正在尝试用 Java 创建一个摘要器 我正在使用斯坦福对数线性词性标注器 http nlp stanford edu software tagger shtml标记单词 然后 对于某些标记 我对句子进行评分 最后在摘要中 我打印具有高分值
  • 无法读取 PNG 签名:文件不以 PNG 签名开头

    Gradle 构建失败并出现以下错误 Error C Users Roman gradle caches transforms 1 files 1 1 appcompat v7 26 0 2 aar bab547c3f1b8061ef942
  • 使用 GhostScript 将 pdf 转换为图像 - 如何引用 gsdll32.dll?

    我正在尝试使用 GhostScript 从 pdf 创建图像 这是我的代码 GhostscriptWrapper ConvertToBMP inputPDFFilePath outputBMPFilePath 这是我的Ghostscript
  • 复合组件属性中的枚举值

    我的问题非常简单 我想创建一个具有字符串属性 Type 的复合组件
  • 将处理3嵌入到swing中

    我正在尝试将Processing 3 集成到swing 应用程序中 但是因为PApplet 不再扩展Applet 所以我不能立即将其添加为组件 无论如何 是否可以将Processing 3 草图嵌入到Swing 中 如果我可以在没有PDE
  • Gradle 无法使用 OBJECT 库构建 CMake 项目,因为它需要输出文件

    My 构建 gradle文件包含以下内容以使用 CMake 构建项目 externalNativeBuild cmake Provides a relative path to your CMake build script version
  • 每个工作表循环的 Excel VBA

    我正在编写代码 基本上浏览工作簿中的每张工作表 然后更新列宽 下面是我写的代码 我没有收到任何错误 但它实际上也没有做任何事情 任何帮助是极大的赞赏 Option Explicit Dim ws As Worksheet a As Rang
  • 文本字体大小

    我创造了不同的layouts layout layout small layout normal layout large layout xlarge 并为values values values ldpi values mdpi valu
  • 如果其他类可见或显示,JQuery 隐藏类

    发现类似的问题 但没有什么能完全满足我的需要 我在示例中保持简单 并且我想使用 JQuery 我有两节课 如果页面加载时显示 类别 div 我想隐藏 过滤器 div 目前没有与这两个类别相关的样式 我相信我已经很接近了 但它不起作用 div
  • Tensorflow numpy 图像重塑 [灰度图像]

    我正在尝试使用我训练过的神经网络数据在 jupyter 笔记本中执行 Tensorflow object detection tutorial py 但它会抛出 ValueError 上面提到的文件是 YouTube 上用于对象检测的 Se