Tensorflow:加载预训练 ResNet 模型时出错

2024-04-21

我想使用 Tensorflow 中预先训练的 ResNet 模型。我下载了code https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_v1.py (resnet_v1.py)对于模型和检查站 http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz (resnet_v1_50.ckpt) file here https://github.com/tensorflow/models/tree/master/research/slim.

我已经可以解决该错误ImportError: No module named 'nets'通过使用以下帖子:请参阅here https://stackoverflow.com/questions/46030481/importerror-no-module-named-nets答案来自茨维蒂科 https://stackoverflow.com/users/4137497/tsveti-iko.

现在我收到以下错误并且不知道该怎么办:

NotFoundError (see above for traceback): Restoring from checkpoint failed. 
This is most likely due to a Variable name or other graph key that is missing from the checkpoint. 
Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

    Tensor name "resnet_v1_50/block1/unit_1/bottleneck_v1/conv1/biases" 
not found in checkpoint files /home/resnet_v1_50.ckpt
         [[node save/RestoreV2 (defined at my_resnet.py:34)  = 
RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ...,
DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost
/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2
/tensor_names, save/RestoreV2/shape_and_slices)]]

这是我尝试加载模型时使用的代码:

import numpy as np
import tensorflow as tf
import resnet_v1

# Restore variables of resnet model
slim = tf.contrib.slim

# Paths
network_dir = "home/resnet_v1_50.ckpt"

# Image dimensions
in_width, in_height, in_channels = 224, 224, 3

# Placeholder
X = tf.placeholder(tf.float32, [None, in_width, in_height, in_channels])

# Define network graph
logits, activations = resnet_v1.resnet_v1_50(X, is_training=False)
prediction = tf.argmax(logits, 1)

with tf.Session() as sess:
    variables_to_restore = slim.get_variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)
    saver.restore(sess, network_dir) 

    # Restore variables
    variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

    # Feed random image into resnet
    img = np.random.randn(1, in_width, in_height, in_channels)
    pred = sess.run(prediction, feed_dict={X:img})

谁能告诉我,为什么它不起作用?我必须如何更改代码才能使其运行?


也许你可以使用 ResNet50tf.keras.applications https://www.tensorflow.org/api_docs/python/tf/keras/applications?

根据错误,如果您没有以任何方式更改图表,并且这是您的整个源代码,那么调试可能真的非常困难。

如果你选择理智tf.keras.applications.resnet50 https://keras.io/applications/#resnet你可以像这样简单地做到这一点:

import tensorflow

in_width, in_height, in_channels = 224, 224, 3

pretrained_resnet = tensorflow.keras.applications.ResNet50(
    weights="imagenet",
    include_top=False,
    input_shape=(in_width, in_height, in_channels),
)

# You can freeze some layers if you want, depends on your task
# Make "top" (last 3 layers below) whatever fits your task as well

model = tensorflow.keras.models.Sequential(
    [
        pretrained_resnet,
        tensorflow.keras.layers.Flatten(),
        tensorflow.keras.layers.Dense(1024, activation="relu"),
        tensorflow.keras.layers.Dense(10, activation="softmax"),
    ]
)

print(model.summary())

现在推荐使用这种方法,特别是考虑到即将推出的 Tensorflow 2.0、合理性和可读性。 顺便提一句。该模型与Tensorflow提供的模型相同,是从IIRC转来的。

您可以阅读更多有关tf.keras.applications在链接的文档和各种博客文章中,例如this one https://www.learnopencv.com/keras-tutorial-fine-tuning-using-pre-trained-models/或其他网络资源。

我如何在 Keras 中

回答评论中的问题

  • How do I pass images to the network?: use model.predict(image)如果你想做出预测,图像在哪里np.array。就那么简单。

  • How do I access weights?:嗯,这个比较复杂……开玩笑吧,每一层都有.get_weights()返回其权重和偏差的方法,您可以使用以下方法迭代各层for layer in model.layers()。您可以使用一次获得所有权重model.get_weights()以及。

总而言之,您将学习 Keras,并且在比调试此问题更短的时间内比在 Tensorflow 中更有效率。他们有30秒指导 https://keras.io/#getting-started-30-seconds-to-keras因为某种原因。

BTW.Tensorflow 默认提供 Keras,因此 Tensorflow 的 Keras 风格isTensorflow 的一部分(无论这听起来多么令人困惑)。这就是为什么我用过tensorflow在我的例子中。

使用 Tensorflow Hub 的解决方案

看来你可以加载和微调Resnet50使用集线器,如所述这个链接 https://tfhub.dev/google/imagenet/resnet_v1_50/feature_vector/1.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow:加载预训练 ResNet 模型时出错 的相关文章

