tf.estimator.train_and_evaluate 出错了 评估精度和损失

2024-04-02

I use tf.estimator.train_and_evaluate训练和评估我的模型。这是我的代码:

import tensorflow as tf
import numpy as np
from tensorflow.contrib.slim.nets import resnet_v2
import tensorflow.contrib.slim as slim

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(path='mnist.npz')
x_train = np.expand_dims(x_train, 3).astype(np.float32)[:5000]
y_train = y_train.astype(np.int32)[:5000]
x_test = np.expand_dims(x_test, 3).astype(np.float32)[:1000]
y_test = y_test.astype(np.int32)[:1000]

tf.logging.set_verbosity(tf.logging.INFO)

cls_num = 10


def model_fn(features, labels, mode):
    is_training = False
    if mode == tf.estimator.ModeKeys.TRAIN:
        is_training = True
    with slim.arg_scope(resnet_v2.resnet_arg_scope()):
        logits, endpoints = resnet_v2.resnet_v2_50(features, 
                num_classes=cls_num,
                is_training=is_training,
                reuse=None)

    logits = tf.squeeze(logits, [1, 2])
    preds = tf.argmax(logits, 1)

    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

    accuracy = tf.metrics.accuracy(labels=labels, predictions=preds)
    metrics = {'accuracy': accuracy}

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)

    optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
    train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) 


def process_fn(feature, label):
    feature = tf.expand_dims(feature, 3)
    return feature, label


def train_input_fn():
    dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    #dataset.map(process_fn)
    dataset = dataset.repeat(1).batch(8)
    return dataset

def eval_input_fn():
    dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    dataset = dataset.repeat(1).batch(8)
    return dataset


estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir='logs')
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
eval_specs = tf.estimator.EvalSpec(input_fn=eval_input_fn)
for _ in xrange(10):
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs)

训练步长没问题,loss变得很小(0.001左右),但是评估结果错误(以下是评估日志):

...
INFO:tensorflow:Saving dict for global step 625: accuracy = 0.5, global_step = 625, loss = 1330830600000.0
...

任务很简单,就是二分类。我不认为这是过度拟合。我的评估代码有问题吗?


None

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

tf.estimator.train_and_evaluate 出错了 评估精度和损失 的相关文章

随机推荐

  • 具有多个边界的类型参数

    此代码编译 import java io Serializable import java util Arrays class Test
  • JQuery 垃圾收集 - 这会干净吗?

    许多文章 例如msdn http msdn microsoft com en us library dd361842 28VS 85 29 aspx 已经说过 当涉及到循环引用时 在某些浏览器中无法清除循环引用DOM对象和一个JS obje
  • WinForms C# 中自定义对象类型的跨进程拖放

    这个问题 https stackoverflow com questions 1213074 winforms interop drag drop from winforms wpf与我感兴趣的内容很接近 但又不完全是 我有一个用 C 编写
  • 在 Tornadoweb 中禁用模板处理

    我必须使用 Tornado Web 作为我们现有 AngularJs 应用程序的 RESTful 后端 在 Angular 应用程序中大量使用 我想将龙卷风的角度文件作为静态文件提供 有没有办法禁用tornado的处理模板以避免与torna
  • Ruby on Rails 中的延迟作业如何工作?

    我对此很陌生 对延迟工作的工作原理不太困惑 我知道它会创建一个表并将作业放入表中 然后我需要运行 rake jobs work 启动后台进程 现在我的问题是 DJ 脚本是否每分钟检查一次表 当时间与 job at 时间匹配时 它会运行该作业
  • 删除与模型关联的文件 - django

    我的一个模型中有以下代码 class PostImage models Model post models ForeignKey Post related name images figure out a way to have image
  • Xamarin.Android 绑定无效操作码

    最近收到一个新的 Android SDK aar 来绑定在 Xamarin 中 最初开始绑定时 我收到错误 COMPILETODALVIK Uncaught translation error com android dx cf code
  • 如何解析xsd:dateTime格式?

    xsd dateTime 类型的值可以有多种形式 如描述于RELAX NG http books xmlschemata org relaxng ch19 77049 html 如何将所有表单解析为时间或日期时间对象 它实际上是一种非常受限
  • 如何通过我自己的模板使用内置密码重置/更改视图

    例如我可以指出url accounts password reset to django contrib auth views password reset在上下文中使用我的模板文件名 但我认为需要发送更多上下文详细信息 我需要确切地知道为
  • 使用python打印月份和日期

    我试图在 python 中仅打印月份和日期 如下所示 09 December 08 October 我怎么能这么做呢 Try this import datetime now datetime datetime now print now
  • 命令行参数 - 所需对象:'objshell.NameSpace(...)'

    我正在编写一个脚本 该脚本将利用 Windows 的内置功能来解压缩提供的 zip 文件 我对 vbscript 还很陌生 所以有些语法让我有点困惑 我正在使用一些现有代码并尝试修改它 以便它将采用命令行选项作为文件名 如果我使用命令行传递
  • 仅向一个应用程序发送广播意图,而不使用显式意图

    我有个问题 我正在做一个外部 android 服务 应用程序可以注册它来接收信息 信息通过广播从服务返回到应用程序 并通过broadcastReceiver 问题是如果我这样做sendBroadcast 任何应用程序都可以监听其他应用程序的
  • 结构末尾的大小为 0 的数组[重复]

    这个问题在这里已经有答案了 我正在学习的系统编程课程的教授今天告诉我们要定义一个末尾带有零长度数组的结构体 struct array size t size int data 0 typedef struct array array 这是一
  • 极长工作流程的 Cucumber 场景

    我们需要为一个功能测试一个漫长的步骤过程 从登录到许多模式对话框 多步骤表单以及不同角色的用户都在交互 我们如何将这个过程的各个部分分解为单独的场景 这是一个例子 Scenario New Manuscript Given I am on
  • 如何获取用户当前在 Spotify 应用程序中收听的内容的信息 [关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 Android 应用程序 在后台运行并使用 Spotify SDK 能否获取用户当前在 Spotify Android 应用程序中收听
  • 是否有复杂的 Java WorkQueue API?

    我正在寻找具有以下功能的 WorkQueue API java util Queue兼容的 优惠 可选 集合语义 单处理和批处理 并发 当然 调度 different processing policies 等到下一次计划执行 如果批量大小
  • PowerPivot 中的滚动 12 个月总和

    在 PowerPivot Excel 2016 中 我编写了滚动 12 个月销售额总和的公式 如下所示 Rolling Sum CALCULATE Sales DATESBETWEEN Sales Date FIRSTDATE DATEAD
  • Python:使用 Openpyxl 读取大型 Excel 工作表

    我有一个 Excel 文件 其中包含大约 400 个工作表 其中 375 个工作表需要保存为 CSV 文件 我尝试过 VBA 解决方案 但 Excel 在打开此工作簿时遇到问题 我创建了一个 python 脚本来做到这一点 然而 它会迅速消
  • UISegmentedControl 截断段标题

    我的 iPhone 应用程序中有一个分段控件 在 ios6 上运行良好 但在 ios7 上 分段图块被截断 有足够的空间容纳文本 但无论如何都会截断它们 self segmentedControl segmentedControlStyle
  • tf.estimator.train_and_evaluate 出错了 评估精度和损失

    I use tf estimator train and evaluate训练和评估我的模型 这是我的代码 import tensorflow as tf import numpy as np from tensorflow contrib