在 Tensorflow 中使用队列将数据馈送到网络时分开验证和训练图

2024-04-28

我一直在做大量关于如何使用队列将数据正确输入网络的研究。但是,我在互联网上找不到任何解决方案。

目前我的代码能够读取训练数据并执行训练,但无需验证和测试。这里有一些重要的行构成了我的代码:

images, volumes = utils.inputs(FLAGS.train_file_path, FLAGS.batch_size, FLAGS.num_epochs)

print("Initiliaze training")
logits = utils.inference(images)
loss_intermediate, loss = utils.get_loss(logits, volumes)

train_optimizer = utils.pre_training(loss, FLAGS.learning_rate)

summary_train = tf.summary.merge_all('train')
summary_test = tf.summary.merge_all('test')

init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

saver = tf.train.Saver(max_to_keep=2)
with tf.Session() as sess:

    summary_writer = tf.summary.FileWriter(FLAGS.train_dir + FLAGS.run, sess.graph)
    summary_writer_test = tf.summary.FileWriter(FLAGS.train_dir + FLAGS.run_test, sess.graph)
    sess.run(init)

    # Start input enqueue threads.
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    print("Start training")

    try:
        step = 0
        while not coord.should_stop():
            start_time = time.time()

            _, loss_intermediate_value, loss_value = sess.run([train_optimizer, loss_intermediate, loss])
            duration = time.time() - start_time
            if step % FLAGS.show_step == 0:
                print('Step %d: loss_intermediate = %.2f, loss = %.5f (%.3f sec)' % (step, loss_intermediate_value, loss_value, duration))
                summary_str = sess.run(summary_train)
                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()

            if step % FLAGS.test_interval == 0:
               ###### HERE VALIDATION HOW ? ############
            step += 1
    except tf.errors.OutOfRangeError:
        print('ERROR IN CODE')
    finally:
        print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
        # When done, ask the threads to stop.
        coord.request_stop()
        # Wait for threads to finish.
        coord.join(threads)

而这个函数就是用来读取数据的。

def inputs(train, batch_size, num_epochs):

  if not num_epochs: num_epochs = None
  filename = os.path.join(train)

  with tf.name_scope('input'):
    filename_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs)

    image, volume = read_and_decode(filename_queue)

    images, volumes = tf.train.shuffle_batch([image, volume], batch_size=batch_size, num_threads=2, capacity=1000 * batch_size, min_after_dequeue=500)

    return images, volume

我不明白如何使用张量流创建另一个输入队列或输入图来进行验证。有人能帮我吗?任何帮助表示赞赏!

EDIT

def _conv(self, inputs, nb_filter, kernel_size=1, strides=1, pad='VALID', name='conv'):
        with tf.name_scope(name) as scope:

            #kernel = tf.Variable(tf.truncated_normal([kernel_size, kernel_size,int(inputs.get_shape().as_list()[3]),int(nb_filter)], mean=0.0, stddev=0.0001), name='weights')
            kernel = tf.Variable(tf.contrib.layers.xavier_initializer(uniform=False)([kernel_size, kernel_size,int(inputs.get_shape().as_list()[3]),int(nb_filter)]), name='weights')
            conv = tf.nn.conv2d(inputs, kernel, [1,strides,strides,1], padding=pad, data_format='NHWC')
            return conv

