TF 对象检测 Zoo 模型没有可训练变量?

2024-03-07

中的模型TF 异议检测动物园 https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md有meta+ckpt文件、Frozen.pb文件和Saved_model文件。

我尝试使用 meta+ckpt 文件进行进一步训练,并为特定张量提取一些权重以用于研究目的。我发现这些模型没有任何可训练的变量。

vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print(vars)

上面的代码片段给出了[]列表。我也尝试使用以下内容。

vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print(vars)

我再次得到一个[] list.

这怎么可能 ?模型是否去掉了变量?或者是tf.Variable(trainable=False)?我在哪里可以获得包含有效可训练变量的meta+ckpt 文件。我专门关注SSD+mobilnet型号

UPDATE:

以下是我用于恢复的代码片段。它位于类内,因为我正在为某些应用程序制作自定义工具。

def _importer(self):
    sess = tf.InteractiveSession()
    with sess.as_default():
        reader = tf.train.import_meta_graph(self.metafile,
                                            clear_devices=True)
        reader.restore(sess, self.ckptfile)

def _read_graph(self):
    sess = tf.get_default_session()
    with sess.as_default():
        vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        print(vars)

更新2:

我还尝试了以下代码片段。简约还原风格。

model_dir = 'ssd_mobilenet_v2/'

meta = glob.glob(model_dir+"*.meta")[0]
ckpt = meta.replace('.meta','').strip()

sess = tf.InteractiveSession()
graph = tf.Graph()
with graph.as_default():
    with tf.Session() as sess:
        reader = tf.train.import_meta_graph(meta,clear_devices=True)
        reader.restore(sess,ckpt)

        vari = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        for var in vari:
            print(var.name,"\n")

上面的代码片段还给出了[]变量列表


经过一番研究后,您问题的最终答案是不,他们不。这是非常明显的,直到您意识到variables目录在saved_model是空的。

对象检测模型zoo提供的检查点文件包含以下文件:

.
|-- checkpoint
|-- frozen_inference_graph.pb
|-- model.ckpt.data-00000-of-00001
|-- model.ckpt.index
|-- model.ckpt.meta
|-- pipeline.config
`-- saved_model
    |-- saved_model.pb
    `-- variables

The pipeline.config是保存的模型的配置文件,frozen_inference_graph.pb用于现成的推理。请注意checkpoint, model.ckpt.data-00000-of-00001, model.ckpt.meta and model.ckpt.index 都对应检查点. (Here https://stackoverflow.com/a/44521818/1621414你可以找到一个很好的解释)

因此,当您想要获得可训练变量时,唯一有用的是saved_model目录。

使用 SavedModel 保存和加载模型 - 变量、图形和图形的元数据。这是一种语言中立、可恢复、密封的序列化格式,使更高级别的系统和工具能够生成、使用和转换 TensorFlow 模型。

为了恢复SavedModel你可以使用APItf.saved_model.loader.load() https://www.tensorflow.org/guide/saved_model#loading_a_savedmodel_in_python,并且此 api 包含一个称为tags,它指定了类型MetaGraphDef。所以如果你想得到可训练的变量,你需要指定tag_constants.TRAINING调用api时。

我尝试调用此 api 来恢复变量,但它给了我错误:

在 SavedModel 中找不到与标签“train”关联的 MetaGraphDef。要检查 SavedModel 中的可用标签集,请使用 SavedModel CLI:saved_model_cli

所以我这样做了saved_model_cli命令检查所有可用的标签SavedModel.

#from directory saved_model
saved_model_cli show --dir . --all

输出是

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
...
signature_def['serving_default']:
  ...

所以没有标签train但只有serve在此之内SavedModel. The SavedModel因此这里仅用于张量流服务。这意味着当创建这些文件时,未使用标签指定training,无法从这些文件中恢复任何训练变量。

P.S.:以下代码是我用来恢复的SavedModel。设置时tag_constants.TRAINING,加载无法完成但设置时tag_constants.SERVING,加载成功但变量为空。

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
  tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], export_dir)
  variables = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
  print(variables)

P.P.S:我找到了创建的脚本SavedModel here https://github.com/tensorflow/models/blob/e7b4d364de5dc9d66d23085a5c52b5d7631576a7/research/object_detection/exporter.py#L289。可以看出确实没有train创建时的标签SavedModel.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TF 对象检测 Zoo 模型没有可训练变量? 的相关文章

随机推荐