tf.keras 损失变为 NaN

2023-12-09

我正在 tf.keras 中编写一个 3 层的神经网络。我的数据集是 MNIST 数据集。我减少了数据集中的示例数量,因此运行时间较短。这是我的代码:

import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import pandas as pd

!git clone https://github.com/DanorRon/data
%cd data
!ls

batch_size = 32
epochs = 10
alpha = 0.0001
lambda_ = 0
h1 = 50

train = pd.read_csv('/content/first-repository/mnist_train.csv.zip')
test = pd.read_csv('/content/first-repository/mnist_test.csv.zip')

train = train.loc['1':'5000', :]
test = test.loc['1':'2000', :]

train = train.sample(frac=1).reset_index(drop=True)
test = test.sample(frac=1).reset_index(drop=True)

x_train = train.loc[:, '1x1':'28x28']
y_train = train.loc[:, 'label']

x_test = test.loc[:, '1x1':'28x28']
y_test = test.loc[:, 'label']

x_train = x_train.values
y_train = y_train.values

x_test = x_test.values
y_test = y_test.values

nb_classes = 10
targets = y_train.reshape(-1)
y_train_onehot = np.eye(nb_classes)[targets]

nb_classes = 10
targets = y_test.reshape(-1)
y_test_onehot = np.eye(nb_classes)[targets]

model = tf.keras.Sequential()
model.add(layers.Dense(784, input_shape=(784,)))
model.add(layers.Dense(h1, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(lambda_)))
model.add(layers.Dense(10, activation='sigmoid', kernel_regularizer=tf.keras.regularizers.l2(lambda_)))

model.compile(optimizer=tf.train.GradientDescentOptimizer(alpha), 
             loss = 'categorical_crossentropy',
             metrics = ['accuracy'])

model.fit(x_train, y_train_onehot, epochs=epochs, batch_size=batch_size)

每当我运行它时,都会发生以下三件事之一:

  1. 在几个时期内,损失会减少,准确度会增加,直到损失无明显原因变为 NaN,准确度直线下降。

  2. 每个时期的损失和准确性保持不变。通常损失为2.3025,精度为0.0986。

  3. 损失从 NaN 开始(并保持这种状态),而准确度仍然很低。

大多数时候,模型会执行其中一项操作,但有时它会执行随机操作。发生的不稳定行为类型似乎是完全随机的。我不知道问题是什么。我该如何解决这个问题?

编辑:有时,损失会减少,但准确性保持不变。另外,有时损失会减少,准确度会增加,然后一段时间后,准确度会下降,但损失仍然会减少。或者,损失减少,准确度增加,然后切换,损失快速上升,而准确度直线下降,最终以损失结束:2.3025 acc:0.0986。

编辑2:这是有时会发生的事情的一个例子:

Epoch 1/100
49999/49999 [==============================] - 5s 92us/sample - loss: 1.8548 - acc: 0.2390

Epoch 2/100
49999/49999 [==============================] - 5s 104us/sample - loss: 0.6894 - acc: 0.8050

Epoch 3/100
49999/49999 [==============================] - 4s 90us/sample - loss: 0.4317 - acc: 0.8821

Epoch 4/100
49999/49999 [==============================] - 5s 104us/sample - loss: 2.2178 - acc: 0.1345

Epoch 5/100
49999/49999 [==============================] - 5s 90us/sample - loss: 2.3025 - acc: 0.0986

Epoch 6/100
49999/49999 [==============================] - 4s 90us/sample - loss: 2.3025 - acc: 0.0986

Epoch 7/100
49999/49999 [==============================] - 4s 89us/sample - loss: 2.3025 - acc: 0.0986

编辑 3:我将损失更改为均方误差,网络现在运行良好。有没有办法让它保持交叉熵而不收敛到局部最小值?


我将损失更改为均方误差,网络现在运行良好

MSE is not针对此类分类问题的适当损失函数;你当然应该坚持loss = 'categorical_crossentropy'.

最有可能的是,该问题是由于您的 MNIST 数据未标准化所致;你应该将你的最终变量标准化为

x_train = x_train.values/255
x_test = x_test.values/255

不规范输入数据是导致梯度爆炸问题的已知原因,这可能就是这里发生的情况。

其他建议:设置activation='relu'对于您的第一个密集层,并摆脱所有层中的正则化器和初始化器参数(默认glorot_uniform实际上是一个更好的初始化器,而这里的正则化实际上可能对性能有害)。

作为一般建议,请尝试not重新发明轮子——从一个开始喀拉拉邦示例使用内置 MNIST 数据...

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

tf.keras 损失变为 NaN 的相关文章

