MobileNetV2 的 Keras 和 TensorFlow Hub 版本之间的差异

2023-12-03

我正在研究一种迁移学习方法,并且在使用 MobileNetV2 时得到了非常不同的结果keras.applications以及 TensorFlow Hub 上提供的一个。这对我来说似乎很奇怪,因为两个版本都声称here and here从同一检查点提取它们的权重mobilenet_v2_1.0_224。 这是如何重现差异的,你可以找到 Colab Notebookhere:

!pip install tensorflow-gpu==2.1.0
import tensorflow as tf
import numpy as np
import tensorflow_hub as hub
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2

def create_model_keras():
  image_input = tf.keras.Input(shape=(224, 224, 3))
  out = MobileNetV2(input_shape=(224, 224, 3),
                  include_top=True)(image_input)
  model = tf.keras.models.Model(inputs=image_input, outputs=out)
  model.compile(optimizer='adam', loss=["categorical_crossentropy"])
  return model

def create_model_tf():
  image_input = tf.keras.Input(shape=(224, 224 ,3))
  out = hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4",
                      input_shape=(224, 224, 3))(image_input)
  model = tf.keras.models.Model(inputs=image_input, outputs=out)
  model.compile(optimizer='adam', loss=["categorical_crossentropy"])
  return model

当我尝试对随机批次进行预测时,结果不相等:

keras_model = create_model_keras()
tf_model = create_model_tf()
np.random.seed(42)
data = np.random.rand(32,224,224,3)
out_keras = keras_model.predict_on_batch(data)
out_tf = tf_model.predict_on_batch(data)
np.array_equal(out_keras, out_tf)

版本的输出来自keras.applications总和为 1,但 TensorFlow Hub 的版本不是。而且两个版本的形状也不同:TensorFlow Hub 有 1001 个标签,keras.applications有 1000 个。

np.sum(out_keras[0]), np.sum(out_tf[0])

prints (1.0000001, -14.166359)

造成这些差异的原因是什么?我错过了什么吗?

编辑 2020年2月18日

正如 Szymon Maszke 指出的,TFHub 版本返回 logits。这就是为什么我添加了一个 Softmax 层create_model_tf如下:out = tf.keras.layers.Softmax()(x)

arnoegw提到TfHub版本需要将图像标准化为[0,1],而keras版本需要标准化为[-1,1]。当我对测试图像使用以下预处理时:

from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
img = tf.keras.preprocessing.image.load_img("/content/panda.jpeg", target_size=(224,224))
img = tf.keras.preprocessing.image.img_to_array(img)
img = preprocess_input(img)
img = tf.io.read_file("/content/panda.jpeg")
img = tf.image.decode_jpeg(img)
img = tf.image.convert_image_dtype(img, tf.float32)
img = tf.image.resize(img, (224,224))

两者都正确预测相同的标签,并且以下条件为真:np.allclose(out_keras, out_tf[:,1:], rtol=0.8)

编辑 2 2020 年 2 月 18 日在我写之前,不可能将格式相互转换。这是由一个错误引起的。


有几个已记录的差异:

  • 正如 Szymon 所说,TF Hub 版本返回 logits(在将其转换为概率的 softmax 函数之前),这是一种常见的做法,因为可以根据 logits 计算出具有更高数值稳定性的交叉熵损失。

  • TF Hub 模型假设 float32 输入在 [0,1] 范围内,这是您得到的tf.image.decode_jpeg(...)其次是tf.image.convert_image_dtype(..., tf.float32)。 Keras 代码使用特定于模型的范围(可能是 [-1,+1])。

  • TF Hub 模型在返回其所有 1001 个输出类时更完整地反映了原始 SLIM 检查点。正如文档中链接的 ImageNetLabels.txt 中所述,添加的类 0 是“背景”(又名“东西”)。这就是对象检测用来指示图像背景而不是任何已知类别的对象的方法。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

