为什么这个 TensorFlow 实现远不如 Matlab 的神经网络成功?

2024-02-17

作为一个玩具示例,我正在尝试拟合一个函数f(x) = 1/x来自 100 个无噪声数据点。 matlab 默认实现非常成功,均方差约为 10^-10,并且插值完美。

我实现了一个神经网络,其中一个隐藏层包含 10 个 S 型神经元。我是神经网络的初学者,所以要警惕愚蠢的代码。

import tensorflow as tf
import numpy as np

def weight_variable(shape):
  initial = tf.truncated_normal(shape, stddev=0.1)
  return tf.Variable(initial)

def bias_variable(shape):
  initial = tf.constant(0.1, shape=shape)
  return tf.Variable(initial)

#Can't make tensorflow consume ordinary lists unless they're parsed to ndarray
def toNd(lst):
    lgt = len(lst)
    x = np.zeros((1, lgt), dtype='float32')
    for i in range(0, lgt):
        x[0,i] = lst[i]
    return x

xBasic = np.linspace(0.2, 0.8, 101)
xTrain = toNd(xBasic)
yTrain = toNd(map(lambda x: 1/x, xBasic))

x = tf.placeholder("float", [1,None])
hiddenDim = 10

b = bias_variable([hiddenDim,1])
W = weight_variable([hiddenDim, 1])

b2 = bias_variable([1])
W2 = weight_variable([1, hiddenDim])

hidden = tf.nn.sigmoid(tf.matmul(W, x) + b)
y = tf.matmul(W2, hidden) + b2

# Minimize the squared errors.
loss = tf.reduce_mean(tf.square(y - yTrain))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

# For initializing the variables.
init = tf.initialize_all_variables()

# Launch the graph
sess = tf.Session()
sess.run(init)

for step in xrange(0, 4001):
    train.run({x: xTrain}, sess)
    if step % 500 == 0:
        print loss.eval({x: xTrain}, sess)

均方差值约为 2*10^-3,因此比 matlab 差约 7 个数量级。可视化与

xTest = np.linspace(0.2, 0.8, 1001)
yTest = y.eval({x:toNd(xTest)}, sess)  
import matplotlib.pyplot as plt
plt.plot(xTest,yTest.transpose().tolist())
plt.plot(xTest,map(lambda x: 1/x, xTest))
plt.show()

we can see the fit is systematically imperfect: enter image description here while the matlab one looks perfect to the naked eye with the differences uniformly < 10^-5: enter image description here I have tried to replicate with TensorFlow the diagram of the Matlab network:

顺便说一句,该图似乎暗示了 tanh 而不是 sigmoid 激活函数。可以肯定的是,我在文档中找不到它。然而,当我尝试在 TensorFlow 中使用 tanh 神经元时,拟合很快就会失败nan对于变量。我不知道为什么。

Matlab 使用 Levenberg–Marquardt 训练算法。贝叶斯正则化在均方为 10^-12 时更加成功(我们可能处于浮点算术的蒸汽领域)。

为什么 TensorFlow 的实现如此糟糕,我能做些什么来让它变得更好?


我尝试训练 50000 次迭代,误差为 0.00012。 Tesla K40 上大约需要 180 秒。

对于这类问题,一阶梯度下降似乎不太合适(双关语),你需要 Levenberg–Marquardt 或 l-BFGS。我认为还没有人在 TensorFlow 中实现它们。

Edit Use tf.train.AdamOptimizer(0.1)对于这个问题。它到达3.13729e-054000次迭代后。此外,使用默认策略的 GPU 对于这个问题来说似乎也是一个坏主意。有许多小操作,开销导致 GPU 版本的运行速度比我的机器上的 CPU 慢 3 倍。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

