在 Keras 中,LSTM 状态何时在 model.predict 调用中重置?

2024-03-18

该模型将 LSTM 作为第一层。

当调用 model.predict 时,假设您传递了几个样本:

>sam = np.array([ [[.5, .6, .3]], [[.6, .6, .3]], [[.5, .6, .3]] ])
>model.predict(sam)
array([[ 0.23589483],
       [ 0.2327884 ],
       [ 0.23589483]])

上面我们看到了映射: [[.5, .6, .3]] -> 0.23589483 等(1 个元素的序列,它是长度为 3 的向量,映射到实数)

该模型的 input_length 为 1,input_dim 为 3。请注意,第一个和最后一个相同,并且具有相同的输出 (0.23589483)。所以我的假设是,Keras 处理样本(在本例中是 1 个 3-D 向量的序列)后,它会重置模型的内存。即每个序列基本上是独立的。这个观点有什么不正确或者误导的地方吗?

再举一个 input_length 3 和 input_dim 1 的例子。这一次,切换序列中的值并看到不同的结果(将第二个列表与最后一个列表进行比较)。因此,当 Keras 处理序列时,内存会发生变化,但处理完成后,内存会重置(第一个和第二个序列具有相同的结果)。

sam = np.array([ [[.1],[.1],[.9]], [[.1],[.9],[.1]], [[.1],[.1],[.9]]   ])
model.predict(sam)
array([[ 0.69906837],
   [ 0.1454899 ],
   [ 0.69906837]])

上面我们看到了映射 [[.1],[.1],[.9]] -> 0.69906837 等(3 个元素到实数的序列)


我知道这是一个老问题,但希望这个答案可以帮助像我这样的其他 Keras 初学者。

我在我的机器上运行这个例子,观察到 LSTM 的隐藏状态和单元状态确实随着调用而改变model.predict.

import numpy as np
import keras.backend as K
from keras.models import Model
from keras.layers import LSTM

batch_size = 1
timestep_size = 2
num_features = 4

