如何使用 Tensorflow 2/ Keras 保存和恢复训练具有多个模型部分的 GAN

2024-01-09

我目前正在尝试添加一个功能来中断和恢复通过此示例代码创建的 GAN 的训练:https://machinelearningmastery.com/how-to-develop-an-auxiliary-classifier-gan-ac-gan-from-scratch-with-keras/ https://machinelearningmastery.com/how-to-develop-an-auxiliary-classifier-gan-ac-gan-from-scratch-with-keras/

我设法让它工作,将整个复合 GAN 的权重保存在 summarise_performance 函数中,该函数每 10 个周期触发一次,如下所示:

# save all weights
filename3 = 'weights_%08d.h5' % (step+1)
gan_model.save_weights(filename3)
print('>Saved: %s and %s and %s' % (filename1, filename2, filename3))

它被加载到我添加到程序开头的一个名为 load_model 的函数中,该函数采用像平常一样构建的 gan 架构,但将其权重更新为最新值,如下所示:

#load model from file and return startBatch number
def load_model(gan_model):
   start_batch = 0
   files = glob.glob("./weights_0*.h5")
   if(len(files) > 0 ):
       most_recent_file = files[len(files)-1]
       gan_model.load_weights(most_recent_file)
       #TODO: breaks if using more than 8 digits for batches
       startBatch = int(most_recent_file[10:18])
       if (start_batch != 0):
           print("> found existing weights; starting at batch %d" % start_batch)
   return start_batch

其中 start_batch 被传递给 train 函数,以便跳过已经完成的 epoch。

虽然这种减重方法确实“有效”,但我仍然认为我的方法是错误的,因为我发现权重数据显然不包括 GAN 的优化器状态,​​因此训练不会像它那样继续进行没有被打断。

我发现保存进度同时保存优化器状态的方法显然是通过保存整个模型而不仅仅是权重来完成的

在这里我遇到了一个问题,因为在 GAN 中我不仅训练一个模型,而且有 3 个模型:

  • 生成器模型 g_model
  • 判别器模型 d_model
  • 和复合 GAN 模型 gan_model

它们都是相互关联、相互依赖的。如果我采用幼稚的方法并单独保存和恢复每个部分模型,我最终会得到 3 个独立的脱节模型,而不是 GAN

有没有一种方法可以保存和恢复整个 GAN,让我可以像没有发生中断一样恢复训练?


也许考虑使用tf.train.检查点 https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint,如果您想恢复整个 GAN:

### In your training loop

checkpoint_dir = '/checkpoints'
checkpoint = tf.train.Checkpoint(gan_optimizer=gan_optimizer,
                            discriminator_optimizer=discriminator_optimizer,
                                  generator=generator,
                                  discriminator=discriminator
                                  gan_model = gan_model)
  
ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
if ckpt_manager.latest_checkpoint:
    checkpoint.restore(ckpt_manager.latest_checkpoint)  
    print ('Latest checkpoint restored!!')

....
....


if (epoch + 1) % 40 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,ckpt_save_path))

### After x number of epochs, just save your generator model for inference.

generator.save('your_model.h5')

