模型不学习

2024-01-10

背景

我有一个非常简单的脚本,它创建了一个 keras 模型,旨在充当异或门。

我在中生成了 40000 个数据点get_data功能。它创建两个数组;一个按某种顺序包含 1 和 0 的输入数组,以及一个 1 或 0 的输出。

Issue

当我运行代码时,它似乎没有学习,并且每次训练时得到的结果都有很大差异。

Code

from keras import models
from keras import layers

import numpy as np

from random import randint


def get_output(a, b): return 0 if a == b else 1


def get_data ():
    data = []
    targets = []

    for _ in range(40010):
        a, b = randint(0, 1), randint(0, 1)

        targets.append(get_output(a, b))
        data.append([a, b])

    return data, targets


data, targets = get_data()

data = np.array(data).astype("float32")
targets = np.array(targets).astype("float32")

test_x = data[40000:]
test_y = targets[40000:]

train_x = data[:40000]
train_y = targets[:40000]

model = models.Sequential()

# input
model.add(layers.Dense(2, activation='relu', input_shape=(2,)))

# hidden
# model.add(layers.Dropout(0.3, noise_shape=None, seed=None))
model.add(layers.Dense(2, activation='relu'))
# model.add(layers.Dropout(0.2, noise_shape=None, seed=None))
model.add(layers.Dense(2, activation='relu'))

# output
model.add(layers.Dense(1, activation='sigmoid')) # sigmoid puts between 0 and 1

model.summary() # print out summary of model

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

res = model.fit(train_x, train_y, epochs=2000, batch_size=200, validation_data=(test_x, test_y)) # train
    
print 'predict: \n', test_x
print model.predict(test_x)

Output

[[0. 1.]
 [1. 1.]
 [1. 1.]
 [0. 0.]
 [1. 0.]
 [0. 0.]
 [0. 0.]
 [0. 1.]
 [1. 1.]
 [1. 0.]]
[[0.6629775 ]
 [0.00603844]
 [0.00603844]
 [0.6629775 ]
 [0.6629775 ]
 [0.6629775 ]
 [0.6629775 ]
 [0.6629775 ]
 [0.00603844]
 [0.6629775 ]]

即使没有丢失层,我也得到了非常相似的结果。


你的问题有几个问题。

首先,您的导入相当非正统(与您的问题无关,确实如此,但它有助于遵守一些约定):

from keras.models import Sequential
from keras.layers import Dense
import numpy as np

其次,您不需要数千个 XOR 问题的示例;只有四种组合:

X = np.array([[0,0],[0,1],[1,0],[1,1]])
y = np.array([[0],[1],[1],[0]])

就这样。

第三,出于同样的原因,您实际上无法通过 XOR 获得“验证”或“测试”数据;在最简单的方法中(即您可以说在这里尝试做的事情),您只能使用这 4 种组合来测试模型学习该函数的程度(因为没有更多了!)。

第四,你应该从一个简单的单隐藏层模型开始(有点多于 2 个单元并且没有 dropout),然后逐步进行如果需要的话:

model = Sequential()
model.add(Dense(8, activation="relu", input_dim=2))
model.add(Dense(1, activation="sigmoid"))

model.compile(loss='binary_crossentropy', optimizer='adam')
model.fit(X, y, batch_size=1, epochs=1000)

这应该会将您的损失降至约 0.12;它对这个功能的学习效果如何?

model.predict(X)
# result:
array([[0.31054294],
       [0.9702552 ],
       [0.93392825],
       [0.04611744]], dtype=float32)

y
# result:
array([[0],
       [1],
       [1],
       [0]])

这够好吗?好吧,我不知道 - 正确的答案始终是“这取决于”!但你现在有了一个起点(即可以说是一个网络)learns某事),您可以从中进行进一步的实验......

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

模型不学习 的相关文章