MobileNetV2 的 Keras 和 TensorFlow Hub 版本之间的差异 的相关文章

  • Tensorflow Hub - 获取模型的输入形状和问题域?

    我正在使用最新版本的tensorflow hub 想知道如何获取有关模型的预期输入形状以及模型属于什么类型的集合的信息 例如 有没有办法以这种方式在 Python 中加载模型后获取有关预期图像形状的信息 model hub load htt
  • Tensorflow 不分配完整的 GPU 内存

    Tensorflow 默认分配所有 GPU 内存 但我的新设置实际上只有 9588 MiB 11264 MiB 我预计大约 11 000MiB 就像我的旧设置一样 张量流信息在这里 from tensorflow python client
  • ValueError:没有为“dense_input”提供数据

    我正在使用以下简单的代码使用tensorflow加载csv并使用keras执行建模 无法弄清楚这个错误 import tensorflow as tf train dataset fp tf keras utils get file fna
  • NumPy 相当于 Keras 函数 utils.to_categorical

    我有一个使用 Keras 进行机器学习的 Python 脚本 我正在构建 X 和 Y 它们分别是特征和标签 标签的构建方式如下 def main depth 10 nclass 101 skip True output True video
  • 使用大数据集在 Google Colab TPU 上训练 seq2seq 模型 - Keras

    我正在尝试使用 Google Colab TPU 上的 Keras 训练用于机器翻译的序列到序列模型 我有一个可以加载到内存中的数据集 但我必须对其进行预处理才能将其提供给模型 特别是 我需要将目标单词转换为一个热向量 并且在许多示例中 我
  • 在 Keras 模型中删除然后插入新的中间层

    给定一个预定义的 Keras 模型 我尝试首先加载预先训练的权重 然后删除一到三个模型内部 非最后几层 层 然后用另一层替换它 我似乎找不到任何有关的文档keras io https keras io 即将做这样的事情或从预定义的模型中删除
  • Keras:如何保存模型或权重?

    如果这个问题看起来很简单 我很抱歉 但是阅读 Keras 保存和恢复帮助页面 https www tensorflow org beta tutorials keras save and restore models https www t
  • Tensorflow 与 Keras 的兼容性

    我正在使用 Python 3 6 和 Tensorflow 2 0 并且有一些 Keras 代码 import keras from keras models import Sequential from keras layers impo
  • 无法使用 Keras 中的 multi_gpu_model 后的 model.save 保存模型

    升级到 Keras 2 0 9 后 我一直在使用multi gpu model实用程序 但我无法使用保存我的模型或最佳权重 model save path 我得到的错误是 类型错误 无法pickle模块对象 我怀疑访问模型对象时存在一些问题
  • 将 Keras 集成到 SKLearn 管道?

    我有一个 sklearn 管道 对异构数据类型 布尔 分类 数字 文本 执行特征工程 并想尝试使用神经网络作为我的学习算法来拟合模型 我遇到了输入数据形状的一些问题 我想知道我想做的事情是否可能 或者我是否应该尝试不同的方法 我尝试了几种不
  • 卷积神经网络 (CNN) 输入形状

    我是 CNN 的新手 我有一个关于 CNN 的问题 我对 CNN 特别是 Keras 的输入形状有点困惑 我的数据是不同时隙的二维数据 比方说10X10 因此 我有 3D 数据 我将把这些数据输入到我的模型中来预测即将到来的时间段 所以 我
  • 增加 sigmoid 预测输出值?

    我创建了一个用于文本分类的 Conv1D 模型 当在最后一个密集处使用 softmax sigmoid 时 它产生的结果为 softmax gt 0 98502016 0 0149798 sigmoid gt 0 03902826 0 00
  • Tensorflow 可变图像输入大小(自动编码器、放大......)

    Edit WARNING不建议使用不同图像大小的图像 因为张量需要具有相同的大小才能实现并行化 我一直在寻找解决方案 了解如何使用不同大小的图像作为神经网络的输入 Numpy 第一个想法是使用numpy 然而 由于每个图像的大小不同 我无法
  • Keras conv1d 层参数:过滤器和 kernel_size

    我对 keras 的 conv1d 层中的这两个参数感到非常困惑 https keras io layers convolutional conv1d https keras io layers convolutional conv1d 文
  • Keras IndexError:索引超出范围

    我是 Keras 新手 我尝试在数据集上执行二进制 MLP 并且不断使索引超出范围 但不知道为什么 from keras models import Sequential from keras layers core import Dens
  • Keras - Nan 总结直方图 LSTM

    我使用 Keras 编写了一个 LSTM 模型 并使用 LeakyReLU 高级激活 ADAM Optimizer with learning rate decay opt optimizers Adam lr 0 0001 beta 1
  • 使用 Keras 和 fit_generator 绘制 TensorBoard 分布和直方图

    我正在使用 Keras 使用 fit generator 函数训练 CNN 这似乎是一个已知问题 https github com fchollet keras issues 3358TensorBoard 在此设置中不显示直方图和分布 有
  • 收到的标签值 1 超出了 [0, 1) 的有效范围 - Python、Keras

    我正在使用具有张量流背景的 keras 开发一个简单的 cnn 分类器 def cnnKeras training data training labels test data test labels n dim print Initiat
  • 如何使用 lstm 执行多类多输出分类

    I have multiclass multioutput classification see https scikit learn org stable modules multiclass html https scikit lear
  • 尝试校准keras模型

    我正在尝试通过 Sklearn 实现来校准我的 CNN 模型CalibratedClassifierCV 尝试将其包装为KerasClassifier并覆盖预测功能但没有成功 有人可以说我做错了什么吗 这是模型代码 def create m

