Keras LSTM 输入形状的输入形状错误

2024-04-17

在 Keras 中使用时间序列时出现此错误:

ValueError: Error when checking input: expected lstm_1_input to have 3 dimensions, but got array with shape (31, 3)

这是我的功能:

def CreateModel(shape):
  """Creates Keras Model.

  Args:
    shape: (set) Dataset shape. Example: (31,3).

  Returns:
    A Keras Model.

  Raises:
    ValueError: Invalid shape
  """

  if not shape:
    raise ValueError('Invalid shape')

  logging.info('Creating model')
  model = Sequential()
  model.add(LSTM(4, input_shape=(31, 3)))
  model.add(Dense(1))
  model.compile(loss='mean_squared_error', optimizer='adam')
  return model

主要代码:

print(training_features.shape)
model = CreateModel(training_features.shape)
model.fit(
      training_features,
      training_label,
      epochs=FLAGS.epochs,
      batch_size=FLAGS.batch_size,
      verbose=FLAGS.keras_verbose_level)

完整错误:

Traceback (most recent call last):
  File "<embedded module '_launcher'>", line 149, in run_filename_as_main
  File "<embedded module '_launcher'>", line 33, in _run_code_in_main
  File "model.py", line 300, in <module>
    app.run(main)
  File "absl/app.py", line 433, in run
    _run_main(main, argv)
  File "absl/app.py", line 380, in _run_main
    sys.exit(main(argv))
  File "model.py", line 274, in main
    verbose=FLAGS.keras_verbose_level)
  File "keras/models.py", line 960, in fit
    validation_steps=validation_steps)
  File "keras/engine/training.py", line 1581, in fit
    batch_size=batch_size)
  File "keras/engine/training.py", line 1414, in _standardize_user_data
    exception_prefix='input')
  File "keras/engine/training.py", line 141, in _standardize_input_data
    str(array.shape))
ValueError: Error when checking input: expected lstm_1_input to have 3 dimensions, but got array with shape (31, 3)

代码最初来自here https://machinelearningmastery.com/time-series-prediction-with-deep-learning-in-python-with-keras/

我努力了:

training_features = numpy.reshape(
      training_features,
      (training_features.shape[0], 1, training_features.shape[1]))

但我得到:

ValueError: Input 0 is incompatible with layer lstm_1: expected ndim=3, found ndim=4

如果您的原始数据是(31,3)那么我认为您正在寻找的是training_features.shape =(31,3,1)。您可以通过以下行获得它...

training_features = training_features.reshape(-1, 3, 1)

这将简单地向现有数据添加一个新轴(-1 只是告诉 numpy 使用原始数据中的值计算出这个维度)。

您还需要修复模型的输入形状。 31 应该是数据中的样本数。这不包含在 Keras 中input_shape范围。你应该使用...

model.add(LSTM(4, input_shape=(3, 1)))