您还可以考虑完全摆脱复合模型。Here https://github.com/Yoan-D/text-to-image-synthesis/blob/master/text2image_gan_ms.py#L358这是我的意思的一个例子。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何使用 Tensorflow 2/ Keras 保存和恢复训练具有多个模型部分的 GAN 的相关文章

  • 与区域指示符字符类匹配的 python 正则表达式

    我在 Mac 上使用 python 2 7 10 表情符号中的标志由一对表示区域指示符号 https en wikipedia org wiki Regional Indicator Symbol 我想编写一个 python 正则表达式来在
  • 在 django ORM 中查询时如何将 char 转换为整数?

    最近开始使用 Django ORM 我想执行这个查询 select student id from students where student id like 97318 order by CAST student id as UNSIG
  • 安装了 32 位的 Python,显示为 64 位

    我需要运行 32 位版本的 Python 我认为这就是我在我的机器上运行的 因为这是我下载的安装程序 当我重新运行安装程序时 它会将当前安装的 Python 版本称为 Python 3 5 32 位 然而当我跑步时platform arch
  • 用枢轴点拟合曲线 Python

    我有下面的图 我想用 2 条线来拟合它 使用 python 我设法适应上半部分 def func x a b x np array x return a x b popt pcov curve fit func up x up y 我想用另
  • 使用 Python 从文本中删除非英语单词

    我正在 python 上进行数据清理练习 我正在清理的文本包含我想删除的意大利语单词 我一直在网上搜索是否可以使用像 nltk 这样的工具包在 Python 上执行此操作 例如给出一些文本 Io andiamo to the beach w
  • 删除flask中的一对一关系

    我目前正在使用 Flask 开发一个应用程序 并且在删除一对一关系中的项目时遇到了一个大问题 我的模型中有以下结构 class User db Model tablename user user id db Column db String
  • 将 python2.7 与 Emacs 24.3 和 python-mode.el 一起使用

    我是 Emacs 新手 我正在尝试设置我的 python 环境 到目前为止 我已经了解到在 python 缓冲区中使用 python mode el C c C c将当前缓冲区的内容加载到交互式 python shell 中 显然使用了什么
  • 使用字典映射数据帧索引

    为什么不df index map dict 工作就像df column name map dict 这是尝试使用index map的一个小例子 import pandas as pd df pd DataFrame one A 10 B 2
  • datetime.datetime.now() 返回旧值

    我正在通过匹配日期查找 python 中的数据存储条目 我想要的是每天选择 今天 的条目 但由于某种原因 当我将代码上传到 gae 服务器时 它只能工作一天 第二天它仍然返回相同的值 例如当我上传代码并在 07 01 2014 执行它时 它
  • Python beautifulsoup 仅限 1 级文本

    我看过其他 beautifulsoup 得到相同级别类型的问题 看来我的有点不同 这是网站 我正试图拿到右边那张桌子 请注意表的第一行如何展开为该数据的详细细分 我不想要那个数据 我只想要最顶层的数据 您还可以看到其他行也可以展开 但在本例
  • Python,将函数的输出重定向到文件中

    我正在尝试将函数的输出存储到Python中的文件中 我想做的是这样的 def test print This is a Test file open Log a file write test file close 但是当我这样做时 我收到
  • 如何使用 Mysql Python 连接器检索二进制数据?

    如果我在 MySQL 中创建一个包含二进制数据的简单表 CREATE TABLE foo bar binary 4 INSERT INTO foo bar VALUES UNHEX de12 然后尝试使用 MySQL Connector P
  • 在 Sphinx 文档中*仅*显示文档字符串?

    Sphinx有一个功能叫做automethod从方法的文档字符串中提取文档并将其嵌入到文档中 但它不仅嵌入了文档字符串 还嵌入了方法签名 名称 参数 我如何嵌入only文档字符串 不包括方法签名 ref http www sphinx do
  • 如何使用 pybrain 黑盒优化训练神经网络来处理监督数据集?

    我玩了一下 pybrain 了解如何生成具有自定义架构的神经网络 并使用反向传播算法将它们训练为监督数据集 然而 我对优化算法以及任务 学习代理和环境的概念感到困惑 例如 我将如何实现一个神经网络 例如 1 以使用 pybrain 遗传算法
  • Cython 和类的构造函数

    我对 Cython 使用默认构造函数有疑问 我的 C 类 Node 如下 Node h class Node public Node std cerr lt lt calling no arg constructor lt lt std e
  • pip 列出活动 virtualenv 中的全局包

    将 pip 从 1 4 x 升级到 1 5 后pip freeze输出我的全局安装 系统 软件包的列表 而不是我的 virtualenv 中安装的软件包的列表 我尝试再次降级到 1 4 但这并不能解决我的问题 这有点类似于这个问题 http
  • 使用特定颜色和抖动在箱形图上绘制数据点

    我有一个plotly graph objects Box图 我显示了箱形 图中的所有点 我需要根据数据的属性为标记着色 如下所示 我还想抖动这些点 下面未显示 Using Box我可以绘制点并抖动它们 但我不认为我可以给它们着色 fig a
  • 如何在 Windows 命令行中使用参数运行 Python 脚本

    这是我的蟒蛇hello py script def hello a b print hello and that s your sum sum a b print sum import sys if name main hello sys
  • cv2.VideoWriter:请求一个元组作为 Size 参数,然后拒绝它

    我正在使用 OpenCV 4 0 和 Python 3 7 创建延时视频 构造 VideoWriter 对象时 文档表示 Size 参数应该是一个元组 当我给它一个元组时 它拒绝它 当我尝试用其他东西替换它时 它不会接受它 因为它说参数不是
  • Kivy - 单击按钮时编辑标签

    我希望 Button1 在单击时编辑标签 etykietka 但我不知道如何操作 你有什么想法吗 class Zastepstwa App def build self lista WebOps getList layout BoxLayo