随机推荐

  • 仅使用 data.table 将 NA 替换为 data.table 中的最后一个非 NA

    我想更换NA最后一个非 NA 值的值data table并使用data table 我有一个解决方案 但它比na locf library data table library zoo library microbenchmark f1 l
  • Google Drive API 403 禁止

    我们使用 Google Drive API 来允许用户浏览并选择要在报告中使用的文件 我们的一位用户 该问题并不普遍 在尝试获取文件列表时遇到错误 如下 从 Google 返回的 JSON 正文 error errors domain gl
  • HttpClient - 如何判断服务器是否更快地关闭?

    我正在使用 NETHttpClient向我的服务器发送请求 我已经设定HttpClient Timeout属性为 10 秒 所以我得到了A task was cancelled每当服务器无法在 10 秒内处理我的请求时 就会出现异常 到这里
  • Windows快捷方式的内部结构是怎样的?

    一台计算机上有 3 个硬盘 2 个 Windows XP 1 个 Windows 7 依次从每个硬盘加载操作系统 我发现在第一个 XP 中创建的一些工作快捷方式 不是全部 在第二个 XP 和 Windows 7 中不起作用 不可用于查看快捷
  • Xcode 4.5 iOS 6.0 模拟器方向不起作用

    我已经将我的 Xcode 更新到 4 5 我已经实现了如下方向方法 BOOL shouldAutorotate return YES NSUInteger supportedInterfaceOrientations return UIIn
  • ASMX 操作 404s,但 ASMX 服务描述没有,url 路由问题?

    所以我发现自己遇到了一个难题 我们的应用程序中有一些旧的 asmx Web 服务 多年来一直运行良好 突然间 他们停止了构建服务器 CI 上的工作 我说停止工作 因为即使当我导航到服务时显示服务描述 调用任何操作都不会路由到服务 Web 表
  • 在 React 中,ref 是引用虚拟 DOM 还是实际 DOM?

    我假设虚拟 DOM 并且 React 通过比较来处理它 但我有一位招聘人员说 ref 会影响实际的 DOM 我不明白这是怎么回事 我认为他们只是误会了 Refs 应该引用实际的 DOM Refs 的一种用法是与第三方 DOM 库集成 因此您
  • 使用 Lodash 合并复杂对象数组

    我是 Lodash 的新手 正在尝试解决这个问题 但可以找到一个好方法 我有一个从数据库返回的对象数组 数据结构如下 var data index 1 Aoo a1 Boo b2 index 1 Aoo a2 Boo b2 index 2
  • 无效的 Swift 支持/无效的 Swift 实现

    我想上传一个用 swift 编写的应用程序 应用程序加载器成功交付应用程序 但几分钟后我收到苹果的回复 无效的 Swift 支持 该捆绑包包含无效的 Swift 实现 该应用程序可能是使用不合规或预发布的工具构建或签名的 访问develop
  • 如何更新已从 BOT 发送给用户的自适应卡?

    我已经发送了包含捕获详细信息和按钮的卡片 从任务模块单击提交后 该模块将通过 http API 保存详细信息 此处的活动类型为 调用 现在我必须更新现有的自适应卡 我有更新消息的代码 但如何更新卡或再次重新发送卡 connector new
  • Webpack 提供的 Angular 2 应用程序基于环境的属性?

    我正在使用由 JHipster 生成并由 Spring Boot 服务器提供服务的独立 Angular 控制台 我希望根据环境 本地 开发 产品等 提供具有不同属性的应用程序 我看到很多关于配置每个环境的 webpack 构建的帖子 但我需
  • PHP/Regex:bbcode [s] 或 [strike] 的简单正则表达式无法工作

    对于一个愚蠢的 bbcode 解析器 我想将两个定义添加到一个中 我最初的 preg replace 定义是这样的 s s si
  • 无法在 Heroku 上使用 Gmail 发送电子邮件

    我无法让我的 Rails 应用程序使用 Gmail 发送电子邮件 我可以在本地开发环境中发送电子邮件 但无法从 Heroku 发送 这是我的配置文件 应用程序 rb config action mailer smtp settings ad
  • Spark Streaming StreamingContext.start() - 启动接收器时出错 0

    我有一个使用 Spark Streaming 的项目 我使用 spark submit 运行它 但遇到了以下错误 15 01 14 10 34 18 ERROR ReceiverTracker Deregistered receiver f
  • 如何使 Satchmo 在 Google App Engine 中工作

    我知道数据存储方面存在很大差异 但既然 django 是捆绑的并且它从 Satchmo 中抽象出数据存储 那么可以做些什么吗 事实上 我不是 Python 爱好者 到目前为止主要是 Java PHP 但我愿意学习 另外 如果今天不可能 让我
  • 如何将 Node.js 应用程序上传到 FTP 服务器?

    我对 Node js 有点陌生 但我构建了一个应用程序 并对它非常满意 我想知道如何将 Node js 应用程序上传到 FTP 服务器 有可能做到这一点吗 Node JS 应用程序只是文件的集合 您可以像任何其他文件一样使用 FTP 将它们
  • 实体框架代码首先将 TPT 转换为 TPH

    我使用 EF Code First 使用 TPT 开发了一个应用程序 发布附件 评论等 它运行良好 并且正在与许多客户进行 beta 测试 但是 存在许多层次结构 因此 我有一个包含各种继承模型的基本模型 每个模型都包含许多属性 这些属性本
  • HTMLAgilityPack 使用 C# 解析 HTML 时出现问题

    我只是想了解 HTMLAgilityPack 和 XPath 我试图从纳斯达克网站获取 HTML 链接 公司列表 http www nasdaq com quotes nasdaq 100 stocks aspx 我目前有以下代码 Html
  • 使用 jquery 显示/隐藏文本

    基本上我有 6 个按钮和 6 个段落 每个按钮与特定段落相关 我想在单击某个按钮时显示一段文本 然后在再次单击该按钮时隐藏该段落 我浏览过类似的问题 但似乎无法让它发挥作用 我认为这是因为我才开始尝试使用 jquery 并且没有真正理解这个
  • tf.keras 损失变为 NaN

    我正在 tf keras 中编写一个 3 层的神经网络 我的数据集是 MNIST 数据集 我减少了数据集中的示例数量 因此运行时间较短 这是我的代码 import tensorflow as tf from tensorflow keras