背景:
我有一些复杂的强化学习算法,我想在多个线程中运行。
Problem
当尝试打电话时sess.run
在一个线程中我收到以下错误消息:
RuntimeError: The Session graph is empty. Add operations to the graph before calling run().
重现错误的代码:
import tensorflow as tf
import threading
def thread_function(sess, i):
inn = [1.3, 4.5]
A = tf.placeholder(dtype=float, shape=(None), name="input")
P = tf.Print(A, [A])
Q = tf.add(A, P)
sess.run(Q, feed_dict={A: inn})
def main(sess):
thread_list = []
for i in range(0, 4):
t = threading.Thread(target=thread_function, args=(sess, i))
thread_list.append(t)
t.start()
for t in thread_list:
t.join()
if __name__ == '__main__':
sess = tf.Session()
main(sess)
如果我在线程外运行相同的代码,它将正常工作。
有人可以深入了解如何在 python 线程中正确使用 Tensorflow 会话吗?
Session不仅可以是当前线程的默认值,还可以是图。
当您进入会话并调用时run
在其上,默认图表将是不同的。
您可以修改您的线程函数像这样让它工作:
def thread_function(sess, i):
with sess.graph.as_default():
inn = [1.3, 4.5]
A = tf.placeholder(dtype=float, shape=(None), name="input")
P = tf.Print(A, [A])
Q = tf.add(A, P)
sess.run(Q, feed_dict={A: inn})
但是,我不希望有任何显着的加速。 Python 线程与其他语言中的含义不同,只有某些操作(例如 io)会并行运行。对于 CPU 密集型操作来说,它不是很有用。多处理可以真正并行运行代码,但您不会共享同一个会话。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)