迁移学习:模型给出不变的损失结果。难道不是训练吗? [关闭]

2023-12-30

我正在尝试训练一个回归Inception V3 上的模型。输入是大小为 (96,320,3) 的图像。总共有 16k+ 图像,其中 12k+ 用于训练,其余用于验证。我已经冻结了 Inception 中的所有图层,但是解冻它们也没有帮助(已经尝试过)。我已经用几个层替换了预训练模型的顶部,如下面的代码所示。

X_train = preprocess_input(X_train)
inception = InceptionV3(weights='imagenet', include_top=False, input_shape=(299,299,3))
inception.trainable = False
print(inception.summary())

driving_input = Input(shape=(96,320,3))
resized_input = Lambda(lambda image: tf.image.resize(image,(299,299)))(driving_input)
inp = inception(resized_input)

x = GlobalAveragePooling2D()(inp)

x = Dense(512, activation = 'relu')(x)
x = Dense(256, activation = 'relu')(x)
x = Dropout(0.25)(x)
x = Dense(128, activation = 'relu')(x)
x = Dense(64, activation = 'relu')(x)
x = Dropout(0.25)(x)
result = Dense(1, activation = 'relu')(x)

lr_schedule = ExponentialDecay(initial_learning_rate=0.1, decay_steps=100000, decay_rate=0.95)
optimizer = Adam(learning_rate=lr_schedule)
loss = Huber(delta=0.5, reduction="auto", name="huber_loss")
model = Model(inputs = driving_input, outputs = result)
model.compile(optimizer=optimizer, loss=loss)

checkpoint = ModelCheckpoint(filepath="./ckpts/model.h5", monitor='val_loss', save_best_only=True)
stopper = EarlyStopping(monitor='val_loss', min_delta=0.0003, patience = 10)

batch_size = 32
epochs = 100

model.fit(x=X_train, y=y_train, shuffle=True, validation_split=0.2, epochs=epochs, 
          batch_size=batch_size, verbose=1, callbacks=[checkpoint, stopper])

This results in this: enter image description here

为什么我的模型没有进行训练,我可以采取什么措施来解决这个问题?


由于您的问题是回归问题,因此最后一层的激活应该是linear代替relu。而且学习率太高,你应该根据你的整体设置考虑降低它。在这里,我展示了 MNIST 的代码示例。

# data 
(xtrain, train_target), (xtest, test_target) = tf.keras.datasets.mnist.load_data()
# train_x, MNIST is gray scale, so in order to use it in pretrained weights , extending it to 3 axix
x_train = np.expand_dims(xtrain, axis=-1)
x_train = np.repeat(x_train, 3, axis=-1)
x_train = x_train.astype('float32') / 255
# prepare the label for regression model 
ytrain4 = tf.square(tf.cast(train_target, tf.float32))

# base model 
inception = InceptionV3(weights='imagenet', include_top=False, input_shape=(75,75,3))
inception.trainable = False

# inputs layer
driving_input = tf.keras.layers.Input(shape=(28,28,3))
resized_input = tf.keras.layers.Lambda(lambda image: tf.image.resize(image,(75,75)))(driving_input)
inp = inception(resized_input)

# top model 
x = GlobalAveragePooling2D()(inp)
x = Dense(512, activation = 'relu')(x)
x = Dense(256, activation = 'relu')(x)
x = Dropout(0.25)(x)
x = Dense(128, activation = 'relu')(x)
x = Dense(64, activation = 'relu')(x)
x = Dropout(0.25)(x)
result = Dense(1, activation = 'linear')(x)

# hyper-param
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=0.0001, 
                                                             decay_steps=100000, decay_rate=0.95)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
loss = tf.keras.losses.Huber(delta=0.5, reduction="auto", name="huber_loss")

# build models
model = tf.keras.Model(inputs = driving_input, outputs = result)
model.compile(optimizer=optimizer, loss=loss)

# callbacks
checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath="./ckpts/model.h5", monitor='val_loss', save_best_only=True)
stopper = tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0.0003, patience = 10)

batch_size = 32
epochs = 10

# fit 
model.fit(x=x_train, y=ytrain4, shuffle=True, validation_split=0.2, epochs=epochs, 
          batch_size=batch_size, verbose=1, callbacks=[checkpoint, stopper])