随机推荐

  • 将颠覆存储库编号放入代码中

    我想实现一种在代码中记录项目版本的方法 以便在测试时使用它并帮助跟踪错误 看起来最好使用的版本号就是 Subversion 的当前修订版号 有没有一种简单的方法可以将这个数字挂接到 在我的例子中是C 头文件或其他文件中 然后我可以在代码中获
  • 安装 GitHub 应用程序时在私有存储库中搜索时出现“验证失败”错误

    我创建了一个 GitHub 应用程序并将其安装在我的帐户中 使其能够访问我帐户中的私有存储库 GitHub 应用程序具有元数据的读取权限 然后 我按照此处的步骤生成了 JWT 并使用它来创建安装访问令牌 我尝试使用此令牌使用 GitHub
  • java中奇怪的平等行为[重复]

    这个问题在这里已经有答案了 看看下面的代码 Long minima 9223372036854775808L Long anotherminima 9223372036854775808L System out println minima
  • ||= 在 Ruby 中做什么[重复]

    这个问题在这里已经有答案了 我使用 Ruby 一段时间了 我不断看到这样的情况 foo bar 它是什么 这将分配bar to foo如果 且仅当 foo is nil or false 编辑 或者错误 谢谢 mopoke
  • VSCode 构建不起作用 - 未定义构建任务。在tasks.json 文件中使用“isBuildCommand”标记任务

    我全新安装了 VSCode 和这个小型的基本 TypeScript 应用程序 第一次 当我想要构建应用程序时 VScode 需要生成tasks json 而且它在很久以前就起作用了 今天我收到这个奇怪的消息 未定义构建任务 在tasks j
  • 如何使用 JDBC 读取 mysql 中的 JSON 数据类型

    Mysql 5 7 引入了 JSON 数据类型 它提供了大量的查询功能 由于没有兼容的结果集函数 我如何以及如何使用检索存储在此数据类型中的数据 它应该是rs getString 因为getString与一起使用VARCHAR TEXT 我
  • Rails4:image_url 未在 scss 中生成摘要

    我不明白为什么我的 css 文件没有使用辅助方法将摘要附加到我的资产中image url 我的资产已正确预编译 并且文件确实包含摘要 我还可以手动访问它们 使用摘要的网址 最奇怪的是 一开始它是有效的 这是我的配置 config asset
  • 实体框架预加载过滤器

    我有一个简单的查询 我想这样做 1 Products have ChildProducts其中有PriceTiers2 我想得到所有Products有一个Category with a ID1 和Display true 3 我想包括所有C
  • 视图的内边距和边距之间的区别

    视图的边距和填充有什么区别 帮助我记住的含义padding 我想到一件有很多的大衣厚棉垫 我在外套里面 但我和我的棉衣是在一起的 我们是一个单位 但要记住margin 我想 嘿嘿 给我一点余地吧 这是我和你之间的空白 不要进入我的舒适区 我
  • jOOQ 和缓存?

    我正在考虑从 Hibernate 迁移到 jOOQ 但我不确定是否可以不使用缓存 休眠有一个一级 二级缓存 https stackoverflow com questions 337072 what are first and second
  • Apache CXF LoggingInInterceptor 已弃用 - 可以使用什么替代?

    我在 Spring Boot 的帮助下使用 Apache CXFcxf spring boot starter jaxws3 2 7版本的插件 我的目的是自定义日志拦截器 但是当我创建以下类时 public class CustomLogg
  • 在 C++ 中打印浮点数的二进制表示形式[重复]

    这个问题在这里已经有答案了 可能的重复 C 中浮点数转换为二进制 https stackoverflow com questions 2746380 float to binary in c 我想在 C 中打印出浮点数的二进制表示形式 不太
  • 将 MongoCursor 从 ->find() 转换为数组

    jokes collection gt find 我如何转换 jokes进入数组 你可以使用 PHP 的iterator to array http php net manual en function iterator to array
  • Roundcube问题:与存储服务器的连接失败

    我在 Roundcube 中收到此错误 连接到存储服务器失败 行 我已经检查了所有内容 配置 数据库用户名密码 服务器详细信息都是干净的 谁能告诉我可能是什么问题 这里我给出了整个配置文件
  • asp.net 中的 Convert.ToDateTime 问题

    我有一个应用程序在西班牙服务器上运行没有任何问题 当我将应用程序上传到在线服务器 英文窗口 时 我收到 Convert ToDateTime 和 Convert ToInt32 的异常 类型为 输入字符串不是有效的 Datetime Int
  • 在 Symfony/Doctrine 中删除记录时执行一些清理

    将 Symfony 1 4 5 与 Doctrine 结合使用 我有一个模型 其中包含上传的图像作为其中一列 创建和更新记录很好 使用 doSave 方法来处理上传和对文件的任何更改 我遇到的问题是 如果记录被删除 我希望它也删除关联的文件
  • 如何控制表格视图滚动速度?

    我想要控制表视图滚动速度 如何以编程方式做到这一点 请帮忙 提前致谢 简森 雅各布 您可以设置tableView decelerationRate财产 它是一个浮点值 决定用户抬起手指后的减速率 并且 您的应用程序可以使用UIScrollV
  • iPhone - 将字典写入文件:处理错误

    使用以下命令将 NSDictionary 保存到文件时 BOOL writeToFile NSString path atomically BOOL flag 可以返回 YES 或 NO 有一些编写接受 NSError 参数的文件的方法 对
  • JQuery - 摆脱 .serialize() 中的 %5B%5D

    我正在使用 AJAX 提交序列化表单 数据传递到action php最终包含 5B 5D 而不是 是否有办法取回 或者数据能够以相同的方式处理 即像数组一样 action php 该表格通过以下方式序列化 var form data for
  • 如何使用 Tensorflow 2/ Keras 保存和恢复训练具有多个模型部分的 GAN

    我目前正在尝试添加一个功能来中断和恢复通过此示例代码创建的 GAN 的训练 https machinelearningmastery com how to develop an auxiliary classifier gan ac gan