Python - 基于 LSTM 的 RNN 需要 3D 输入?

2024-03-27

我正在尝试构建一个基于 LSTM RNN 的深度学习网络,这是尝试过的

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.layers import Embedding
from keras.layers import LSTM
import numpy as np

train = np.loadtxt("TrainDatasetFinal.txt", delimiter=",")
test = np.loadtxt("testDatasetFinal.txt", delimiter=",")

y_train = train[:,7]
y_test = test[:,7]

train_spec = train[:,6]
test_spec = test[:,6]


model = Sequential()
model.add(LSTM(32, input_shape=(1415684, 8),return_sequences=True))
model.add(LSTM(64, input_dim=8, input_length=1415684, return_sequences=True))
##model.add(Embedding(1, 256, input_length=5000))
##model.add(LSTM(64,input_dim=1, input_length=10, activation='sigmoid',
##               return_sequences=True, inner_activation='hard_sigmoid'))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))

model.compile(loss='binary_crossentropy', optimizer='rmsprop')

model.fit(train_spec, y_train, batch_size=2000, nb_epoch=11)
score = model.evaluate(test_spec, y_test, batch_size=2000)

但它给我带来了以下错误

ValueError: Error when checking input: expected lstm_1_input to have 3 dimensions, but got array with shape (1415684, 1)

这是数据集的示例

(患者编号、时间(毫秒)、加速度计 x 轴、y 轴、z 轴、幅度、频谱图、标签(0 或 1))

1,15,70,39,-970,947321,596768455815000,0
1,31,70,39,-970,947321,612882670787000,0
1,46,60,49,-960,927601,602179976392000,0
1,62,60,49,-960,927601,808020878060000,0
1,78,50,39,-960,925621,726154800929000,0

在数据集中,我仅使用频谱图作为输入特征,使用标签(0或1)作为输出,总训练样本为 1,415,684


您的主要错误是误解了 LSTM(或实际上任何 RNN)的工作原理以及它接受的输入内容。 LSTM 网络的单个训练示例包含sequence和一个标签。例如,这个...

1,15,70,39,-970,947321,596768455815000,0
1,31,70,39,-970,947321,612882670787000,0
1,46,60,49,-960,927601,602179976392000,0
1,62,60,49,-960,927601,808020878060000,0
1,78,50,39,-960,925621,726154800929000,0

...是长度为 5 的序列,具有 8 个特征。整个序列的标签是下一行的标签列。请注意,这只是一个例子;一个批次意味着多个这样的序列和标签。


现在,关于 Keras,从这个答案 https://stackoverflow.com/a/48141688/712995:

LSTM 层是一个循环层,因此它需要 3 维输入(batch_size, timesteps, input_dim).

让我们仔细看看您的规格:input_shape=(1415684, 8)告诉 keras 期望序列的长度1415684,其中每个项目都有8特征。所有这些都没有考虑批量大小,即2000.

这显然行不通,因为1415684LSTM 序列太长了。经验证据表明 LSTM 最多可以学习 100 个时间步,因此输入更大的序列并不会让它学得更好。更不用说它不节省内存和时间。