Output

1500/1500 [==============================] - 27s 18ms/step - loss: 5.2239 - val_loss: 3.6060
Epoch 2/10
1500/1500 [==============================] - 26s 17ms/step - loss: 3.5634 - val_loss: 2.9022
Epoch 3/10
1500/1500 [==============================] - 26s 17ms/step - loss: 3.0629 - val_loss: 2.5063
Epoch 4/10
1500/1500 [==============================] - 26s 17ms/step - loss: 2.7615 - val_loss: 2.3764
Epoch 5/10
1500/1500 [==============================] - 26s 17ms/step - loss: 2.5371 - val_loss: 2.1303
Epoch 6/10
1500/1500 [==============================] - 26s 17ms/step - loss: 2.3848 - val_loss: 2.1373
Epoch 7/10
1500/1500 [==============================] - 26s 17ms/step - loss: 2.2653 - val_loss: 1.9039
Epoch 8/10
1500/1500 [==============================] - 26s 17ms/step - loss: 2.1581 - val_loss: 1.9087
Epoch 9/10
1500/1500 [==============================] - 26s 17ms/step - loss: 2.0518 - val_loss: 1.7193
Epoch 10/10
1500/1500 [==============================] - 26s 17ms/step - loss: 1.9699 - val_loss: 1.8837

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

迁移学习:模型给出不变的损失结果。难道不是训练吗? [关闭] 的相关文章

  • 如何使用 conda 在一行中安装多个包?

    我需要使用 conda 安装以下多个软件包 我不确定 conda forge 是什么 有些使用 conda forge 有些不使用它 是否可以将它们安装成一行而不需要一一安装 谢谢 conda install c conda forge d
  • ca 证书 Mac OS X

    我需要在emacs 上安装offlineimap 和mu4e 问题是配置 当我运行 Offlineimap 时 我得到 OfflineIMAP 6 5 5 Licensed under the GNU GPL v2 v2 or any la
  • Tipfy:如何在模板中显示blob?

    鉴于在 gae 上使用tipfy http www tipfy org python 以下模型 greeting avatar db Blob avatar 显示 blob 此处为图像 的模板标签是什么 在这种情况下 斑点是一个图像 这很棒
  • Paramiko SSHException 通道已关闭

    我一直在使用 Paramiko 在 Linux Windows 机器上发送命令 它可以很好地在 Ubuntu 机器上远程执行测试 但是 它不适用于 Windows 7 主机 以下是我收到的错误 def unit for event self
  • 将一维数组转换为下三角矩阵

    我想将一维数组转换为较低的零对角矩阵 同时保留所有数字 我知道numpy tril函数 但它用零替换了一些元素 我需要扩展矩阵以包含所有原始数字 例如 10 20 40 46 33 14 12 46 52 30 59 18 11 22 30
  • numpy:大量线段/点的快速规则间隔平均值

    我沿着一维线有许多 约 100 万个 不规则间隔的点 P 这些标记线段 这样 如果点是 0 x a x b x c x d 则线段从 0 gt x a x a gt x b x b gt x c x c gt x d 等 我还有每个段的 y
  • 如何在Python代码中查找列号

    简短问题 当按上述方式调用函数时 我可以找到行号here https stackoverflow com questions 3056048 filename and line number of python script 同样 如何找到
  • Django 不会以奇怪的错误“AttributeError: 'module' object has no attribute 'getargspec'”启动

    我对 Django 的内部结构有点缺乏经验 所以我现在完全陷入困境 它昨天起作用了 但我不记得我改变过任何重要的东西 当我转身时DEBUG True任何恰好位于列表中第一个的模块上都有堆栈跟踪 Traceback most recent c
  • 按多个键分组并对字典列表的值进行汇总/平均值

    在Python中按多个键进行分组并对字典列表进行汇总 平均值的最Pythonic方法是什么 假设我有一个字典列表 如下所示 input dept 001 sku foo transId uniqueId1 qty 100 dept 001
  • uri 警告中缺少端口:使用 Python OpenCV cv2.VideoCapture() 打开文件时出错

    当我尝试流式传输 ipcam 时 出现了如下所示的错误 tcp 000000000048c640 uri 中缺少端口 警告 打开文件时出错 build opencv modules videoio src cap ffmpeg impl h
  • 根据第三个变量更改散点图中的标记样式

    我正在处理多列字典 我想绘制两列 然后根据第三列和第四列更改标记的颜色和样式 我很难改变 pylab 散点图中的标记样式 我的方法适用于颜色 不幸的是不适用于标记样式 x 1 2 3 4 5 6 y 1 3 4 5 6 7 m k l l
  • 为什么 __instancecheck__ 没有被调用?

    我有以下 python3 代码 class BaseTypeClass type def new cls name bases namespace kwd result type new cls name bases namespace p
  • 在python中读取PASCAL VOC注释

    我在 xml 文件中有注释 例如这个 它遵循 PASCAL VOC 约定
  • Python 导入非常慢 - Anaconda python 2.7

    我的 python import 语句变得非常慢 我使用 Anaconda 包在本地运行 python 2 7 导入模块后 我编写的代码运行得非常快 似乎只是导入需要很长时间 例如 我使用以下代码运行了一个 tester py 文件 imp
  • Python:无法使用 os.system() 打开文件

    我正在编写一个使用该应用程序的 Python 脚本pdftk http www pdflabs com tools pdftk the pdf toolkit 几次来执行某些操作 例如 我可以在 Windows 命令行 shell 中使用
  • 沿轴 0 重复 scipy csr 稀疏矩阵

    我想重复 scipy csr 稀疏矩阵的行 但是当我尝试调用 numpy 的重复方法时 它只是将稀疏矩阵视为对象 并且只会将其作为 ndarray 中的对象重复 我浏览了文档 但找不到任何实用程序来重复 scipy csr 稀疏矩阵的行 我
  • Streamlabs API 405 响应代码

    我正在尝试使用Streamlabs API https dev streamlabs com Streamlabs API 使用 Oauth2 来创建应用程序 因此 首先我将使用我的应用程序的用户发送到一个授权链接 其中包含我的应用程序的客
  • Java/Python 中的快速 IPC/Socket 通信

    我的应用程序中需要两个进程 Java 和 Python 进行通信 我注意到套接字通信占用了 93 的运行时间 为什么通讯这么慢 我应该寻找套接字通信的替代方案还是可以使其更快 更新 我发现了一个简单的修复方法 由于某些未知原因 缓冲输出流似
  • 如何使用 Python 3 正确显示倒计时日期

    我正在尝试获取将显示的倒计时 基本上就像一个世界末日时钟哈哈 有人可以帮忙吗 import os import sys import time import datetime def timer endTime datetime datet
  • 使用 SERVER_NAME 时出现 Flask 404

    在我的 Flask 配置中 我将 SERVER NAME 设置为 app example com 之类的域 我这样做是因为我需要使用url for with external网址 如果未设置 SERVER NAME Flask 会认为服务器