inputs = Input(batch_shape=(batch_size, timestep_size, num_features)
x = LSTM(num_features, stateful=True)(inputs)

model = Model(inputs=inputs, outputs=x)
model.compile(loss="mse",
              optimizer="rmsprop",
              metrics=["accuracy"])

x = np.random.randint((10,2,4))
y = np.ones((10,4))
model.fit(x,y, epochs=100, batch_size=1)

def get_internal_state(model):
    # get the internal state of the LSTM
    # see https://github.com/fchollet/keras/issues/218
    h, c = [K.get_value(s) for s, _ in model.state_updates]
    return h, c

print "After fitting:", get_internal_state(model)

for i in range(3):
    x = np.random.randint((10,2,4))
    model.predict(x)
    print "After predict:", get_internal_state(model)

以下是调用的输出示例get_internal_state训练结束后:

After_fitting: (array([[ 1.,  1.,  1.,  1.]], dtype=float32), array([[  11.33725166,   11.8036108 ,  181.75688171,   25.50110626]], dtype=float32))
After predict (array([[ 1.        ,  0.99999994,  1.        ,  1.        ]], dtype=float32), array([[   9.26870918,    8.83847237,  179.92633057,   28.89341927]], dtype=float32))
After predict (array([[ 0.99999571,  0.9992013 ,  1.        ,  0.9915328 ]], dtype=float32), array([[   6.5174489 ,    8.55165958,  171.42166138,   25.49199104]], dtype=float32))
After predict (array([[ 1.,  1.,  1.,  1.]], dtype=float32), array([[   9.78496075,    9.27927303,  169.95401001,   28.74017715]], dtype=float32))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 Keras 中,LSTM 状态何时在 model.predict 调用中重置? 的相关文章

  • 加载视频数据集(Keras)

    我正在尝试实现 LRCN C LSTM RNN 来对视频中的情绪进行分类 我的数据集结构分为两个文件夹 train set 和 valid set 当你打开其中任何一个时 你可以找到3个文件夹 积极 消极 和 惊喜 最后 这 3 个文件夹中
  • Google Colab:为什么 CPU 比 TPU 快?

    我正在使用 Google colabTPU训练一个简单的Keras模型 删除分布式strategy并在CPU比TPU 这怎么可能 import timeit import os import tensorflow as tf from sk
  • 了解 keras 中不同序列的 lstm 输入形状

    我对 keras 和 python 都很陌生 我有一个具有不同序列长度的时间序列数据集 例如第一个序列是 484000x128 第二个序列是 563110x128 等 我已将序列放入 3D 数组中 我的问题是如何定义输入形状 因为我很困惑
  • ImageDataGenerator 预测类 - 为什么预测未正确从概率转换为预测类?

    我有一个这样设置的目录 images val class1 class2 test all classes train class1 class2 每个目录中都有一组图像 我想预测测试中的每个图像是否属于 1 类或 2 类 我写这个是为了读
  • 结合两个 CNN

    我想在 Keras 中将两个 CNN 合并为一个 我的意思是我希望神经网络拍摄两张图像并在单独的 CNN 中处理每一张图像 然后将它们连接在一起进入扁平化层并使用全连接层来做最后的工作 我做了什么 Start With First Bran
  • 为 Keras 编写自定义数据生成器

    我将每个数据点存储在 npy 文件中 其中shape 1024 7 8 我想通过类似的方式将它们加载到 Keras 模型中ImageDataGenerator 所以我编写并尝试了不同的自定义生成器 但它们都不起作用 这是我改编的一个this
  • Keras 中的 model.fit() 和 model.evaluate() 有什么区别?

    我使用 Keras 和 TensorFlow 后端来训练 CNN 模型 之间是什么model fit and model evaluate 我应该最好使用哪一种 我在用model fit 截至目前 我知道的用处model fit and m
  • tf.keras.utils.image_dataset_from_directory,但标签来自 csv?

    请告诉我哪里出错了 我正在研究 Kaggle 狗品种分类挑战 我想尝试 one hot 编码与标签编码 图像未在图像目录中拆分 因此我无法将 推断 与 tf keras utils image dataset from directory
  • 使用 Keras 的 ImageDataGenerator 预测单个图像

    我对深度学习很陌生 所以请原谅我这个可能很简单的问题 我训练了一个网络来分类positive and negative 为了简化图像生成和拟合过程 我使用了ImageDataGenerator和fit generator函数 如下图 imp
  • 具有多个输入的 Keras TimeDistributed 层

    我正在尝试使以下代码行正常工作 low encoder out TimeDistributed AutoregressiveDecoder X tf embeddings Where AutoregressiveDecoder是一个需要两个
  • 在 Tensorflow 2.0 中的简单 LSTM 层之上添加 Attention

    我有一个由一个 LSTM 和两个 Dense 层组成的简单网络 如下所示 model tf keras Sequential model add layers LSTM 20 input shape train X shape 1 trai
  • Keras 序列模型中的数据增强层

    我正在尝试将数据增强作为一个层添加到模型中 但我遇到了我认为是形状问题 我也尝试在增强层中指定输入形状 当我取出data augmentation模型中的图层运行良好 preprocessing RandomFlip horizontal
  • Keras 中的损失函数和度量有什么区别? [复制]

    这个问题在这里已经有答案了 我不清楚 Keras 中损失函数和指标之间的区别 该文档对我没有帮助 损失函数用于优化您的模型 这是优化器将最小化的函数 指标用于判断模型的性能 这仅供您查看 与优化过程无关
  • 将预训练的手套词嵌入与 scikit-learn 结合使用

    我已经使用 keras 来使用预先训练的词嵌入 但我不太确定如何在 scikit learn 模型上执行此操作 我也需要在 sklearn 中执行此操作 因为我正在使用vecstack集成 keras 序列模型和 sklearn 模型 这就
  • 如何确定 Keras Conv2D 函数中的“filter”参数

    我刚刚开始我的 ML 之旅 并且已经完成了一些教程 对我而言 不清楚的一件事是如何为 Keras Conv2D 确定 过滤器 参数 我读过的大多数资料只是将参数设置为 32 没有任何解释 这只是经验法则还是输入图像的尺寸起作用 例如 CIF
  • 了解 YOLO 是如何训练的

    我试图了解 YOLO v2 是如何训练的 为此 我使用这个 keras 实现https github com experiencor keras yolo2 https github com experiencor keras yolo2在
  • 将 Dropout 与 Keras 和 LSTM/GRU 单元结合使用

    在 Keras 中 您可以像这样指定 dropout 层 model add Dropout 0 5 但对于 GRU 单元 您可以将 dropout 指定为构造函数中的参数 model add GRU units 512 return se
  • 在相同任务上,Keras 比 TensorFlow 慢

    我正在使用 Python 运行斩首 DCNN 本例中为 Inception V3 来获取图像特征 我使用的是 Anaconda Py3 6 和 Windows7 使用 TensorFlow 时 我将会话保存在变量中 感谢 jdehesa 并
  • 如何使用 Tensorflow-GPU 和 Keras 修复低易失性 GPU-Util?

    我有一台 4 GPU 机器 在上面运行带有 Keras 的 Tensorflow GPU 我的一些分类问题需要几个小时才能完成 nvidia smi returns Volatile GPU Util which never exceeds
  • Keras model.predict 函数给出输入形状错误

    我已经在 Tensorflow 中实现了通用句子编码器 现在我正在尝试预测句子的类概率 我也将字符串转换为数组 Code if model model type universal classifier basic class probs

随机推荐

  • .jcall(cell, "V", "setCellValue", value) 中的错误:尝试 write.xlsx 时未找到带有签名 ([D)V 的 setCellValue 方法

    library dtplyr library xlsx library lubridate data frame 612 obs of 7 variables Company Factor w 10 levels Harbor HCG 6
  • JQuery:委托和日期选择器

    我需要给定类中的每个文本输入都是一个日期选择器 就像是 input type text time datepicker 但我通过 Jquery load 添加了很多代码 所以我相信我需要一个委托 问题是我不知道该怎么做 因为据我所知 加载事
  • Ninject 3.0 MVC kernel.bind 错误自动注册

    kernel Bind 上的获取和错误scanner gt 在 VS 2010 中 scanner 下面有一条小错误线 无法将 lambda 表达式转换为类型 System Type 因为它不是代表 类型 尝试像 2 0 中的旧 kerne
  • Xcode 7:将数组控制器绑定到单选按钮组

    我有一小组对象 用户应该能够使用单选按钮组从中选择一个对象 这些对象已绑定到数组控制器 有没有办法将该阵列控制器绑定到单选按钮组 以便动态生成其他单选按钮 如果可能 首选 IB 解决方案 示例项目 https scriptreactor c
  • 无法为 Kindle Fire HD 安装 ADB

    我正在尝试root它 尽管在我安装了正确的ADB驱动程序之后 当我插入我的Kindle fire HD 7 时 点燃火 gt Android 复合 ADB 接口 没有出现在设备管理器中 因此我无法执行root 我已将 0x1949 添加到
  • Elasticsearch 使用 jest 通过查询删除[关闭]

    Closed 这个问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 我发现一个有趣的功能叫做通过查询删除 https www elastic co guide en el
  • 如何使用 python 从文本文件的行中读取特定字符?

    我有多个 txt 文件 其中包含与此类似的多行 class1 1 28 9 315 13 354227 2 36 247 17 342 8 34 14 3825 class2 14 31 8679 7 32 3582 2 32 4127 1
  • 组合常见搭配的 NLP 流程

    我有一个语料库 我在 R 中使用 tm 包 并且还在 python 中的 NLTK 中镜像相同的脚本 我正在使用一元组 但希望某种解析器能够将通常位于同一位置的单词组合成一个单词 即 我不想再在我的单词中分别看到 New 和 York 当它
  • Matlab 替换轴范围

    我的 x 轴从 0 到 96 其中每个数字代表一天中的一刻钟 96 4 24 小时 我需要轴来显示 0 到 24 小时 有没有办法在绘图后仅修改轴 您可以使用 gt gt set gca XTick 0 4 96 gt gt set gca
  • 如何将嵌套字典传递给 Flask 的 GET 请求处理程序

    我试图将嵌套字典作为参数传递给 GET 请求 该请求由 Flask 工作线程处理 整个设置是Nginx Gunicorn Flask 在客户端 我正在执行以下操作 import requests def find cabin party P
  • Numpy:了解行名称的 numpy 数组概念

    也许是一个非常模糊的问题 但是挖掘 numpy 上的链接对我没有帮助 我需要使用以下分层聚类对如下所示的二进制数组进行相似度矩阵计算 name val1 val2 val3 val4 val5 comp1 0 0 1 0 1 comp2 1
  • 使用PHP批量删除域共享联系人

    我正在使用 Google API PHP客户端库 v2 1 3 https github com google google api php client 我正在关注以下文档域共享联系人 https developers google co
  • 使 saxon-c 在 Python 中可用

    我刚刚读到 Saxon 现在可用于 Python 这非常有趣而且很好 但是任何人都可以写一篇关于如何使其可用于 Python Anaconda WingIDE 或类似的教程吗 我习惯于使用 pip 或 conda 安装 并指向一个包 轮子以
  • xcode 5 问题:“iOS 模拟器无法安装应用程序”

    我刚刚将我的 xcode 版本升级到 5 0 运行应用程序 2 3 次后 它给我这样的错误 iOS模拟器无法安装应用程序 这在旧的 xcode 中工作正常 当我重置模拟器时 它工作正常 但这一次又一次令人恼火 谁能告诉我真正的问题是什么 我
  • DateTimePicker 显示今天的日期而不是显示其实际值

    我们在表单上的自定义用户控件上有几个 DateTimePicker 它们是可见的 但未启用 仅用于显示目的 当加载 UserControl 时 DateTimePicker 会从来自 DataSet 的 DataRow 分配值 该 Data
  • 什么时候需要在 Ruby C 扩展中声明易失性值?

    我找不到太多关于何时适合声明的文档VALUE as volatileRuby 扩展中以避免过早对正在使用的对象进行垃圾回收 这是我到目前为止所学到的 有人可以填空吗 When volatile does not需要使用 在 C 对象成员中
  • 获取线程的输出

    您认为获取线程工作结果的最佳方式是什么 想象一个线程执行一些计算 如何警告主程序计算已完成 您可以每隔 X 毫秒轮询一些名为 作业完成 的公共变量或顺便说一句 但是您会收到比可用结果更晚的结果 主代码将浪费时间等待它们 另一方面 如果您使用
  • 如何从文件(即 SVG)创建 CGPath

    是否可以从给定文件创建 CGPath SVG 是首选 但任何东西都可以 袖珍SVG https github com arielelkin PocketSVG会将 SVG 文件转换为 UIBezierPath 从中您可以获得 CGPath
  • Mac OS X 上的 Heroku Local 和 PHP

    目前 除了始终在线的 apache 代理 php fpm 之外 我只使用额外的终端选项卡来手动启动工作进程和时钟进程 当我开始使用heroku时 我尝试了heroku local 但它的设置打败了我 现在我想再试一次 我在 High Sie
  • 在 Keras 中,LSTM 状态何时在 model.predict 调用中重置?

    该模型将 LSTM 作为第一层 当调用 model predict 时 假设您传递了几个样本 gt sam np array 5 6 3 6 6 3 5 6 3 gt model predict sam array 0 23589483 0