EDIT 2

  with tf.Graph().as_default():
    print("Load Data...")
    images, volumes = utils.inputs(FLAGS.train_file_path, FLAGS.batch_size, FLAGS.num_epochs)
    v_images, v_volumes = utils.inputs(FLAGS.val_file_path, FLAGS.batch_size)

    print("input shape: " + str(images.get_shape()))
    print("output shape: " + str(volumes.get_shape()))

    print("Initialize training")
    logits = utils.inference(images, FLAGS.stacks, True)
    v_logits = utils.inference(v_images, FLAGS.stacks, False)

    tf.add_to_collection("logits", v_logits)

    loss = utils.get_loss(logits, volumes, FLAGS.stacks, 'train')
    v_loss = utils.get_loss(v_logits, v_volumes, FLAGS.stacks, 'val')

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_optimizer = utils.pre_training(loss, FLAGS.learning_rate)

    validate = utils.validate(v_images, v_logits, v_volumes, FLAGS.scale)

    summary_train_op = tf.summary.merge_all('train')
    summary_val_op = tf.summary.merge_all('val')

    init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=2)
    with tf.Session() as sess:

        summary_writer = tf.summary.FileWriter(FLAGS.train_dir + FLAGS.run, sess.graph)
        summary_writer_val = tf.summary.FileWriter(FLAGS.train_dir + FLAGS.run + FLAGS.run_val, sess.graph)
        sess.run(init)

        # Start input enqueue threads.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            print("Start training")
            step = 0
            while not coord.should_stop():

                start_time = time.time()
                _, loss_list, image_batch, volume_batch, summary_str = sess.run([train_optimizer, loss, images, volumes, summary_train_op])
                duration = time.time() - start_time

                if (step + 1) % FLAGS.show_step == 0:
                    print('Step %d: (%.3f sec)' % (step, duration), end= ': ')
                    print (", ".join('%.5f'%float(x) for x in loss_list))
                    summary_writer.add_summary(summary_str, step)

                if (step + 1) % FLAGS.val_interval == 0:

                    val_loss_sum_list = [0] * len(v_loss)

                    for val_step in range(0, FLAGS.val_iter):
                        _, val_loss_list, summary_str_val, image_input, volume_estimated, volume_ground_truth = sess.run([validate, v_loss, summary_val_op, v_images, v_logits, v_volumes])
                        val_loss_sum_list = [sum(x) for x in zip(val_loss_sum_list, val_loss_list)]

                        if (val_step + 1) == FLAGS.val_iter:
                            print('Validation Interval %d: ' % (step / FLAGS.val_interval), end= '')
                            print (", ".join('%.5f'%float(x / FLAGS.val_iter) for x in val_loss_sum_list))
                            summary_writer_val.add_summary(summary_str_val, step)

                            #image_input, volume_estimated, volume_ground_truth = sess.run([v_images, v_logits, v_volumes])
                            #summary_val_images_op = utils.validate(image_input, volume_estimated, volume_ground_truth, FLAGS.scale, int(step / FLAGS.val_interval))

                if (step + 1) % FLAGS.step_save_checkpoint == 0:
                    checkpoint_file = os.path.join(FLAGS.train_dir + FLAGS.run, 'hourglass-model')
                    saver.save(sess, checkpoint_file, global_step=step)
                    print('Step: ' + str(step))
                    print('Saved: ' + checkpoint_file)

                step += 1
        except tf.errors.OutOfRangeError:
            print('OUT OF RANGE ERROR')
        except Exception as e:
            print(sys.exc_info())
            print('Unexpected error in code')
            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
            print(exc_type, fname, exc_tb.tb_lineno)
        finally:
            print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
            checkpoint_file = os.path.join(FLAGS.train_dir + FLAGS.run, '-model')
            saver.save(sess, checkpoint_file, global_step=step)
            print('Step: ' + str(step))
            print('Saved: ' + checkpoint_file)

            # When done, ask the threads to stop.
            coord.request_stop()
            # Wait for threads to finish.
            coord.join(threads)

如果您已经将数据分为训练数据集和验证数据集,那么您所要做的就是为验证数据创建另一个输入管道。使用您提供的代码,它应该看起来像这样



images, volumes = utils.inputs(FLAGS.train_file_path, FLAGS.batch_size, FLAGS.num_epochs)
# create validation pipeline
v_images, v_volumes = utils.inputs(FLAGS.valid_file_path, FLAGS.batch_size, None)

logits = utils.inference(images)
loss_intermediate, loss = utils.get_loss(logits, volumes)
# define validation ops
v_logits = utils.inference(v_images)
accuracy = utils.accuracy(v_logits, v_volumes)

... a bunch of code here ...

with tf.Session() as sess:
    ... more code here ...
    if step % FLAGS.test_interval == 0:
        acc = sess.run([accuracy])
        print('Accuracy on validation data: {}'.format(acc))
    ... more code here ...
  