随机推荐

  • 按上下文获取所有标签以实现 acts-as-taggable-on

    We use https github com mbleigh acts as taggable on https github com mbleigh acts as taggable on对于我们的 Rails 应用程序 我们遇到了问题
  • 如何与 React Test Renderer / Jest 渲染的组件交互

    我正在使用 Jest 和快照测试 我想做的是渲染一个组件ReactTestRenderer 然后模拟单击其中的按钮 然后验证快照 ReactTestRenderer 返回的对象create呼叫有一个getInstance函数允许我直接调用它
  • 不兼容的片段类型

    你好 我在 android 中有一个小应用程序 我在其中使用带导航抽屉的片段作为菜单 但现在我想在用户单击某些内容时在片段对话框弹出窗口中显示 并且出现以下错误 主要活动 private void displayView int posit
  • shell 脚本参数非位置

    有没有办法将非位置参数提供给 shell 脚本 意思是明确指定某种标志 myscript sh value1 value2 myscript sh val1 value1 val2 value2 您可以使用getopts 但我不喜欢它 因为
  • MySQL 错误 1241:操作数应包含 1 列

    我正在尝试将表1中的数据插入表2中 insert into table2 Name Subject student id result select Name Subject student id result from table1 表2
  • 在.Net Framework中使用最新版本的System.Net.Http

    最新版本System Net Http https www nuget org packages System Net Http nuget 上的版本是 4 3 4 但即使是最新的 Net Framework 4 8 也附带了该库的 4 2
  • 拼写检查等统计句子建议模型

    已经有可用的拼写检查模型 可以帮助我们根据经过训练的正确拼写语料库找到建议的正确拼写 是否可以将粒度从字母表增加到 单词 以便我们可以有均匀的短语建议 这样如果输入了错误的短语 那么它应该从正确短语的语料库中建议最接近的正确短语 当然它是从
  • Google 地图信息窗口中的 YouTube 视频

    我正在尝试将 YouTube 视频放入 Google 地图 v3 信息窗口中 它在 Firefox 和 Internet Explorer 中运行良好 It does not在 Safari 和 Chrome 中工作 在这些浏览器中 定位已
  • 在 Activity.onCreate(..) 中显示警报

    我是 Android 新手 这是我的第一个问题 所以请放轻松 是否可以检查 Activity 的 onCreate 内的某些条件并显示 AlertDialog 我在 Oncreate 中匿名创建一个 AlertDialog 并在该实例上调用
  • 使用 R Markdown 的 Beamer 演示

    我正在使用 R Markdown 来制作投影仪演示我对幻灯片水平有疑问 我选择法兰克福主题 该主题允许制定演示计划 标题中的项目符号 我的问题 当我输入 slide level 2 时 我有内容 但没有演示文稿的计划 当我输入 slide
  • 如何从命令行将错误列表(或任何自定义查询)从 TFS 导出到 Excel?

    我需要将错误列表从 Team Foundation Server 导出到 Excel 手动执行此操作很简单 但我需要命令行版本 因为该任务需要自动化 有人知道该怎么做吗 回答你原来的问题 在 TFS 中添加新查询 创建查询并单击 保存 这应
  • 使用 Visual Studio 创建 MSI 并强制所有用户

    我使用 Visual Studio 2015 带有 Visual Studio 安装程序插件 创建了一个安装程序 目标是始终使用相同的本地资源运行应用程序 无论谁登录 因此我们的目标是 CommonAppDataFolder Win10 上
  • 淡出旧元素,淡入新元素

    我是新来反应并尝试过反应动画 http facebook github io react docs animation html 当在 TransitionGroup 中添加或删除元素时 它们非常有用 但是 如果我用类似的元素替换单个元素
  • 非活动类中的警报对话框

    我有一个代码可以检查一些数据并在非活动类中显示警报 但是在运行应用程序时崩溃并且不显示警报对话框 我使用下面的代码 if str isEmpty strPort isEmpty new AlertDialog Builder Mtx get
  • Kivy (Python) - 椭圆点击事件

    我正在尝试翻译的开头一个简单的画布应用程序 https bloom510 github io pitch canvas 我用 JavaScript 编写了 Kivy 框架 我已经能够沿着圆的周长分布顶点 但是无论是用 Python 还是 K
  • 即使在 conda 中安装后也无法导入 Poppler

    我正在尝试使用 pdf 渲染包 Poppler 我在这里找到了相同的 Anaconda 安装 https anaconda org conda forge poppler https anaconda org conda forge pop
  • Python读取文件时出现UnicodeDecodeError,如何忽略错误并跳转到下一行?

    我必须将文本文件读入Python 文件编码为 file bi test csv text plain charset us ascii 这是第三方文件 我每天都会收到一个新文件 所以我宁愿不更改它 该文件包含非 ascii 字符 例如 我需
  • Laravel - 有没有办法将 whereHas 和 with 结合起来

    我目前面临一个小问题 我只想在存在特定条件的关系时返回模型 这与 whereHas 方法配合得很好 m Model whereHas programs function q q gt active 但是 将关系称为这样的属性将为我提供所有内
  • Flask 将 pyaudio 发送到浏览器

    我正在将服务器麦克风的音频发送到浏览器 主要是this https stackoverflow com a 56037682 4871482发布但有一些修改的选项 一切都工作正常 直到你转到手机或野生动物园 它根本不起作用 我尝试过使用类似
  • 迁移学习:模型给出不变的损失结果。难道不是训练吗? [关闭]

    Closed 这个问题是基于意见的 help closed questions 目前不接受答案 我正在尝试训练一个回归Inception V3 上的模型 输入是大小为 96 320 3 的图像 总共有 16k 图像 其中 12k 用于训练