Keras 会自动将批量大小设置为None这意味着任意数量的样本都适用于该模型。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Keras LSTM 输入形状的输入形状错误 的相关文章

  • python中热图的层次聚类

    我有一个 NxM 矩阵 其值范围为 0 到 20 我可以使用 Matplotlib 和 pcolor 轻松获得热图 现在我想使用 scipy 应用层次聚类和树状图 我想重新排序每个维度 行和列 以显示哪些元素相似 根据聚类结果 如果矩阵是方
  • 美丽的汤从谷歌搜索中提取href

    谷歌搜索给出了以下 HTML 的第一个结果 h3 class r a href https rads stackoverflow com amzn click com 0470284889 class l vst em Quantitati
  • 静态文件配置不正确

    我已经在 Heroku 上部署了简单的博客应用程序 它运行在Django 1 8 4 我在静态文件方面遇到了一些问题 当打开我的应用程序时 我看到Application Error页面 所以我尝试调试它并发现当我提交到 Heroku 时它无
  • 无法在 mysql 表中的值中使用破折号(-)[重复]

    这个问题在这里已经有答案了 我一直在尝试从 python 将数据插入 MYSQL 表 我的sql表中的字段是id token start time end time和no of trans 我想存储使用生成的令牌uuid4在令牌栏中 但由于
  • 如何在 Django 管理中以表格格式显示添加模型?

    我刚刚开始使用 Django 编写我的第一个应用程序 为我的家庭设计的家务图表管理器 在本教程中 它向您展示了如何添加相关对象 http docs djangoproject com en dev intro tutorial02 cust
  • 创建圆形图像 PIL Tkinter

    Currently I have a zoom feature in my application that works very well however I d like the actual zoom box to be a circ
  • 使用信号时出现 django TransactionManagementError

    我有一个与 django 的用户和 UserInfo 一对一的字段 我想订阅用户模型上的 post save 回调函数 以便我也可以保存 UserInfo receiver post save sender User def saveUse
  • 在 keras 中使用自定义张量流操作

    我在张量流中有一个脚本 其中包含自定义张量流操作 我想将代码移植到 keras 但我不确定如何在 keras 代码中调用自定义操作 我想在 keras 中使用tensorflow 所以到目前为止我发现的教程描述了与我想要的相反的内容 htt
  • 向 Python 2.6 添加 SSL 支持

    我尝试使用sslPython 2 6 中的模块 但我被告知它不可用 安装OpenSSL后 我重新编译2 6 但问题仍然存在 有什么建议么 您安装了 OpenSSL 开发库吗 我必须安装openssl devel例如 在 CentOS 上 在
  • 在 MATLAB 中创建共享库

    一位研究人员在 MATLAB 中创建了一个小型仿真 我们希望其他人也能使用它 我的计划是进行模拟 清理一些东西并将其变成一组函数 然后我打算将其编译成C库并使用SWIG https en wikipedia org wiki SWIG创建一
  • Floyd-Warshall 算法:获取最短路径

    假设一个图由一个表示n x n维数邻接矩阵 我知道如何获得所有对的最短路径矩阵 但我想知道有没有办法追踪所有最短路径 Blow是python代码实现 v len graph for k in range 0 v for i in range
  • 管理文件字段当前 url 不正确

    在 Django 管理中 只要有 FileField 编辑页面上就会有一个 当前 框 其中包含指向当前文件的超链接 但是 此链接会附加到当前页面 url 因此会导致 404 因为不存在这样的页面 例如 http 127 0 0 1 8000
  • 如何从数据框的单元格中获取值?

    我构建了一个条件 从我的数据框中提取一行 d2 df df l ext l ext df item item df wn wn df wd 1 现在我想从特定列中获取一个值 val d2 col name 但结果 我得到一个包含一行和一列
  • 使用python中的mysql连接器正确从mysql数据库获取blob

    当执行以下代码时 import mysql connector connection mysql connector connect connection params here cursor connection cursor curso
  • python中打印字符串的长度

    有没有什么方法可以找到 即使是最好的猜测 Python中字符串的 打印 长度 例如 potaa bto 是 8 个字符len但 tty 上只打印 6 个字符宽 预期用途 s potato x1b 01 32mpotato x1b 0 0mp
  • 向量化 numpy bincount

    我有一个 2d numpy 数组 A我要申请np bincount 到矩阵的每一列A生成另一个二维数组B由原始矩阵每列的 bincounts 组成A 我的问题是 np bincount 是一个采用一维数组的函数 它不是像这样的数组方法B A
  • 从 csv 中读取 pandas 数据帧,以非固定标头开始

    我有许多数据文件是由我的实验室中使用的一些相当黑客的脚本生成的 该脚本非常有趣 因为它在标头之前附加的行数因文件而异 尽管它们具有相同的格式并具有相同的标头 我正在编写一个批处理来将所有这些文件处理为数据帧 如果我不知道位置 如何让 pan
  • 从 C 线程调用 Python 代码

    我对从 C 或 C 线程调用 Python 代码时如何确保线程安全感到非常困惑 The Python 文档 http docs python org c api init html non python created threads似乎是
  • 在Python中从列表中获取n个项目组的惯用方法? [复制]

    这个问题在这里已经有答案了 给定一个列表 A 1 2 3 4 5 6 是否有任何惯用的 Pythonic 方式来迭代它 就好像它是 B 1 2 3 4 5 6 除了索引之外 这感觉像是 C 的遗留物 for a1 a2 in A i A i
  • Selenium Python 使用代理运行浏览器[重复]

    这个问题在这里已经有答案了 我正在尝试编写一个非常简单的脚本 该脚本从 txt 文件获取代理 不需要身份验证 并用它打开浏览器 然后沿着代理列表循环此操作一定次数 我确实知道如何打开 txt 文件并使用它 我的主要问题是让代理正常工作 我见

随机推荐