随机推荐

  • 在 PHP 中提取字符串的特定部分

    我只是想知道在 PHP 中提取动态字符串的特定部分最简单 最有效的方法是什么 例如 在此字符串中 http www dailymotion com video xclep1 school gyrls something like a par
  • Android Firebase云功能通知

    我已成功设置 firebase 云功能来向主题发送通知 问题是它发送给包括发件人在内的所有用户 我如何设置我的云功能 以便它不向发件人显示通知 请帮忙 以下是我如何发送到主题 exports sendNotesNotification fu
  • 父级上的 CKEditor“溢出:滚动”导致工具栏冻结在初始位置

    当您使用以下命令将 CKEditor 添加到 div 内的 div 时 overflow scroll 滚动父 div 时工具栏不会移动 div div This is the ckedito div div 可以在这里找到一个例子 htt
  • 当我导入客户端库时,为什么会出现 ReferenceError: self is not Defined ?

    试图创建一个xterm反应组件Next js我陷入了困境 因为我无法克服以前从未收到过的错误消息 我正在尝试导入一个名为的 npm 客户端模块xterm 但是如果我添加导入行 应用程序就会崩溃 import Terminal from xt
  • 正则表达式按空格分割但不转义空格

    我想按标准空白进行分割 但没有转义空格 例如 使用字符串 my name is max 单引号所以 是字面意思 我想要得到 my name is max 我试过这个正则表达式 s 但结果是这样的 gt m name is max 这很接近
  • 如何在 Dartlang 中检索元数据?

    Dartlang教程介绍package metahttps www dartlang org docs dart up and running contents ch02 html ch02 metadata DartEditor 识别元数
  • 从字符串中提取电话号码

    我正在尝试从给定的字符串中提取java中的电话号码 即电话号码可以位于字符串中的任何位置 例如 bla bla TELEPHONE NUMBER bla bla 现在我想在另一个字符串中提取这个电话号码 在使用时 matcher match
  • 如何将保存的 localStorage Web 数据传递到 php 脚本?

    好吧 所以我在尝试找出如何将我保存在 localStorage 中的一些数据传递到我编写的 php 脚本时遇到了一些问题 这样我就可以将其发送到服务器上的数据库 我之前确实找到了一些代码 https developer mozilla or
  • 发送 Outlook 日历邀请 PHP

    该代码的目标是使用 PHP 发送约会和阻止人员日历 我这里有两页 测试 php
  • 通过缓存电子表格值提高脚本性能

    我正在尝试使用 Google Apps 脚本开发一个网络应用程序 将其嵌入到 Google 站点中 该站点仅显示 Google 表格的内容并使用一些简单的参数对其进行过滤 至少目前是这样 稍后我可能会添加更多功能 我得到了一个功能齐全的应用
  • 将密码重置发送到其他电子邮件 - Devise

    我正在使用 Ruby on Rails 5 和 devise 我需要将密码重置电子邮件发送到与我的用户表中存储的电子邮件不同的电子邮件 如何才能实现这一目标 请注意 这是非常不推荐的实现方式 它不在最佳实践的范围内 它又脏又脆弱 但如果你真
  • Apple 文件系统从照片库读取的权限

    我的 ios 应用程序中有一个 UIWebView 它将响应式网站加载到我的 webview 中 在 asp net 中开发 网站有一个按钮用于从设备照片库中选择视频 另一个按钮用于上传视频 在 ios 版本 10 2 之前 它可以成功地将
  • 在帆和水线中混合使用 AND 和 OR 子句

    如何在 Sailsjs 及其 ORM Waterline 中使用 OR 和 AND 子句 例如我有一张书表 book name author free public Book A Author 1 false true Book B Aut
  • 错误标记主机:等待条件超时 [kubernetes]

    我刚刚开始学习 Kubernetes 我已经通过 Kubernetes YUM 存储库安装了 CentOS 7 5 并禁用了 SELinux 的 kubectl kubeadm 和 kubelet 然而 当我想开始一个kubeadm ini
  • 撇号 cms - 自定义小部件中富文本的内联编辑?

    在某些情况下 我无法将富文本的内联编辑保存回数据库 请耐心等待 这里将粘贴一些代码 因为这是我描述我正在做的事情的唯一方式 我的项目中有两种自定义小部件 一种只有一个小部件实例 通常在lib modules目录 article widget
  • 依赖注入类型选择

    最近我遇到一个问题 我必须根据参数选择类型 例如 用于发送通知的类 应根据输入参数选择正确的渠道 电子邮件 短信等 我看起来像这样 public class NotificationManager IEmail email ISms sms
  • Google URLShortener API 返回 ipRefererBlocked

    我正在尝试将 Google URL 缩短 API 与 PHP 结合使用 apiKey ABC url http www stackoverflow com postData array longUrl gt url jsonData jso
  • 正则表达式匹配除空格之外的单个字符

    我需要匹配一个不是空格的单个字符 但我不知道如何使用正则表达式来做到这一点 以下应该足够了 如果您想将其扩展到除空白之外的任何内容 换行符 制表符 空格 硬空格 s or S Note this is a CAPITAL S
  • 将数据从操作传递到另一个操作

    如何通过 RedirectAction 方法将模型从 GetDate 操作传递到另一个 ProcessP 操作 这是源代码 HttpPost public ActionResult GetDate FormCollection values
  • MobileNetV2 的 Keras 和 TensorFlow Hub 版本之间的差异

    我正在研究一种迁移学习方法 并且在使用 MobileNetV2 时得到了非常不同的结果keras applications以及 TensorFlow Hub 上提供的一个 这对我来说似乎很奇怪 因为两个版本都声称here and here从