在 Keras 和 Tensorflow 中复制模型以实现多线程设置

2024-03-03

我正在尝试在 Keras 和 TensorFlow 中实现 actor-critic 的异步版本。我使用 Keras 作为构建网络层的前端(我直接使用张量流更新参数)。我有一个global_model和一个主要的张量流会话。但在每个线程内我正在创建一个local_model它从复制参数global_model。我的代码看起来像这样

def main(args):
    config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=True)
    sess = tf.Session(config=config)
    K.set_session(sess) # K is keras backend
    global_model = ConvNetA3C(84,84,4,num_actions=3)

    threads = [threading.Thread(target=a3c_thread, args=(i, sess, global_model)) for i in range(NUM_THREADS)]

    for t in threads:
        t.start()

def a3c_thread(i, sess, global_model):
    K.set_session(sess) # registering a session for each thread (don't know if it matters)
    local_model = ConvNetA3C(84,84,4,num_actions=3)
    sync = local_model.get_from(global_model) # I get the error here

    #in the get_from function I do tf.assign(dest.params[i], src.params[i])

我收到来自 Keras 的用户警告

用户警告:默认的 TensorFlow 图不是关联的图 当前已在 Keras 中注册的 TensorFlow 会话,以及 这样 Keras 无法自动初始化变量。你 应该考虑通过 Keras 注册正确的会话K.set_session(sess)

随后出现张量流错误tf.assign操作表示操作必须在同一个图上。

值错误:张量(“conv1_W:0”,形状=(8,8,4,16), dtype=float32_ref, device=/device:CPU:0) 必须来自同一个图表 作为张量(“conv1_W:0”,形状=(8,8,4,16),dtype=float32_ref)

我不太确定出了什么问题。

Thanks


该错误来自 Keras 因为tf.get_default_graph() is sess.graph正在返回False。从 TF 文档中,我看到tf.get_default_graph()返回当前线程的默认图表。当我启动一个新线程并创建一个图时,它会被构建为特定于该线程的单独图。我可以通过执行以下操作来解决这个问题,

