Tensorflow 2.1/Keras - 尝试冻结图形时出现“output_node 不在图形中”错误


我正在尝试保存使用 Keras 创建的模型并保存为 .h5 文件,但每次尝试运行 freeze_session 函数时都会收到此错误消息:输出节点/身份不在图中

这是我的代码(我使用的是 Tensorflow 2.1.0):

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    Freezes the state of a session into a pruned computation graph.

    Creates a new computation graph where variable nodes are replaced by
    constants taking their current value in the session. The new graph will be
    pruned so subgraphs that are not necessary to compute the requested
    outputs are removed.
    @param session The TensorFlow session to be frozen.
    @param keep_var_names A list of variable names that should not be frozen,
                          or None to freeze all the variables in the graph.
    @param output_names Names of the relevant graph outputs.
    @param clear_devices Remove the device directives from the graph for better portability.
    @return The frozen graph definition.
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.compat.v1.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.compat.v1.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = tf.compat.v1.graph_util.convert_variables_to_constants(
            session, input_graph_def, output_names, freeze_var_names)
        return frozen_graph
# inputs:
print('inputs: ', model.input.op.name)
# outputs: 
print('outputs: ', model.output.op.name)
layer_names=[layer.name for layer in model.layers]


inputs: input_node outputs: output_node/Identity ['input_node', 'conv2d_6', 'max_pooling2d_6', 'conv2d_7', 'max_pooling2d_7', 'conv2d_8', 'max_pooling2d_8', 'flatten_2', 'dense_4', 'dense_5', 'output_node'] 正如预期的那样(与我在训练后保存的模型中相同的层名称和输出)。

然后我尝试调用 freeze_session 函数并保存生成的冻结图:

frozen_graph = freeze_session(K.get_session(), output_names=[out.op.name for out in model.outputs])
write_graph(frozen_graph, './', 'graph.pbtxt', as_text=True)
write_graph(frozen_graph, './', 'graph.pb', as_text=False)


AssertionError                            Traceback (most recent call last)
<ipython-input-4-1848000e99b7> in <module>
----> 1 frozen_graph = freeze_session(K.get_session(), output_names=[out.op.name for out in model.outputs])
      2 write_graph(frozen_graph, './', 'graph.pbtxt', as_text=True)
      3 write_graph(frozen_graph, './', 'graph.pb', as_text=False)

<ipython-input-2-3214992381a9> in freeze_session(session, keep_var_names, output_names, clear_devices)
     24                 node.device = ""
     25         frozen_graph = tf.compat.v1.graph_util.convert_variables_to_constants(
---> 26             session, input_graph_def, output_names, freeze_var_names)
     27         return frozen_graph

c:\users\marco\anaconda3\envs\tfv2\lib\site-packages\tensorflow_core\python\util\deprecation.py in new_func(*args, **kwargs)
    322               'in a future version' if date is None else ('after %s' % date),
    323               instructions)
--> 324       return func(*args, **kwargs)
    325     return tf_decorator.make_decorator(
    326         func, new_func, 'deprecated',

c:\users\marco\anaconda3\envs\tfv2\lib\site-packages\tensorflow_core\python\framework\graph_util_impl.py in convert_variables_to_constants(sess, input_graph_def, output_node_names, variable_names_whitelist, variable_names_blacklist)
    275   # This graph only includes the nodes needed to evaluate the output nodes, and
    276   # removes unneeded nodes like those involved in saving and assignment.
--> 277   inference_graph = extract_sub_graph(input_graph_def, output_node_names)
    279   # Identify the ops in the graph.

c:\users\marco\anaconda3\envs\tfv2\lib\site-packages\tensorflow_core\python\util\deprecation.py in new_func(*args, **kwargs)
    322               'in a future version' if date is None else ('after %s' % date),
    323               instructions)
--> 324       return func(*args, **kwargs)
    325     return tf_decorator.make_decorator(
    326         func, new_func, 'deprecated',

c:\users\marco\anaconda3\envs\tfv2\lib\site-packages\tensorflow_core\python\framework\graph_util_impl.py in extract_sub_graph(graph_def, dest_nodes)
    195   name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
    196       graph_def)
--> 197   _assert_nodes_are_present(name_to_node, dest_nodes)
    199   nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name)

c:\users\marco\anaconda3\envs\tfv2\lib\site-packages\tensorflow_core\python\framework\graph_util_impl.py in _assert_nodes_are_present(name_to_node, nodes)
    150   """Assert that nodes are present in the graph."""
    151   for d in nodes:
--> 152     assert d in name_to_node, "%s is not in graph" % d

**AssertionError: output_node/Identity is not in graph** 


如果您使用 Tensorflow 版本 2.x 添加:


这应该有效。 我还没有检查生成的 pb 文件,但它应该可以工作。


edit:但是,例如,这个线程 https://github.com/tensorflow/models/issues/7508,TF1 和 TF2 pb 文件本质上是不同的。我的解决方案可能无法正常工作或实际创建 TF1 pb 文件。