随机推荐

  • 属性的命名约定[关闭]

    Closed 这个问题是基于意见的 help closed questions 目前不接受答案 哪一个更好或更清楚 public int FrozenRegionWidth get set Or public int WidthOfFroz
  • 将多个 Pandas DataFrame 列设置为单列中的值或同时设置多个标量值

    我正在尝试将多个新列设置为一列 并分别将多个新列设置为多个标量值 也做不到 除了单独设置之外还有什么办法吗 df pd DataFrame columns A B data np arange 6 reshape 3 2 df loc C
  • 如何使用express-uploadfile从POST读取文本文件?

    我正在尝试制作 Node js 服务器 用于在那里上传文本文件 所以我使用 POST 来获取本地用户的文本文件 然后我想让服务器读取该文件 我想我可以让用户上传他的本地文本文件 我可以获取上传文件的描述 但很难让服务器读取文件的实际字符串
  • Chrome 的自动填充隐藏文本输入的背景图像

    成功禁用自动填充黄色背景颜色后 我偶然发现了另一个功能 我的每个输入元素都有一个背景图像 每次我关注文本输入时 浏览器都会在下拉列表中建议我之前使用的值 选择一个值后 自动填充会覆盖整个背景并隐藏图像 这是我的 html 和 css 在 J
  • NHibernate:System.Argument异常:已添加具有相同键的项目

    我遇到了一个很难重现的偶发错误 我的第一个猜测是 不知怎的 我有一个泄漏的休眠会话 但是当我运行休眠分析器 http nhprof com 我没有看到太多异常 MVC 2 0 流畅版本1 1 0 685 NHibernate 版本 2 1
  • 使用三元运算符初始化结构

    为什么三元运算符不能用于初始化结构类型 而可以用于初始化基类型 例如int 示例代码 include
  • javascript向所有函数添加原型方法?

    有没有一种方法可以在不使用原型库的情况下向所有 javascript 函数添加方法 类似于 Function prototype methodName function return dowhateverto this 这是我到目前为止所尝
  • Perl 6 中有快速并行“for”循环吗?

    给定一些对 1 到 500000 之间的每个数字进行一些数学 转换的代码 我们有选择 简单的for循环 for 500000 gt i my result i 2 Str 在我的不科学基准测试中 这需要 2 8 秒 最规范的并行版本在一个P
  • 新的 SQL Server 用户登录失败

    我已在 SQL Server Management Studio SQL Server 2008 Express 的安全选项卡中创建了新用户 指定登录名 SQL Server 身份验证 输入密码 分配服务器角色sysadmin 映射到我的数
  • 如何检查元素是否在 iframe 内

    假设您有一个 DOM 节点 并且您想知道它是否位于 iframe 内 一种方法是检查它的父链 看看您是否在到达父窗口之前到达了 iframe 不过 我想知道是否有更快的方法来做到这一点 你也许可以检查ownerDocument财产 http
  • 强制使用 SSL:尝试确定托管应用程序的 DNC 进程的进程 ID 时发生错误

    我想在我的网站上强制使用 https 如果发现本文 https azure microsoft com en us documentation articles web sites configure ssl certificate 4 e
  • 使用 Perl 发送电子邮件

    我正在尝试使用 Perl 发送电子邮件 基本上我有一个 Perl 脚本 可以以良好的格式打印出报告 我希望通过电子邮件发送该报告 我怎样才能做到这一点 如果机器没有配置sendmail 我通常使用邮件 发送邮件 https metacpan
  • 如何根据位置分割字符串

    我想根据字符的位置拆分变量 生成的第一个字符串应具有指定位置之前的前一个位置 另一个字符串应包含其他部分 假设如果我有一个变量 var 2013AD 我想 var1 2013 and var2 AD 我怎样才能实现这个目标 嗯 要在这里使用
  • 如何使用 Selenium WebDriver 检查单选按钮?

    我想检查这个单选按钮 但我不知道如何检查 我的 HTML 是 div class appendContent div id contentContainer class grid list template gt div div div d
  • AEM Scheduler 的配置发生变化吗?

    我正在尝试为我的项目需求实现简单的调度程序 我的项目正在使用Adobe AEM 截至目前 我浏览了 Adob e 网站并尝试实现所提供的给定示例 但没有一个更新我的error log file package sling docu exam
  • Android 项目未解析任何静态资源

    由于某种原因 我的 android 项目无法解析 js css 图像的任何静态路径 而它在 web 和 ios 上运行良好 我没有使用离子 所以也许我错过了一些特定的东西 然而 所有这些文件都可以在 android 项目中使用 这是突出显示
  • 训练 Keras 模型会产生多个优化器错误

    所以我需要使用我自己的数据集重新训练 Tiny YOLO 我正在使用的模型可以在这里找到 keras yolo3 https github com qqwweee keras yolo3 我开始训练 遇到多个优化器错误 添加了错误代码以防止
  • Java 中的 Process.exitValue()

    下面是我用来简单地从命令行程序打开和关闭 Internet Explorer 的程序 我在 Windows XP 操作系统上使用 Java 6 运行我的程序 Runtime runtime Runtime getRuntime Proces
  • PHPUnit 严格模式 - 如何更改默认超时

    我想继续在严格模式下运行我的单元测试 以便我可以轻松地了解任何异常长的测试 但同时 1 秒的默认超时是不够的 我可以为所有测试更改它吗 我知道我可以使用以下命令为每个课程 和单独的测试 设置超时 short medium long注释 但是
  • 模型不学习

    背景 我有一个非常简单的脚本 它创建了一个 keras 模型 旨在充当异或门 我在中生成了 40000 个数据点get data功能 它创建两个数组 一个按某种顺序包含 1 和 0 的输入数组 以及一个 1 或 0 的输出 Issue 当我