你应该做的是选择一个较小的timesteps参数,说timesteps=64,并将您的数据分割成块timesteps后续行。块可能会相交。这些手段的批次batch_size * timesteps总共行数,每行有8列。这y_train应包含每个训练序列的基本事实。 Keras 不会执行此准备步骤,因此您必须手动执行此操作。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Python - 基于 LSTM 的 RNN 需要 3D 输入? 的相关文章

  • Keras AttributeError:“顺序”对象没有属性“predict_classes”

    我试图按照本指南找到模型性能指标 F1 分数 准确性 召回率 https machinelearningmastery com how to calculate precision recall f1 and more for deep l
  • Python - 重写 print()

    我正在使用 mod wsgi 想知道是否可以覆盖 print 命令 因为它没用 这样做是行不通的 print myPrintFunction 因为这是一个语法错误 Print 不是 Python 2 x 中的函数 因此这不能直接实现 但是
  • Keras 中的 Tensorflow 自定义损失函数 - 张量循环

    我正在尝试在 Keras 中编写自定义损失函数 如下所示 Keras 中的自定义损失函数 https stackoverflow com questions 43818584 custom loss function in keras 我的
  • 找不到 Jupyter 命令 `jupyter-lab`

    我尝试在我的 Kubuntu 机器上安装 jupyter lab 如果我使用 pip3 install jupyter jupyterlab 安装 jupyter lab 则命令 jupyter notebook 完全可以正常工作 但是 如
  • 从“stdin”读取文件后如何使用“input()”?

    Context 我想要一个简单的脚本 它可以选择多个管道输入中的一个 而不需要EOF when reading a lineUnix Linux 上的错误 它试图 接受多行管道文本 等待用户选择一个选项 将该选项打印到标准输出 所需用途 p
  • 猴子修补@property

    是否有可能对 a 的值进行猴子修补 property我无法控制的类的实例 class Foo property def bar self return here be dragons f Foo print f bar baz f bar
  • 具有多个输入的 Keras TimeDistributed 层

    我正在尝试使以下代码行正常工作 low encoder out TimeDistributed AutoregressiveDecoder X tf embeddings Where AutoregressiveDecoder是一个需要两个
  • 如何测量异步发电机所花费的时间?

    我想测量生成器花费的时间 阻塞主循环的时间 假设我有以下两个生成器 async def run for i in range 5 await asyncio sleep 0 2 yield i return async def walk f
  • Python 3.7 Windows 不支持 dbm.gnu 吗?

    做的时候 import dbm gnu 在适用于 Windows 的标准 Python 3 7 6 64 上 我得到 文件 C Python37 lib dbm gnu py 第 3 行 位于从 gdbm 导入 ModuleNotFound
  • Python:球体的交集

    我对编程非常陌生 但我决定承担一个有趣的项目 因为我最近学会了如何以参数形式表示球体 当三个球体相交时 有两个不同的交点 除非它们仅在一个奇点处重叠 球体的参数表示 我的代码是根据答案修改的Python matplotlib 绘制 3d 立
  • 将画布的鼠标坐标转换为地理坐标

    我正在尝试使用 Python Tkinter 创建包含意大利所有城市的地图Canvas 我在网上找到了一张意大利地图的图片 其中突出显示了一些城市 并将其插入到我的Canvas 之后 我使用一个函数来确定 2 个突出显示的城市的画布坐标 i
  • django 2.0 中的错误 404 静态文件

    我试图在 django 2 0 中找到我的所有静态文件 但是当我只运行服务器时 我收到 404 错误 这是我的设置代码 STATIC URL static STATIC ROOT var www example com static STA
  • Python3:如何“不”四舍五入到最接近的偶数?

    我知道Python3的round 函数四舍五入到最接近的偶数 我怎样才能防止这种情况发生并使其像Python2那样从零舍入一半 您可以使用Decimal and ROUND HALF UP from decimal https docs p
  • 在python中访问超级(父)类变量

    我是Python新手 我尝试使用 super 方法访问子类中的父类变量 但它抛出错误 无参数 使用类名访问类变量是可行的 但我想知道是否可以使用 super 方法访问它们 class Parent object props a str a
  • 使用 python-shell 持续交换数据

    我需要从节点运行一些 python 脚本 由于我的 python 脚本使用复杂的结构 我认为如果只加载这些结构一次 然后使用这些结构运行一些特定的脚本 任务 会更好 在节点上 我想永远运行一个脚本 或者直到我说它可以终止 并继续向该脚本发送
  • 从数据帧字典中获取单独的数据帧 Python

    我有一本字典d充满了数据帧的集合 key type size value gm1 dataframe mxn gm2 dataframe mxN gm10 dataframe nxM 我想使用它们来一一输出这些数据帧keys作为新数据框的名
  • Python ttk.combobox 强制发布/打开

    我正在尝试扩展 ttk 组合框类以允许自动建议 我到目前为止的代码运行良好 但我想让它在输入一些文本后显示下拉列表 而不从小部件的输入部分移除焦点 我正在努力解决的部分是找到一种强制下拉的方法 在 python 文档中我找不到任何提及这一点
  • Python 小数.InvalidOperation 错误

    当我运行这样的东西时 我总是收到此错误 from decimal import getcontext prec 30 b 2 3 Decimal b Error Traceback most recent call last File Te
  • 在 keras 中使用自定义张量流操作

    我在张量流中有一个脚本 其中包含自定义张量流操作 我想将代码移植到 keras 但我不确定如何在 keras 代码中调用自定义操作 我想在 keras 中使用tensorflow 所以到目前为止我发现的教程描述了与我想要的相反的内容 htt
  • 如何得到将外力映射到广义力的矩阵?

    给定一个多体植物 我需要找到将外力 lambda 转换为广义力的矩阵 IE 以下方程中的 Phi 取自 Scott Kuindersma Frank Permenter 和 Russ Tedrake 的 稳定动态运动的有效可解二次规划 我的

