如何在 Tensorflow 上测试自己的图像到 Cifar-10 教程?

2024-04-21

我训练了 Tensorflow Cifar10 模型,我想为其提供自己的单个图像(32*32,jpg/png)。

我想将标签和每个标签的概率视为输出,但我对此遇到了一些麻烦。

搜索堆栈溢出后,我发现了一些帖子this https://stackoverflow.com/questions/37578876/how-to-feed-cifar10-trained-model-with-my-own-image-and-get-label-as-output我修改了cifar10_eval.py。

但这根本不起作用。

错误消息是:

InvalidArgumentError 回溯(最近一次调用最后一次) in () ----> 1 evaluate()

在评估() 86 # 从检查点恢复 87 print("ckpt.model_checkpoint_path", ckpt.model_checkpoint_path) ---> 88 saver.restore(sess, ckpt.model_checkpoint_path) 89 # 假设 model_checkpoint_path 看起来像: 90 # /我最喜欢的路径/cifar10_train/model.ckpt-0,

/home/huray/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/training/saver.pyc 在恢复(自我,sess,save_path)1127提高 ValueError(“使用无效的保存路径%s调用恢复”% save_path)
第1128章 -> 1129 {self.saver_def.filename_tensor_name: save_path}) 1130 1131 @staticmethod

/home/huray/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc 在运行(自我,获取,feed_dict,选项,run_metadata) 第380章 第381章 --> 382 run_metadata_ptr) 第383章 第384章

/home/huray/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc 在 _run(self, 句柄, 获取, feed_dict, 选项, run_metadata) 第653章 第654章 --> 655 feed_dict_string、选项、run_metadata) 第656章 第657章

/home/huray/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc 在 _do_run(self, 句柄, target_list, fetch_list, feed_dict, 选项, 运行元数据) 第721章 第722章 --> 723 target_list、选项、run_metadata) 第724章: 第725章

/home/huray/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc 在 _do_call(self, fn, *args) 中 741 除了 KeyError: 第742章 --> 743 引发类型(e)(node_def,op,消息) 第744章 第745章

InvalidArgumentError:分配需要两个张量的形状匹配。 左轴形状= [18,384] 右轴形状= [2304,384] [[节点:保存/分配_5 = 分配[T=DT_FLOAT, _class=["loc:@local3/weights"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](local3/weights, save/restore_slice_5)]]

任何有关 Cifar10 的帮助将不胜感激。

这是迄今为止已实现的代码,但存在编译问题:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
import math
import time

import numpy as np
import tensorflow as tf
import cifar10

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('eval_dir', '/tmp/cifar10_eval',
                           """Directory where to write event logs.""")
tf.app.flags.DEFINE_string('eval_data', 'test',
                           """Either 'test' or 'train_eval'.""")
tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/cifar10_train',
                           """Directory where to read model checkpoints.""")
tf.app.flags.DEFINE_integer('eval_interval_secs', 5,
                            """How often to run the eval.""")
tf.app.flags.DEFINE_integer('num_examples', 1,
                            """Number of examples to run.""")
tf.app.flags.DEFINE_boolean('run_once', False,
                         """Whether to run eval only once.""")

def eval_once(saver, summary_writer, top_k_op, summary_op):
  """Run Eval once.

  Args:
    saver: Saver.
    summary_writer: Summary writer.
    top_k_op: Top K op.
    summary_op: Summary op.
  """
  with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
      # Restores from checkpoint
      saver.restore(sess, ckpt.model_checkpoint_path)
      # Assuming model_checkpoint_path looks something like:
      #   /my-favorite-path/cifar10_train/model.ckpt-0,
      # extract global_step from it.
      global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
    else:
      print('No checkpoint file found')
      return
    print("Check point : %s" % ckpt.model_checkpoint_path)

    # Start the queue runners.
    coord = tf.train.Coordinator()
    try:
      threads = []
      for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
                                         start=True))

      num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
      true_count = 0  # Counts the number of correct predictions.
      total_sample_count = num_iter * FLAGS.batch_size
      step = 0
      while step < num_iter and not coord.should_stop():
        predictions = sess.run([top_k_op])
        true_count += np.sum(predictions)
        step += 1

      # Compute precision @ 1.
      precision = true_count / total_sample_count
      print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))

      summary = tf.Summary()
      summary.ParseFromString(sess.run(summary_op))
      summary.value.add(tag='Precision @ 1', simple_value=precision)
      summary_writer.add_summary(summary, global_step)
    except Exception as e:  # pylint: disable=broad-except
      coord.request_stop(e)

    coord.request_stop()
    coord.join(threads, stop_grace_period_secs=10)