with sess.graph.as_default():
   local_model = ConvNetA3C(84,84,4,3)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 Keras 和 Tensorflow 中复制模型以实现多线程设置 的相关文章

  • 查找 python 数据框中每行的最高值

    我想找到每行中的最高值并返回 python 中该值的列标题 例如 我想找到每行的前两个 df A B C D 5 9 8 2 4 1 2 3 我希望我的输出看起来像这样 df B C A D 您可以使用字典理解来生成largest n数据帧
  • 为什么我们应该在 Keras 中对深度学习数据进行标准化?

    我正在 Keras 中测试一些网络架构 以对 MNIST 数据集进行分类 我已经实现了一个类似于 LeNet 的方法 我看到在网上找到的例子中 有一个数据标准化的步骤 例如 X train 255 我在没有这种标准化的情况下进行了测试 我发
  • 如何使用 boto3 从 AWS Cognito 获取经过身份验证的身份响应

    我想使用 boto3 获取访问 AWS 服务的临时凭证 用例是这样的 我的 Cognito 用户池中的用户登录到我的服务器 我希望服务器代码为该用户提供访问其他 AWS 服务的临时凭证 我有一个存储我的用户的 Cognito 用户池 我有一
  • 修复类以在 Flask 会话中启用对象存储[重复]

    这个问题在这里已经有答案了 我有一个自定义类 Passport 其中包含活动用户身份和权限 我曾经将它存储在会话中 如下所示 p Passport p do something fancy session passport p 它就奏效了
  • SQLAlchemy:检查给定值是否在列表中

    问题 在 PostgreSQL 中 检查某个字段是否在给定列表中是使用IN操作员 SELECT FROM stars WHERE star type IN Nova Planet SQLAlchemy 的等价物是什么INSQL查询 我尝试过
  • 检查多维 numpy 数组的所有边是否都是零数组

    n 维数组有 2n 个边 1 维数组有 2 个端点 2 维数组有 4 个边或边 3 维数组有 6 个 2 维面 4 维数组有 8 个边 ETC 这类似于抽象 n 维立方体发生的情况 我想检查 n 维数组的所有边是否仅由零组成 以下是边由零组
  • 使用Python进行图像识别[关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 我有一个想法 就是我想识别图像中的字母 可能是 bmp或 jpg 例如 这是一个包含字母 S 的 bmp 图像 我想做的是使用Pyth
  • C:如果文件描述符被删除,阻塞读取应该返回

    我正在以阻塞的方式从设备 文件描述符中读取 可能会发生这样的情况 在不同的线程中 设备被关闭并且文件描述符被删除 不幸的是 读取没有返回或注意到并且一直阻塞 作为一种解决方法 我可以使用 select 作为超时来执行 while 循环 如果
  • 为什么我在将数据上传到数据库时不断看到“正在重置断开的连接”?

    我正在通过 REST API 将数亿个项目从 Heroku 上的云服务器上传到 AWS EC2 中的数据库 我正在使用 Python 并且经常在日志中看到以下 INFO 日志消息 requests packages urllib3 conn
  • Python代码执行时自动打开浏览器

    我正在 Python Flask 中实现 GUI Flask 的设计方式是 必须 手动 打开本地主机以及端口号 有没有一种方法可以使其自动化 以便在运行代码时自动打开浏览器 本地主机 我尝试使用 webbrowser 包 但它在会话终止后打
  • Microsoft Azure 数据仓库和 SqlAlchemy

    我正在尝试使用 python 的 sqlalchemy 库连接到 microsoft azure 数据仓库 并收到以下错误 pyodbc Error HY000 HY000 Microsoft ODBC SQL Server Driver
  • Numpy 通过一个数组的值总结另一个数组

    我正在尝试找到一种矢量化方法来完成以下任务 假设我有一个 x 和 y 值的数组 请注意 x 值并不总是整数并且可以为负数 import numpy as np x np array 1 1 1 3 2 2 2 5 4 4 dtype flo
  • Learning_rate 不是合法参数

    我正在尝试通过实现 GridSearchCV 来测试我的模型 但我似乎无法在 GridSearch 中添加学习率和动量作为参数 每当我尝试通过添加这些代码来执行代码时 我都会收到错误 这是我创建的模型 def define model op
  • Python:使用for循环更改变量后缀

    我知道这个问题被问了很多 但到目前为止我无法使用 理解答案 我想改变for循环中变量的后缀 我尝试了 stackoverflow 搜索提供的所有答案 但很难理解提问者经常提出的具体代码 因此 为了清楚起见 我使用一个简单的示例 这并不意味着
  • Python 可以替代 Java 小程序吗?

    除了制作用于物理模拟 如抛射运动 重力等 的教育性 Java 小程序之外 还有其他选择吗 如果你想让它在浏览器中运行 你可以使用PyJamas http pyjs org 这是一个 Python 到 Javascript 的编译器和工具集
  • 在哪里可以找到Python内置序列类型的时间和空间复杂度

    我一直无法找到此信息的来源 无法亲自查看 Python 源代码来确定这些对象是如何工作的 有谁知道我可以在网上找到这个吗 结帐时间复杂度 http wiki python org moin TimeComplexitypy dot org
  • Matplotlib 渲染日期、图像的问题

    我在使用 conda forge 的 Matplotlib v 3 1 3 和 python 3 7 时遇到问题 我拥有 Matplotlib 所需的所有依赖项 当我输入这段代码时 它应该可以工作 我得到了泼溅艺术 它基于此 YouTube
  • 使用 Python 进行 Google 搜索网页抓取 [关闭]

    Closed 这个问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 最近为了工作中的一些项目 学习了很多python 目前我需要使用谷歌搜索结果进行一些网络抓取 我发现几
  • Chrome + 另一个进程:进程间通信比 HTTP/XHR 请求更快?

    我有一个进程 1 对视频流进行实时图像处理 我需要在 Chrome 中的 HTML 页面中渲染该视频 同一台计算机上的进程 2 在canvas or img or videoHTML5 元素 由于我有 1000x1000 像素 x 3 字节
  • Java,如何管理线程读取socket(websocket)?

    我有一个 WebSocket 服务器 我的服务器创建一个新线程来处理新连接 该线程一直处于活动状态 直到 websocket 中断 我的问题 对于 1 000 000 个连接 我需要 1 000 000 个线程 我如何通过一个线程处理多个

随机推荐