这是您一直在寻找的吗?

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 Tensorflow 中使用队列将数据馈送到网络时分开验证和训练图 的相关文章

  • keras 层教程和示例

    我正在尝试编码和学习不同的神经网络模型 我对输入维度有很多复杂性 我正在寻找一些教程 显示层的差异以及如何设置每个层的输入和输出 Keras 文档 https keras io layers core 向您展示所有input shape每层
  • 按 A 列删除重复项,保留 B 列中具有最高值的行

    我有一个数据框 A 列中有重复值 我想删除重复项 保留 B 列中具有最高值的行 So this A B 1 10 1 20 2 30 2 40 3 10 应该变成这样 A B 1 20 2 40 3 10 我猜想可能有一种简单的方法可以做到
  • Python3 类型错误:replace() 参数 1 必须是 str,而不是 int

    我已经尝试了几天让这段代码在 MacOS 上运行 但没有成功 你能看一下我错过了什么吗 运行 python 3 6 我已经上传了整个代码 多谢 usr bin env python3 from future import print fun
  • 二进制数据的Python字符串表示

    我试图理解 Python 显示表示二进制数据的字符串的方式 这是一个使用的示例乌兰多姆操作系统 http docs python org library os html os urandom In 1 random bytes os ura
  • 如何使用 QWebView 显示 html。 Python?

    如何在控制台中显示 HTML 格式的网页 import sys from PyQt4 QtGui import QApplication from PyQt4 QtCore import QUrl from PyQt4 QtWebKit i
  • 如何创建毫秒粒度的 Python 时间戳?

    我需要一个自纪元以来的毫秒 ms 时间戳 这应该不难 我确信我只是缺少一些方法datetime或类似的东西 实际上微秒 s 粒度也很好 我只需要亚 1 10 秒的计时 例子 我有一个每 750 毫秒发生一次的事件 假设它检查灯是否打开或关闭
  • 从内存中发送图像

    我正在尝试为 Discord 机器人实现一个系统 该系统可以动态修改图像并将其发送给机器人用户 为此 我决定使用 Pillow PIL 库 因为它对于我的目的来说似乎简单明了 这是我的工作代码的示例 它加载一个示例图像 作为测试修改 在其上
  • 使用正则表达式检查整个字符串

    我正在尝试检查字符串是否是数字 因此正则表达式 d 似乎不错 然而 由于某种原因 该正则表达式也适合 78 46 92 168 8000 这是我不想要的 一些代码 class Foo rex re compile d def bar sel
  • 导入 scipy.stats 时,出现“ImportError: DLL load failed: 找不到指定的过程”

    我无法导入 scipy stats 并收到以下错误 但不知何故 import scipy as sp 仍然可以正常工作 其他库如numpy pandas都可以毫无问题地导入 我尝试在 Anaconda 中重新安装 scipy 1 2 1 降
  • ValueError:维度 (-1) 必须在 [0, 2) 范围内

    我的python版本是3 5 2 我已经安装了keras和tensorflow 并尝试了官方的一些示例 示例链接 示例标题 用于多类 softmax 分类的多层感知器 MLP https keras io getting started s
  • 如何反转 dropout 来补偿 dropout 的影响并保持期望值不变?

    我正在学习神经网络中的正则化deeplearning ai课程 在dropout正则化中 教授说 如果应用dropout 计算出的激活值将比不应用dropout时 测试时 更小 因此 我们需要扩展激活以使测试阶段更简单 我理解这个事实 但我
  • Python 排列(包括子字符串)

    我遇到过这个帖子 如何在Python中生成列表的所有排列 https stackoverflow com questions 104420 how to generate all permutations of a list in pyth
  • 配置 Flask 以正确加载 Bootstrap js 和 css 文件

    如何使用 Flask 中的 url for 指令来正确设置 以便使用 Bootstrap 和 RGraph 的 html 页面可以正常工作 假设我的 html 页面看起来像这样 部分片段
  • Buildozer Numpy RuntimeError:工具链损坏:无法链接简单的 C 程序

    用 Python 编写我的第一个 Android 应用程序并使用 Buildozer 对其进行打包 因为稍后在项目中需要使用numpy 所以我尝试打包以下测试代码 import numpy import kivy kivy require
  • 将索引数组转换为 NumPy 中的 one-hot 编码数组

    给定一个一维索引数组 a array 1 0 3 我想将其一次性编码为二维数组 b array 0 1 0 0 1 0 0 0 0 0 0 1 创建归零数组b有足够的列 即a max 1 然后 对于每一行i 设置a i 第 列 至1 gt
  • 访问 Scrapy 内的 django 模型

    是否可以在 Scrapy 管道内访问我的 django 模型 以便我可以将抓取的数据直接保存到我的模型中 我见过this https scrapy readthedocs org en latest topics djangoitem ht
  • 在 python 中使用 re.sub 将字母变成大写?

    在许多编程语言中 以下内容 find foo a z bar并替换为GOO U 1GAR 将导致整个匹配项变为大写 我似乎无法在 python 中找到等效项 它存在吗 您可以将函数传递给re sub http docs python org
  • 使用 pyspark 计算所有可能的单词对

    我有一个文本文档 我需要找到整个文档中重复单词对的可能数量 例如 我有下面的word文档 该文档有两行 每行用 分隔 文档 My name is Sam My name is Sam My name is Sam My name is Sa
  • Spark (Python) 中的 Kolmogorov Smirnov 测试不起作用?

    我正在 Python Spark ml 中进行正态性测试 看到了我的结果think是一个错误 这是设置 我有一个标准化的数据集 范围 1 到 1 当我做直方图时 我可以清楚地看到数据不正常 gt gt gt prices norm hist
  • 捕获 SQLAlchemy 异常

    我可以使用什么捕获 SQLAlechmy 异常的上层异常 gt gt gt from sqlalchemy import exc gt gt gt dir exc ArgumentError CircularDependencyError

随机推荐