def evaluate():
  """Eval CIFAR-10 for a number of steps."""
  with tf.Graph().as_default() as g:
    # Get images and labels for CIFAR-10.
    eval_data = FLAGS.eval_data == 'test'
#     images, labels = cifar10.inputs(eval_data=eval_data)

    # TEST CODE
    img_path = "/TEST_IMAGEPATH/image.png"
    input_img = tf.image.decode_png(tf.read_file(img_path), channels=3)
    casted_image = tf.cast(input_img, tf.float32)

    reshaped_image = tf.image.resize_image_with_crop_or_pad(casted_image, 24, 24)
    float_image = tf.image.per_image_withening(reshaped_image)
    images = tf.expand_dims(reshaped_image, 0) 

    logits = cifar10.inference(images)
    _, top_k_pred = tf.nn.top_k(logits, k=1)


    with tf.Session() as sess:
        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
          print("ckpt.model_checkpoint_path ", ckpt.model_checkpoint_path)
          saver.restore(sess, ckpt.model_checkpoint_path)
          global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
        else:
          print('No checkpoint file found')
          return

        print("Check point : %s" % ckpt.model_checkpoint_path)
        top_indices = sess.run([top_k_pred])
        print ("Predicted ", top_indices[0], " for your input image.")

evaluate()

视频https://youtu.be/d9mSWqfo0Xw https://youtu.be/d9mSWqfo0Xw显示了对单个图像进行分类的示例。

在网络经过 python cifar10_train.py 训练后,我们评估 CIFAR-10 数据库的单个图像 deer6.png 和火柴盒自己的照片。 TF教程原始源码最重要的修改如下:

首先,需要将这些图像转换为 cifar10_input.py 可以读取的二进制形式。通过使用可以在以下位置找到的代码片段可以轻松完成此操作如何创建类似于 cifar-10 的数据集 https://stackoverflow.com/questions/35032675/how-to-create-dataset-similar-to-cifar-10

然后,为了读取转换后的图像(称为 input.bin),我们需要修改 cifar10_input.py 中的函数 input():

  else:
    #filenames = [os.path.join(data_dir, 'test_batch.bin')]
    filenames = [os.path.join(data_dir, 'input.bin')]

(data_dir 等于“./”)

最后为了获得标签,我们修改了源 cifar10_eval.py 中的函数 eval_once():

      #while step < num_iter and not coord.should_stop():
      #  predictions = sess.run([top_k_op])
      print(sess.run(logits[0]))
      classification = sess.run(tf.argmax(logits[0], 0))
      cifar10classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
      print(cifar10classes[classification])

      #true_count += np.sum(predictions)
      step += 1

      # Compute precision @ 1.
      precision = true_count / total_sample_count
      # print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))