随机推荐

  • 无法从“方法组”转换为“System.Action<对象>”错误

    我创建了以下函数 public void DelegatedCall Action delegatedMethod 并定义了以下方法 public void foo1 String str 但是 当我尝试打电话时DelegateCall w
  • 如何以最佳方式将 SQL 查询转换为 cypher?

    我是 neo4j 的新手 使用 3 0 版本 我有一个巨大的事务数据集 我将其转换为图形模型 我需要将下面的 SQL 查询转换为 cypher create table calc base as select a ticket id tic
  • 保存为自动填充对话框未显示

    我有一个显示用户名 UI 的活动 输入该活动并点击继续按钮后会显示输入密码 UI 输入密码并点击登录按钮后 完成当前活动并启动新活动 在我的设备上 我选择了 Google 自动填充服务 因此在第一个活动完成后 我想要 保存以供自动填充 对话
  • Maven AppAssembler 找不到类

    尝试修改现有的 Java Tomcat 应用程序以按照其部署在 Heroku 上tutorial https devcenter heroku com articles create a java web application using
  • has_many :autosave => true 保存子项时跳过验证

    在 Rails 2 和 Rails 3 中 如果 autosave gt true 是一个 has many 关联 则循环遍历集合并对每个子关联调用 save validate gt false 这是为什么 我们需要为该子对象运行 befo
  • 创建 SKShapeNode 的子类

    class ColorRectangle SKShapeNode var width CGFloat var height CGFloat var rectColor UIColor convenience init rectOfSize
  • 使用 NSValueTransformer 加密 iOS 核心数据

    我正在尝试使用 Core Data 和 CommonCrypto 加密数据 我正在尝试使用 NSValueTransformer 来延迟加密和解密 但是 当我现在尝试将加密数据保存到持久存储协调器时 它失败了 每次我尝试将数据保存到数据库时
  • RecyclerView - 获取 Activity 内的位置而不是 RecyclerView 适配器

    这是我处理视图点击的第三天 我原来用的是ListView 然后我切换到RecyclerView 我已经添加了android onclick我的每个控件的元素row layout我正在处理它们MainActivity像这样 public vo
  • Moment js - 获取日期而不考虑时区

    我确实阅读了不同的 StackOverflow 帖子 他们建议从一开始就使用 utc 但它不起作用 Note 我在 PST 区域 const start 2018 06 10T21 00 00 04 00 const end 2018 06
  • MS2015中的MvcBuildViews需要很长时间

    我们正在转换解决方案以使用新的 Roslyn 编译器 当我在发布模式下通过 teamCity 构建它时 MVCBuildViews 步骤仍然使用 aspnet compiler exe 并且预编译视图需要大约 15 分钟 在 NET 4 5
  • Factory_girl 与 validates_presence_of 有关系

    我有 2 个型号 user rb class User lt ActiveRecord Base has one profile dependent gt destroy end profile rb class Profile lt Ac
  • Linux 中允许的 c/c++ 最大互斥体数量

    我一直在尝试找出 Linux 中 c c 进程的最大互斥体数量是多少 但没有成功 另外 有没有办法修改这个数字 我正在读的书提到了如何找到Linux中允许的最大线程数以及如何修改这个数字 但没有提到互斥体 检查这个pthread mutex
  • Django Postgresql 在迁移时删除列默认值

    我面临表默认值的问题 例如我有这个模型 class model1 models Model field1 models CharField max length 50 default My Default Value 1 db column
  • 如何完全静音 bash 脚本中的 vlc 输出?

    我有一个为自己编写的脚本 它在接近结束的地方使用 vlc 我需要它停止输出它想要的任何内容 但保留我自己的输出 所以没有 清除 我使用了参数 q 和 no sout x264 quiet 但无济于事 它仍然输出丑陋的消息 即 警告 调用 r
  • 使用两个具有相同命名空间的 .NET 库

    我目前正在为一家公司维护一些旧代码 正如所发生的那样 我正在修改的当前应用程序使用旧版本的内部库 我们将其称为 Lib1 dll 他们还有一个名为 Lib2 dll 的新版本库 它在许多方面对以前的库进行了改进 不幸的是 Lib2 不向后兼
  • VBS 脚本 getElementbyID 错误(自动登录脚本)

    我正在编写适用于不同站点的 vbs 脚本文件 但我正在为我的大学网页编写用于互联网页面登录的自动登录脚本 所以我一直在工作直到填写用户名和密码 但我无法让它点击登录 这是大学登录的链接 我不确定您是否可以从网络外访问它 请注意编辑请不要将其
  • 从 DataReader 读取数据的通用方法

    我目前正在使用此方法从 DataReader 读取数据 private T GetValue
  • LinkedList“节点跳转”

    试图找出为什么我的list 类指针被第三个节点覆盖 插入函数 如下 中发生的情况是 第三次调用插入函数headByName gt nextByName当第三个节点应该指向第二个节点时 节点指针被第三个节点覆盖 因此 您可以猜测第 4 个节点
  • 使用 Velocity 生成基于 HTML 的电子邮件

    我尝试这个教程http www java2s com Code Java Velocity UseVelocitytogenerateHTMLbasedemail htm http www java2s com Code Java Velo
  • Tensorflow:加载预训练 ResNet 模型时出错

    我想使用 Tensorflow 中预先训练的 ResNet 模型 我下载了code https github com tensorflow models blob master research slim nets resnet v1 py