为什么这个 TensorFlow 实现远不如 Matlab 的神经网络成功? 的相关文章

  • 我怎样才能更多地了解Python的内部原理? [关闭]

    Closed 这个问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 我使用Python编程已经有半年多了 我对Python内部更感兴趣 而不是使用Python开发应用程序
  • 如何使用 imaplib 获取“消息 ID”

    我尝试获取一个在操作期间不会更改的唯一 ID 我觉得UID不好 所以我认为 Message ID 是正确的 但我不知道如何获取它 我只知道 imap fetch uid XXXX 有人有解决方案吗 来自 IMAP 文档本身 IMAP4消息号
  • 将数据帧行转换为字典

    我有像下面的示例数据这样的数据帧 我正在尝试将数据帧中的一行转换为类似于下面所需输出的字典 但是当我使用 to dict 时 我得到了索引和列值 有谁知道如何将行转换为像所需输出那样的字典 任何提示都非常感激 Sample data pri
  • if 语句未命中中的 continue 断点

    在下面的代码中 两者a and b是生成器函数的输出 并且可以评估为None或者有一个值 def testBehaviour self a None b 5 while True if not a or not b continue pri
  • Pandas 中允许重复列

    我将一个大的 CSV 包含股票财务数据 文件分割成更小的块 CSV 文件的格式不同 像 Excel 数据透视表之类的东西 第一列的前几行包含一些标题 公司名称 ID 等在以下列中重复 因为一家公司有多个属性 而不是一家公司只有一栏 在前几行
  • 从零开始的 numpy 形状意味着什么

    好的 我发现数组的形状中可以包含 0 对于将 0 作为唯一维度的情况 这对我来说是有意义的 它是一个空数组 np zeros 0 但如果你有这样的情况 np zeros 0 100 让我很困惑 为什么这么定义呢 据我所知 这只是表达空数组的
  • 如何创建一个语句来打印以特定单词开头的单词? [关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 如何在 python 中打印从特定字母开始的单词 而不使用函数 而是使用方法或循环 1 我有一个字符串 想要打印以 m 开头的单词 S
  • 切片 Dataframe 时出现 KeyError

    我的代码如下所示 d pd read csv Collector Output csv df pd DataFrame data d dfa df copy dfa dfa rename columns OBJECTID Object ID
  • 对图像块进行多重处理

    我有一个函数必须循环遍历图像的各个像素并计算一些几何形状 此函数需要很长时间才能运行 在 24 兆像素图像上大约需要 5 小时 但似乎应该很容易在多个内核上并行运行 然而 我一生都找不到一个有据可查 解释充分的例子来使用 Multiproc
  • 按元组分隔符拆分列表

    我有清单 print L I WW am XX newbie YY ZZ You WW are XX cool YY ZZ 我想用分隔符将列表拆分为子列表 ZZ print new L I WW am XX newbie YY ZZ You
  • Seaborn Pairplot 图例不显示颜色

    我一直在学习如何在Python中使用seaborn和pairplot 这里的一切似乎都工作正常 但由于某种原因 图例不会显示相关的颜色 我无法找到解决方案 因此如果有人有任何建议 请告诉我 x sns pairplot stats2 hue
  • 使用 Firefox 绕过弹出窗口下载文件:Selenium Python

    我正在使用 selenium 和 python 来从中下载某些文件web page http www oceanenergyireland com testfacility corkharbour observations 我之前一直使用设
  • 使用yield 进行字典理解

    作为一个人为的例子 myset set a b c d mydict item yield join item s for item in myset and list mydict gives as cs bs ds a None b N
  • 使用队列从多个输入文件中统一采样

    我的数据集中的每个类都有一个序列化文件 我想使用队列来加载每个文件 然后将它们放入 RandomShuffleQueue 中 这样我就可以从每个类中获得随机的示例组合 我认为这段代码会起作用 在此示例中 每个文件有 10 个示例 filen
  • 迭代 my_dict.keys() 并修改字典中的值是否会使迭代器失效?

    我的例子是这样的 for my key in my dict keys my dict my key mutate 上述代码的行为是否已定义 假设my dict是一本字典并且mutate是一个改变其对象的方法 我担心的是 改变字典中的值可能
  • 当鼠标悬停在上面时,intellisense vscode 不显示参数或文档

    我正在尝试将整个工作流程从 Eclipse 和 Jupyter Notebook 迁移到 VS Code 我安装了 python 扩展 它应该带有 Intellisense 但它只是部分更糟糕 我在输入句点后收到建议 但当将鼠标悬停在其上方
  • 限制 django 应用程序模型中的单个记录?

    我想使用模型来保存 django 应用程序的系统设置 因此 我想限制该模型 使其只能有一条记录 极限怎么办 尝试这个 class MyModel models Model onefield models CharField The fiel
  • 如何读取Python字节码?

    我很难理解 Python 的字节码及其dis module import dis def func x 1 dis dis func 上述代码在解释器中输入时会产生以下输出 0 LOAD CONST 1 1 3 STORE FAST 0 x
  • 检查字典键是否有空值

    我有以下字典 dict1 city name yass region zipcode phone address tehsil planet mars 我正在尝试创建一个基于 dict1 的新字典 但是 它不会包含带有空字符串的键 它不会包
  • Scrapy Spider不存储状态(持久状态)

    您好 有一个基本的蜘蛛 可以运行以获取给定域上的所有链接 我想确保它保持其状态 以便它可以从离开的位置恢复 我已按照给定的网址进行操作http doc scrapy org en latest topics jobs html http d

随机推荐

  • 使用 SecItemImport 导入 PKCS12

    Apple s 文档 https developer apple com library mac documentation security Reference keychainservices Reference reference h
  • Plotly-Dash:如何确定客户端回调中的触发输入

    Dash 的文档描述了在服务器端回调的情况下如何确定哪个输入触发了回调 高级回调 https dash plotly com advanced callbacks 有没有办法确定哪个输入触发了客户端打回来 看起来这个功能是在1 13 0版本
  • REST - 修改部分资源 - PUT 或 POST

    我看到很多关于如何使用 REST 只更新部分资源 例如状态指示器 的问题 选项似乎是 抱怨 HTTP 没有 PATCH 或 MODIFY 命令 然而 接受的答案REST 的 HTTP MODIFY 动词 https stackoverflo
  • 获取当前行的长度

    我正在尝试在状态行中添加一个指示符来显示行的总长度 不仅仅是光标列位置 可以用 c 我该怎么做呢 要将一行内容作为字符串获取 请使用getline
  • 抓取无限滚动页面停止而不滚动

    我目前正在使用 PhantomJS 和 CasperJS 来抓取网站中的链接 该网站使用 JavaScript 动态加载结果 然而 下面的代码片段并没有让我获得页面包含的所有结果 我需要的是向下滚动到页面底部 查看微调器是否显示 意味着还有
  • 一次从单个设备登录,注销其他 MERN、JWT、Google 登录

    我无法理解应该如何防止同一用户多次登录 我在带有 JWT 令牌的 React Node 应用程序中使用带有 firebase 的 google 登录 如果从其他浏览器或其他设备登录 如何使用户注销 有没有任何库可以处理这个问题或者有什么方法
  • 如何在同一个库中拥有多个wpf自定义控件?

    我有一个 WPF 自定义控件项目 我想在其中包含许多自定义控件 默认情况下 VS2015 cummunity 创建一个 Theme 文件夹 其中包含 generic xaml 文件和包含交互逻辑的 cs 文件 我想要有很多用户控件 所以我尝
  • 从外部站点动态加载 js

    我想当用户单击按钮时从外部站点加载 JS 代码 例如
  • 如何在 ngTagsInput 中设置标签的颜色?

    我想在我的项目中使用 ng tags input 我尝试根据数组中的颜色属性对象为每个标签设置不同的颜色 Here is plunker http plnkr co edit W5bjrwN5riL94i2jhOP3 p preview我正
  • 反转字符串中元素的顺序

    我有以下字符串 1119 2 483 11021 我想反转该字符串中元素的顺序 期望的输出 11021 483 2 1119 T SQL 版本 2014 您需要一个有序的 split 函数 例如 灵感 https www sqlserver
  • 线程如何节省时间?

    我正在学习 C 中的线程 但是 我无法理解线程的哪些方面实际上提高了性能 考虑仅存在一个核心处理器的场景 将任务拆分为多个线程使用相同的进程上下文 共享资源 并且它们同时运行 由于线程只是共享时间 为什么它们的运行时间 周转时间 小于单线程
  • Html5画布文本交叉点

    我有一些话 所有话都在某个 物体 之王中 这些单词可以在画布上移动 我需要获取所有交叉点的数组 如本例所示 但不需要将文本转换为 SVG paperjs org examples path intersections 谢谢 您可以通过比较两
  • 如何在python源代码中找到运算符的定义?

    我对 in 的实现感到好奇 contains python 中的运算符由于这个问题 https stackoverflow com questions 9089400 python set in operator uses equality
  • 如果页面加载失败,如何运行 Tampermonkey 脚本?

    我有一个在服务器页面上运行的脚本 有时发送不会将任何内容发送回客户端 我得到未收到数据Chrome 中的错误 我想注册此事件 通过 AJAX 通知另一台服务器 然后重新加载页面 即使页面加载失败 如何确保脚本运行 None
  • 如何在Angular中的地图上动态绘制多边形形状

    如何动态绘制多边形形状 未预定义paths 以及如何存储多边形的经纬度值 我已经参考了AGMP多边形 https angular maps com api docs agm core directives AgmPolygon html但这
  • 如何检查从 C++ 字符串到无符号整数的转换

    我需要 1 找出我当前系统上最大的 unsigned int 值是多少 我在 limit h 上没有找到它 写起来安全吗unsigned int maxUnsInt 0 1 我也尝试过unsigned int maxUnsInt MAX I
  • 捕获目录内发生的事件

    我正在使用以下方式观看目录Java 7 nio WatchService通过使用以下方法 Path myDir Paths get rootDir try WatchService watcher myDir getFileSystem n
  • 访问 Access 2013 数据库的架构

    如果我尝试读取 Access 2013 数据库的架构 我会收到以下错误 no read permission on MSysRelationships 现在帮助告诉我 User level security features are not
  • 如何将类添加到 simple_form 2 包装器中的输入组件

    我正在努力拥有class text 在我的输入字段中使用名为 hinted in simple form 2 0 0 rc 的自定义包装器时 config wrappers hinted do b b use input class gt
  • 为什么这个 TensorFlow 实现远不如 Matlab 的神经网络成功?

    作为一个玩具示例 我正在尝试拟合一个函数f x 1 x来自 100 个无噪声数据点 matlab 默认实现非常成功 均方差约为 10 10 并且插值完美 我实现了一个神经网络 其中一个隐藏层包含 10 个 S 型神经元 我是神经网络的初学者