我想使用 Tensorflow 中预先训练的 ResNet 模型。我下载了code https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_v1.py (resnet_v1.py
)对于模型和检查站 http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz (resnet_v1_50.ckpt
) file here https://github.com/tensorflow/models/tree/master/research/slim.
我已经可以解决该错误ImportError: No module named 'nets'
通过使用以下帖子:请参阅here https://stackoverflow.com/questions/46030481/importerror-no-module-named-nets答案来自茨维蒂科 https://stackoverflow.com/users/4137497/tsveti-iko.
现在我收到以下错误并且不知道该怎么办:
NotFoundError (see above for traceback): Restoring from checkpoint failed.
This is most likely due to a Variable name or other graph key that is missing from the checkpoint.
Please ensure that you have not altered the graph expected based on the checkpoint. Original error:
Tensor name "resnet_v1_50/block1/unit_1/bottleneck_v1/conv1/biases"
not found in checkpoint files /home/resnet_v1_50.ckpt
[[node save/RestoreV2 (defined at my_resnet.py:34) =
RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ...,
DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost
/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2
/tensor_names, save/RestoreV2/shape_and_slices)]]
这是我尝试加载模型时使用的代码:
import numpy as np
import tensorflow as tf
import resnet_v1
# Restore variables of resnet model
slim = tf.contrib.slim
# Paths
network_dir = "home/resnet_v1_50.ckpt"
# Image dimensions
in_width, in_height, in_channels = 224, 224, 3
# Placeholder
X = tf.placeholder(tf.float32, [None, in_width, in_height, in_channels])
# Define network graph
logits, activations = resnet_v1.resnet_v1_50(X, is_training=False)
prediction = tf.argmax(logits, 1)
with tf.Session() as sess:
variables_to_restore = slim.get_variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess, network_dir)
# Restore variables
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
# Feed random image into resnet
img = np.random.randn(1, in_width, in_height, in_channels)
pred = sess.run(prediction, feed_dict={X:img})
谁能告诉我,为什么它不起作用?我必须如何更改代码才能使其运行?