当然,您需要进行一些小的修改。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在 Tensorflow 上测试自己的图像到 Cifar-10 教程? 的相关文章

  • 使用Python的工业视觉相机[关闭]

    就目前情况而言 这个问题不太适合我们的问答形式 我们希望答案得到事实 参考资料或专业知识的支持 但这个问题可能会引发辩论 争论 民意调查或扩展讨论 如果您觉得这个问题可以改进并可能重新开放 访问帮助中心 help reopen questi
  • 在 python 2 和 3 的spyder之间切换

    根据我在文档中了解到的内容 它指出您只需使用命令提示符创建一个新变量即可轻松在 2 个 python 环境之间切换 如果我已经安装了 python 2 7 则 conda create n python34 python 3 4 anaco
  • 为什么方法无法访问类变量?

    我试图理解Python中的变量作用域 除了我不明白为什么类变量不能从其方法访问的部分之外 大多数事情对我来说都很清楚 在下面的例子中mydef1 无法访问a 但如果a可以在全局范围 类定义之外 声明 class MyClass1 a 25
  • Pytest:如何使用从夹具返回的列表来参数化测试?

    我想使用由固定装置动态创建的列表来参数化测试 如下所示 pytest fixture def my list returning fixture depends on other fixtures return a dynamically
  • 对打开文件的脚本进行单元测试

    我编写了一个脚本 它打开一个文件 读取内容并进行一些操作和计算 并将它们存储在集合和字典中 我该如何为这样的事情编写单元测试 我的问题具体是 我会测试文件是否打开 文件很大 这是unix字典文件 我如何对计算进行单元测试 我真的必须手动计算
  • 如何在“python setup.py test”中运行 py.test 和 linter

    我有一个项目setup py文件 我用pytest作为测试框架 我还在我的代码上运行各种 linter pep8 pylint pydocstyle pyflakes ETC 我用tox在多个 Python 版本中运行它们 并使用以下命令构
  • 如何从网站中提取冠状病毒病例?

    我正在尝试从网站中提取冠状病毒 https www trackcorona live https www trackcorona live 但我得到了一个错误 这是我的代码 response requests get https www t
  • 在python中调用subprocess.Popen时“系统找不到指定的文件”

    我正在尝试使用svnmerge py合并一些文件 它在底层使用 python 当我使用它时 我收到一个错误 系统找不到指定的文件 工作中的同事正在运行相同版本的svnmerge py 以及 python 2 5 2 特别是 r252 609
  • Python 相当于 Bit Twiddling Hacks 中的 C 代码?

    我有一个位计数方法 我正在尝试尽可能快地实现 我想尝试下面的算法位摆弄黑客 http graphics stanford edu seander bithacks html CountBitsSetParallel 但我不知道 C 什么是
  • 熊猫记忆

    我有冗长的计算 我重复了很多次 因此 我想使用记忆 诸如jug http packages python org Jug and joblib http packages python org joblib memory html 与Pan
  • 如何在python中递归复制目录并覆盖全部?

    我正在尝试复制 home myUser dir1 及其所有内容 及其内容等 home myuser dir2 在Python中 此外 我希望副本覆盖中的所有内容dir2 It looks like distutils dir util co
  • Matplotlib 将颜色图 tab20 更改为三种颜色

    Matplotlib 有一些新的且非常方便的颜色图 选项卡颜色图 https matplotlib org examples color colormaps reference html 我错过的是生成像 tab20b 或 tab20c 这
  • 与函数复合 UniqueConstraint

    一个快速的 SQLAlchemy 问题 我有一个 文档 类 其属性为 数字 和 日期 我需要确保没有重复的号码同年 是 有没有办法对 数字 年份 日期 进行UniqueConstraint 我应该使用唯一索引吗 我如何声明功能部分 SQLA
  • 将参数传递给 __enter__

    刚刚学习 with 语句尤其是这篇文章 http effbot org zone python with statement htm 问题是 我可以传递一个参数给 enter 我有这样的代码 class clippy runner def
  • 从 Apache 运行 python 脚本的最简单方法

    我花了很长时间试图弄清楚这一点 我基本上正在尝试开发一个网站 当用户单击特定按钮时 我必须在其中执行 python 脚本 在研究了 Stack Overflow 和 Google 之后 我需要配置 Apache 以便能够运行 CGI 脚本
  • 使用 pandas 绘制带有误差线的条形图

    我正在尝试从 DataFrame 生成条形图 如下所示 Pre Post Measure1 0 4 1 9 这些值是我从其他地方计算出来的中值 我还有它们的方差和标准差 以及标准误差 我想将结果绘制为具有适当误差线的条形图 但指定多个误差值
  • dask allocate() 或 apply() 中的变量列名

    我有适用于pandas 但我在将其转换为使用时遇到问题dask 有一个部分解决方案here https stackoverflow com questions 32363114 how do i change rows and column
  • Python列表对象属性“append”是只读的

    正如标题所说 在Python中 我试图做到这一点 以便当有人输入一个选择 在本例中为Choice13 时 它会从密码列表中删除旧密码并添加新密码 passwords mrjoebblock mrjoefblock mrjoegblock m
  • Melt() 函数复制数据集

    我有一个这样的表 id name doggo floofer puppo pupper 1 rowa NaN NaN NaN NaN 2 ray NaN NaN NaN NaN 3 emma NaN NaN NaN pupper 4 sop
  • 从 Flask 中的 S3 返回 PDF

    我正在尝试在 Flask 应用程序的浏览器中返回 PDF 我使用 AWS S3 来存储文件 并使用 boto3 作为与 S3 交互的 SDK 到目前为止我的代码是 s3 boto3 resource s3 aws access key id

随机推荐