随机推荐

  • 如何创建所有子类的实例

    我有超过 250 个子类需要由它们组成的实例 我不能坐在那里羞涩地粘贴new Class 250次 是否有使用反射来创建类的实例 创建实例时不需要构造函数 谢谢 我真的不明白你的意思 但我尝试猜测 未测试 public class Test
  • 参数“samples”的预期哈希值(获取数组)

    我一直在关注 Railscasts 的嵌套形式和复杂形式的剧集 在以单个表单创建多个模型的过程中 我能够编辑 更新 删除和创建嵌套在批处理模型中的示例模型的记录 我很长时间以来一直在绞尽脑汁 也尝试四处寻找 但找不到任何正确的解决方案来解决
  • 如何离线存储密码

    虽然这是针对Windows Phone 7的 但我想这个原理是通用的 我想在我的应用程序中设置一个密码保护区 但是 我的应用程序完全离线 因此我必须在手机上存储凭据详细信息 我最初的想法是存储密码和盐的哈希值 这是最好的方法吗 如果是这样
  • 更改特定索引而不在 Vuejs 中重新渲染整个数组

    In a Vuejs项目 我有一个array in my 数据对象并将其呈现在视图中v for指示 现在 如果我更改该数组中的特定索引 Vue 会在视图中重新渲染整个数组 有没有办法在不重新渲染整个数组的情况下查看视图的变化 这个问题背后的
  • 如何处理 JSON 字符串中的 unicode 值?

    我正在用 C 编写 JSON 解析器 在解析 JSON 字符串时遇到问题 JSON 规范规定 JSON 字符串可以包含以下形式的 unicode 字符 here comes a unicode character u05d9 我的 JSON
  • 如何获取要执行的 PTX 文件

    我知道如何生成 ptx文件来自 cu以及如何生成 cubin文件来自 ptx 但我不知道如何获得最终的可执行文件 更具体地说 我有一个sample cu文件 编译为sample ptx 然后我使用 nvcc 来编译sample ptx to
  • 如何在Oracle中查找模式名称?当您使用只读用户连接到 SQL 会话时

    我使用只读用户连接到 Oracle 数据库 并且在 sql Developer 中设置连接时使用了服务名称 因此我不知道 SID 架构 如何找到我连接到的架构名称 我正在寻找这个 因为我想要生成 ER 图 https stackoverfl
  • 按方案中的第一个元素对列表列表进行排序

    例如 我正在研究按第一个元素对列表列表进行排序 排序 列表 2 1 6 7 4 3 1 2 4 5 1 1 预期输出 gt 1 1 2 1 6 7 4 3 1 2 4 5 我使用的算法是冒泡排序 我修改了它来处理列表 但是 该代码无法编译
  • jQuery Mobile 范围滑块响应不够灵敏

    各位互联网界的好心人 大家好 我正在尝试使用 jQuery Mobile 滑块 范围 虽然它们工作得相当好并且在桌面浏览器上响应良好 但它们似乎在实际手机 例如 Android 与互联网网页交互时 Android 上使用触摸屏的滑块交互非常
  • Facebook SDK:ApiException:代理应用程序在未事先安装的情况下无法请求发布权限

    我正在努力使用 Android facebook SDK 3 5 riigth ow 我的账户一切都很完美 现在我把这个应用程序给了我的一个朋友 当他登录时 他并没有因为这个失败而被卡住 ApiException The proxied a
  • Azure 表存储将数据导出到 SQL 的平面或 XML 文件

    I am looking for capability to Export data from SQL Azure Azure Table Storage to Some Flat file or XML file so that we c
  • 如何将我的表单放在 css/html 中的图像之上?

    开发者们好 我想问一下如何才能让我的表单出现在我的图片之上 问题是我的表格出现在底部 这是我的屏幕截图 这是我的代码 HTML div class container align center div img src assets img
  • Fabric 不断要求输入密码

    我有 fab 文件 其中包含 env hosts localhost env user code env password searce def mk dirtree sudo mkdir s PROJECT DIR sudo chown
  • Java中int是如何实现的?

    根据文档Integer class Integer 类将基本类型 int 的值包装在对象中 Integer 类型的对象包含一个类型为 int 的字段 和文档int 默认情况下 int 数据类型是 32 位有符号二进制补码整数 其最小值为 2
  • 在组件安装之前反应设置滚动位置

    我有下面的反应组件 它本质上是一个聊天框 render const messages this props messages return div h1 this props project 0 project h1 div div div
  • 如何在 XCode4 中复制项目目标

    我想为测试环境创建一个具有不同捆绑 ID 的目标 我尝试使用 复制 功能来克隆目标并更改捆绑 ID 发现原始目标也发生了更改 感谢您的任何提示 更新 解决复制目标后的链接错误 这是一个xcode bug 搜索路径中的引号字符 更改为 目标的
  • PostgreSQL bigserial 和 nextval

    我有一个 PgSQL 9 4 3 服务器设置 之前我只使用公共模式 例如我创建了一个如下表 CREATE TABLE ma accessed by members tracking reference bigserial NOT NULL
  • 给定一个 4x4 齐次矩阵,我如何获得 3D 世界坐标?

    所以我有一个正在旋转然后再次平移和旋转的对象 我将这些翻译的矩阵存储为对象成员 现在 当我进行对象拾取时 我需要知道该对象的 3D 世界坐标 目前我已经能够像这样获得物体的位置 coords 0 finalMatrix 12 坐标 1 最终
  • h2o.saveModel 在 Windows 8 上抛出目录异常

    我在 R 中使用 h2o 版本 3 0 0 22 并尝试保存我的模型 但我似乎无法弄清楚预期的格式 我尝试了各种变化 但遇到了各种不同的异常 h2o saveModel model dir c temp name my model ERRO
  • Python - 基于 LSTM 的 RNN 需要 3D 输入?

    我正在尝试构建一个基于 LSTM RNN 的深度学习网络 这是尝试过的 from keras models import Sequential from keras layers import Dense Dropout Activatio