键盘中断张量流运行并在此时保存

2024-04-16

有没有办法通过键盘中断来中断张量流会话,并可以选择在此时保存模型?我目前让会话运行过夜,但需要停止它,以便我可以释放内存供电脑在白天使用。随着训练的进行,每个时期都会变慢,因此有时我可能需要等待几个小时才能进行程序中的下一个计划保存。我想要能够随时进入运行并从该点保存的功能。我什至不知道这是否可能。希望能得到指点。


一种选择是子类化tf.Session对象并创建一个__exit__当键盘中断通过时保存当前状态的函数。仅当新对象作为对象的一部分被调用时,这才有效。with block.

这是子类:

import tensorflow as tf

class SessionWithExitSave(tf.Session):
    def __init__(self, *args, saver=None, exit_save_path=None, **kwargs):
        self.saver = saver
        self.exit_save_path = exit_save_path
        super().__init__(*args, **kwargs)

    def __exit__(self, exc_type, exc_value, exc_tb):
        if exc_type is KeyboardInterrupt:
            if self.saver:
                self.saver.save(self, self.exit_save_path)
                print('Output saved to: "{}./*"'.format(self.exit_save_path))
        super().__exit__(exc_type, exc_value, exc_tb)

TensorFlow mnist 演练中的示例用法。

import tensorflow as tf
import datetime as dt
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('U:/mnist/', one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, W) + b
# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(cross_entropy)

saver = tf.train.Saver()

with SessionWithExitSave(
        saver=saver, 
        exit_save_path='./tf-saves/_lastest.ckpt') as sess:
    sess.run(tf.global_variables_initializer())
    total_epochs = 50
    for epoch in range(1, total_epochs+1):
        for _ in range(1000):
            batch_xs, batch_ys = mnist.train.next_batch(100)
            sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
        # Test trained model
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

        print(f'Epoch {epoch} of {total_epochs} :: accuracy = ', end='')
        print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
        save_time = dt.datetime.now().strftime('%Y%m%d-%H.%M.%S')
        saver.save(sess, f'./tf-saves/mnist-{save_time}.ckpt')

在从键盘发送中断信号之前,我让它运行 10 个纪元。这是输出:

Epoch 1 of 50 :: accuracy = 0.9169
Epoch 2 of 50 :: accuracy = 0.919
Epoch 3 of 50 :: accuracy = 0.9205
Epoch 4 of 50 :: accuracy = 0.9221
Epoch 5 of 50 :: accuracy = 0.92
Epoch 6 of 50 :: accuracy = 0.9229
Epoch 7 of 50 :: accuracy = 0.9234
Epoch 8 of 50 :: accuracy = 0.9234
Epoch 9 of 50 :: accuracy = 0.9252
Epoch 10 of 50 :: accuracy = 0.9248
Output saved to: "./tf-saves/_lastest.ckpt./*"
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
...
--> 768   elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
    769     return item[1]._is_present_in_parent
    770   else:
KeyboardInterrupt:

事实上,我确实拥有所有保存的文件,包括发送到系统的键盘中断的保存。

import os

os.listdir('./tf-saves/')
# returns:
['checkpoint',
 'mnist-20171207-23.05.18.ckpt.data-00000-of-00001',
 'mnist-20171207-23.05.18.ckpt.index',
 'mnist-20171207-23.05.18.ckpt.meta',
 'mnist-20171207-23.05.22.ckpt.data-00000-of-00001',
 'mnist-20171207-23.05.22.ckpt.index',
 'mnist-20171207-23.05.22.ckpt.meta',
 'mnist-20171207-23.05.26.ckpt.data-00000-of-00001',
 'mnist-20171207-23.05.26.ckpt.index',
 'mnist-20171207-23.05.26.ckpt.meta',
 'mnist-20171207-23.05.31.ckpt.data-00000-of-00001',
 'mnist-20171207-23.05.31.ckpt.index',
 '_lastest.ckpt.data-00000-of-00001',
 '_lastest.ckpt.index',
 '_lastest.ckpt.meta']
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

