Keras LSTM:检查模型输入维度时出错

2024-02-16

我是 keras 的新用户,正在尝试实现 LSTM 模型。为了测试,我声明了如下所示的模型,但由于输入维度的差异而失败。虽然我在这个网站上发现了类似的问题,但我自己无法发现我的错误。

ValueError: 
Error when checking model input: 
expected lstm_input_4 to have 3 dimensions, but got array with shape (300, 100)

我的环境

  • 蟒蛇3.5.2
  • keras 1.2.0(Theano)

Code

from keras.layers import Input, Dense
from keras.models import Sequential
from keras.layers import LSTM
from keras.optimizers import RMSprop, Adadelta
from keras.layers.wrappers import TimeDistributed
import numpy as np

in_size = 100
out_size = 10
nb_hidden = 8

model = Sequential()
model.add(LSTM(nb_hidden, 
               name='lstm',
               activation='tanh',
               return_sequences=True,
               input_shape=(None, in_size)))
model.add(TimeDistributed(Dense(out_size, activation='softmax')))

adadelta = Adadelta(clipnorm=1.)
model.compile(optimizer=adadelta,
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# create dummy data
data_size = 300
train = np.zeros((data_size, in_size,), dtype=np.float32)
labels = np.zeros((data_size, out_size,), dtype=np.float32)
model.fit(train, labels)

编辑 1(在 Marcin Możejko 发表评论后不起作用)

谢谢马尔辛·莫泽科。但我有一个类似的错误,如下所示。我更新了虚拟数据以供检查。这段代码有什么问题?

ValueError:检查模型目标时出错:预期 timedistributed_36 具有 3 个维度,但得到具有形状的数组 (208, 1)

def create_dataset(X, Y, loop_back=1):
    dataX, dataY = [], []
    for i in range(len(X) - loop_back-1):
        a = X[i:(i+loop_back), :]
        dataX.append(a)
        dataY.append(Y[i+loop_back, :])
    return np.array(dataX), np.array(dataY)

data_size = 300
dataset = np.zeros((data_size, feature_size), dtype=np.float32)
dataset_labels = np.zeros((data_size, 1), dtype=np.float32)

train_size = int(data_size * 0.7)
trainX = dataset[0:train_size, :]
trainY = dataset_labels[0:train_size, :]
testX = dataset[train_size:, :]
testY = dataset_labels[train_size:, 0]
trainX, trainY = create_dataset(trainX, trainY)
print(trainX.shape, trainY.shape) # (208, 1, 1) (208, 1)

# in_size = 100
feature_size = 1
out_size = 1
nb_hidden = 8

model = Sequential()
model.add(LSTM(nb_hidden, 
               name='lstm',
               activation='tanh',
               return_sequences=True,
               input_shape=(1, feature_size)))

model.add(TimeDistributed(Dense(out_size, activation='softmax')))
adadelta = Adadelta(clipnorm=1.)
model.compile(optimizer=adadelta,
              loss='categorical_crossentropy',
              metrics=['accuracy'])
model.fit(trainX, trainY, nb_epoch=10, batch_size=1)

这是一个非常经典的问题LSTM in Keras. LSTM输入形状应该是2d- 有形状(sequence_length, nb_of_features)。额外的第三个维度来自示例维度 - 因此输入到模型的表格具有形状(nb_of_examples, sequence_length, nb_of_features)。这就是你的问题的根源。请记住,一个1-d序列应表示为2-d具有形状的数组(sequence_length, 1)。这应该是你的输入形状LSTM:

model.add(LSTM(nb_hidden, 
           name='lstm',
           activation='tanh',
           return_sequences=True,
           input_shape=(in_size, 1)))

并记住reshape将您的输入转换为适当的格式。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Keras LSTM:检查模型输入维度时出错 的相关文章

随机推荐

  • 如何在使用 XSL-FO 生成的 PDF 中插入换行符

    我正在使用 XSL FO 和 XML 生成 PDF 在文本框中 用户可以输入 1 等数据 然后按 ENTER 然后按 2 ENTER 3 等 但在 XML 和 PDF 中 输出是 1234567 如何保留换行符 我已经尝试过了white s
  • Vue props 数据未在子组件中更新

    大家好 我只是想要一些关于 vue props 数据的解释 所以我将值从父组件传递到子组件 问题是 当父数据发生数据更改 更新时 它不会在子组件中更新 Vue component child component template div c
  • Blazor JsInterop:调用 JS 时 Div 不可用

    该问题涉及客户端 Blazor 组件 该组件包含一个被组件变量隐藏的 div bool 打开 我需要组件在组件代码文件中显示 div 之后运行一些 Javascript 以便调整它在屏幕上的位置 下面的代码应该更好地解释这一点 组件 raz
  • 为什么要使用弹簧? [关闭]

    就目前情况而言 这个问题不太适合我们的问答形式 我们希望答案得到事实 参考资料或专业知识的支持 但这个问题可能会引发辩论 争论 民意调查或扩展讨论 如果您觉得这个问题可以改进并可能重新开放 访问帮助中心 help reopen questi
  • NetworkStream 和 Socket 类有什么区别?

    我有一个项目 我可能想抽象客户端和服务器之间的通信 我最初使用的是套接字和 TCP 然后我认为能够切换到进程间通信通道可能会很好 然后我查看了 System IO PipedStream 类 发现 PipeStream 和 Socket 类
  • PowerShell Start-Service无限运行

    Problem 因此 我有一段代码用于启动服务 如果服务花费太长时间并且在大多数情况下工作正常 则服务超时 不幸的是 当该服务尝试启动无法启动的服务时 它会显示以下警告消息 WARNING Waiting for ServiceName
  • 从数组中删除重复的字符串?

    如何在不使用 HashSet 的情况下从字符串数组中删除重复的字符串 我尝试使用循环 但没有删除的话 StringBuffer outString new StringBuffer Our aim and isn t easy you yo
  • 创建 OpenLayer 圈时出现问题

    如何在openlayer地图中画一个圆 我尝试过不同的方式 但它不起作用 请帮助我编写代码 我使用了以下代码 但它创建了多边形 var p1 new OpenLayers Geometry Point 439000 114000 var p
  • 我可以在我的视图模型中创建一个实时数据观察器吗?或者我应该始终观察片段/活动?

    我是 MVVM 新手 因此 我的片段 活动向服务器发出了 2 个请求 第一个请求的结果将用作第二个请求的输入参数 因此 首先在我的片段中 当单击按钮时 我会发出请求以检查用户是否被禁止 如果没有 则该用户可以创建帖子 所以首先我检查用户是否
  • 检测两年以上的浏览器

    这是一个拥有大约 10 000 个用户的私人公司网站 我已经看到了一些浏览器检测的努力 但与浏览器的年龄无关 有人对此有想法吗 相关项目 http fresh browsers com en http fresh browsers com
  • RESTEasy Mock 与异常映射器与上下文

    RESTEasy 模拟框架工作正常 没有异常映射器 接收请求并返回带有预期内容的实体 注册异常映射器并强制异常后 当 RESTEasy 内部调用 ResteasyProviderFactory getContextData type 时 调
  • 如果 div 包含

    标签,jQuery 返回 true 或 false

    让我们来看看 div p this div contains a p tag p div div this one is not div 如果 div 包含特定标签 如上例中的 p 如何为变量分配布尔值 true 或 false div h
  • Spark-单调递增 id 在数据帧中无法按预期工作?

    我有一个数据框df在 Spark 中 它看起来像这样 scala gt df show columna1 columna2 0 1 0 4 0 2 0 5 0 1 0 3 0 3 0 6 0 2 0 7 0 2 0 8 0 1 0 7 0
  • 模拟器:错误:x86 模拟当前需要硬件加速

    我尝试在 Android Studio 中运行我的 Hello World 应用程序 我收到以下错误 模拟器 错误 x86 模拟当前需要硬件 加速 请确保英特尔 HAXM 已正确安装且可用 CPU加速状态 HAX内核模块未安装 你能告诉我如
  • 如何映射网址?

    我想映射这样的页面domain content myProject home html to domain home html content myProject 不需要 我有以下代码 String newpath getResourceR
  • 如何在 Google Optimize 中的 Document Ready 上运行 Javascript?

    如何在 Google 优化广告系列中的窗口加载或文档就绪时运行 javascript 它似乎允许我选择 DOM 元素一直到 Body 但我需要在文档准备好时运行 js 这就是我的做法 在可视化编辑器中编辑您的实验变体 单击选择元素图标 左上
  • Flutter (Dart) 如何在应用程序中点击时将副本添加到剪贴板?

    我是 Flutter 的初学者 我刚刚开始遵循他们的名称生成器应用程序教程并制作了一个简单的名称生成应用程序 我想知道当用户点击名称时是否可以添加复制到剪贴板功能 我尝试实现在堆栈上找到的解决方案 但它不起作用 我的完整代码在这里 任何建议
  • 检查Python中的字符串是否包含日期或时间戳[关闭]

    Closed 这个问题不符合堆栈溢出指南 help closed questions 目前不接受答案 我需要想出一个函数 它将接受一个字符串 它将执行以下操作 检查它是否是 UTC 格式的时间戳 例如 如果它的形式为2014 05 10T1
  • 为什么 scanf() 在某些情况下需要 & 运算符(地址),而在其他情况下不需要? [复制]

    这个问题在这里已经有答案了 为什么我们需要放一个 运算符在scanf 用于将值存储在整数数组中 但不能将字符串存储在字符数组中 int a 5 for i 0 i lt 5 i scanf d a i but char s 5 scanf
  • Keras LSTM:检查模型输入维度时出错

    我是 keras 的新用户 正在尝试实现 LSTM 模型 为了测试 我声明了如下所示的模型 但由于输入维度的差异而失败 虽然我在这个网站上发现了类似的问题 但我自己无法发现我的错误 ValueError Error when checkin