我的 TensorFlow 用例要求我为每个需要处理的实例构建一个新的计算图。这最终会增加内存需求。
除了少数几个tf.Variables
这些是模型参数,我想删除所有其他节点。其他有类似问题的人也发现了tf.reset_default_graph()
很有用,但这会消除我需要保留的模型参数。
我可以使用什么来删除除这些节点之外的所有节点?
编辑:
实例特定的计算实际上只是意味着我添加了很多新操作。我相信这些操作是内存问题背后的原因。
UPDATE:请参阅最近发布的张量流折叠(https://github.com/tensorflow/fold https://github.com/tensorflow/fold)允许动态构建计算图。
tf.graph 数据结构被设计为仅附加数据结构。因此不可能删除或修改现有节点。通常这不是问题,因为运行会话时仅处理必要的子图。
您可以尝试的是将图表的变量复制到新图表中并删除旧图表。要存档此文件,只需运行:
old_graph = tf.get_default_graph() # Save the old graph for later iteration
new_graph = tf.graph() # Create an empty graph
new_graph.set_default() # Makes the new graph default
如果您想迭代旧图中的所有节点,请使用:
for node in old_graph.get_operations():
if node.type == 'Variable':
# read value of variable and copy it into new Graph
或者您可以使用:
for node in old_graph.get_collection('trainable_variables'):
# iterates over all trainable Variabels
# read and create new variable
也看看python/framework/ops.py : 1759
查看操作图中节点的更多方法。
然而在你乱搞之前tf.Graph
我强烈建议考虑这是否真的需要。通常,人们可以尝试概括计算并使用共享变量构建一个图,以便您要处理的每个实例都是该图的子图。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)