键盘中断张量流运行并在此时保存 的相关文章

  • 如何在groupby之后将pandas数据框拆分为许多列

    我希望能够在 pandas 中使用 groupby 按列对数据进行分组 然后将其拆分 以便每个组都是数据框中自己的列 e g time data 0 1 2 0 1 2 3 0 2 3 4 0 3 1 2 1 4 2 3 1 5 3 4 1
  • 如何 json_normalize() df 中的特定字段并保留其他列? [复制]

    这个问题在这里已经有答案了 这是我的简单示例 我的实际数据集中的 json 字段非常嵌套 因此我一次解压一层 我需要在 json normalize 之后保留数据集上的某些列 https pandas pydata org docs ref
  • 无法将 datetime.datetime 与 datetime.date 进行比较

    我有以下代码并收到上述错误 由于我是 python 新手 我无法理解这里的语法以及如何修复错误 if not start or date lt start start date 有一个datetime date 从日期时间转换为日期的方法
  • 为 PyCharm 中的所有配置设置相同的环境变量

    我有一个与 Celery 和很多不同的工作人员一起的项目 如何避免每次将 PyCharm 中的环境变量复制粘贴到每个运行 调试配置 有什么方法可以在项目设置中设置它们吗 找到解决方案here https stackoverflow com
  • multiprocessing.freeze_support()

    为什么多处理模块需要调用特定的function http docs python org dev library multiprocessing html multiprocessing freeze support在被 冻结 以生成 Wi
  • Python 相当于 Bit Twiddling Hacks 中的 C 代码?

    我有一个位计数方法 我正在尝试尽可能快地实现 我想尝试下面的算法位摆弄黑客 http graphics stanford edu seander bithacks html CountBitsSetParallel 但我不知道 C 什么是
  • 如何将同步函数包装在异步协程中?

    我在用着aiohttp https github com aio libs aiohttp构建一个 API 服务器 将 TCP 请求发送到单独的服务器 发送 TCP 请求的模块是同步的 对于我来说是一个黑匣子 所以我的问题是这些请求阻塞了整
  • 如何解码 dtype=numpy.string_ 的 numpy 数组?

    我需要使用 Python 3 解码按以下方式编码的字符串 gt gt gt s numpy asarray numpy string hello nworld gt gt gt s array b hello nworld dtype S1
  • Pandas,按最大返回值进行分组 AssertionError:

    熊猫有问题 我想听听你的意见 我有这个数据框 我需要在其中获取最大值 代码就在下面 df stack pd DataFrame 1 0 2016 0 NonResidential Hotel 98101 0 DOWNTOWN 47 6122
  • 如何将reportlab与Google应用程序引擎一起使用

    我无法在谷歌应用程序引擎下正确导入reportlab 根据以下guide http blog notdot net 2010 04 Generating PDFs on App Engine Python and introducing M
  • 与函数复合 UniqueConstraint

    一个快速的 SQLAlchemy 问题 我有一个 文档 类 其属性为 数字 和 日期 我需要确保没有重复的号码同年 是 有没有办法对 数字 年份 日期 进行UniqueConstraint 我应该使用唯一索引吗 我如何声明功能部分 SQLA
  • 从 Apache 运行 python 脚本的最简单方法

    我花了很长时间试图弄清楚这一点 我基本上正在尝试开发一个网站 当用户单击特定按钮时 我必须在其中执行 python 脚本 在研究了 Stack Overflow 和 Google 之后 我需要配置 Apache 以便能够运行 CGI 脚本
  • dask allocate() 或 apply() 中的变量列名

    我有适用于pandas 但我在将其转换为使用时遇到问题dask 有一个部分解决方案here https stackoverflow com questions 32363114 how do i change rows and column
  • python csv按列转换为字典

    是否可以将 csv 文件中的数据读取到字典中 使得列的第一行是键 同一列的其余行构成列表的值 例如 我有一个 csv 文件 strings numbers colors string1 1 blue string2 2 red string
  • DRF:以编程方式从 TextChoices 字段获取默认选择

    我们的网站是 Vue 前端 DRF 后端 在一个serializer validate 方法 我需要以编程方式确定哪个选项TextChoices类已被指定为模型字段的默认值 TextChoices 类 缩写示例 class PaymentM
  • Windows 与 Linux 文本文件读取

    问题是 我最近从 Windows 切换到 Ubuntu 我的一些用于分析数据文件的 python 脚本给了我错误 我不确定如何正确解决 我当前仪器的数据文件输出如下 Header 有关仪器等的各种信息 Data 状态 代码 温度 字段等 0
  • SpaCy 中的自定义句子边界检测

    我正在尝试在 spaCy 中编写一个自定义句子分段器 它将整个文档作为单个句子返回 我编写了一个自定义管道组件 它使用以下代码来执行此操作here https github com explosion spaCy issues 1850 但
  • LSTM 批次与时间步

    我按照 TensorFlow RNN 教程创建了 LSTM 模型 然而 在这个过程中 我对 批次 和 时间步长 之间的差异 如果有的话 感到困惑 并且我希望得到帮助来澄清这个问题 教程代码 见下文 本质上是根据指定数量的步骤创建 批次 wi
  • scrapy python 请求未定义

    我在这里找到了答案 code for site in sites Link site xpath a href extract CompleteLink urlparse urljoin response url Link yield Re
  • Django - 缺少 1 个必需的位置参数:'request'

    我收到错误 get indiceComercioVarejista 缺少 1 个必需的位置参数 要求 当尝试访问 get indiceComercioVarejista 方法时 我不知道这是怎么回事 views from django ht

随机推荐