如何使用有状态 LSTM 和 batch_size > 1 布置训练数据

2024-03-29

背景

我想在 Keras 中对“有状态”LSTM 进行小批量训练。我的输入训练数据位于一个大矩阵“X”中,其维度为 m x n,其中

m = number-of-subsequences
n = number-of-time-steps-per-sequence

X 的每一行都包含一个子序列,该子序列接续前一行上的子序列离开的位置。因此,给定一长串数据,

Data = ( t01, t02, t03, ... )

其中“tK”表示原始数据中位置 K 的标记,序列在 X 中布局如下:

X = [
  t01 t02 t03 t04
  t05 t06 t07 t08
  t09 t10 t11 t12
  t13 t14 t15 t16
  t17 t18 t19 t20
  t21 t22 t23 t24
]

Question

我的问题是,当我使用有状态 LSTM 对以这种方式布置的数据进行小批量训练时会发生什么。具体来说,小批量训练通常一次对“连续”的行组进行训练。因此,如果我使用大小为 2 的小批量,则 X 将分为三个小批量 X1、X2 和 X3,其中

X1 = [
  t01 t02 t03 t04
  t05 t06 t07 t08
]

X2 = [
  t09 t10 t11 t12
  t13 t14 t15 t16
]

X3 = [
  t17 t18 t19 t20
  t21 t22 t23 t25
]

请注意,这种类型的小批量处理与训练并不相符statefulLSTM,因为通过处理前一批的最后一列产生的隐藏状态不是与后续批次的第一列之前的时间步相对应的隐藏状态。

要看到这一点,请注意小批量将按照从左到右的方式进行处理,如下所示:

------ X1 ------+------- X2 ------+------- X3 -----
t01 t02 t03 t04 | t09 t10 t11 t12 | t17 t18 t19 t20
t05 t06 t07 t08 | t13 t14 t15 t16 | t21 t22 t23 t24

暗示着

- Token t04 comes immediately before t09
- Token t08 comes immediately before t13
- Token t12 comes immediately before t17
- Token t16 comes immediately before t21

但我希望小批量对行进行分组,以便我们在小批量之间获得这种时间对齐:

------ X1 ------+------- X2 ------+------- X3 -----
t01 t02 t03 t04 | t05 t06 t07 t08 | t09 t10 t11 t12
t13 t14 t15 t16 | t17 t18 t19 t20 | t21 t22 t23 t24

在 Keras 中训练 LSTM 时实现此目标的标准方法是什么?

感谢您在这里的任何指点。


解决方案 1 - 批量大小 = 1

好吧,既然看起来你实际上只有一个序列(虽然分开了,但它仍然是一个序列,对吧?),你确实必须使用等于 1 的批量大小进行训练。

如果您不想更改或重新组织数据,只需:

 X = X.reshape((-1,length,features))

     #where
         #length = 4 by your description    
         #features = 1 (if you have only one var over time, as it seems)

解决方案 2 - 重新组合长度 = 8

仍在使用一个批量大小为 1,重塑输入数据(在将其传递给模型之前),使其具有双倍长度。

最终结果将与您使用所描述的大小为 2 的小批量进行训练完全相同。(但请确保在模型的输入形状中将批量大小设置为 1,否则这会给您带来错误的结果)。

X = X.reshape((-1, 2 * length, features)) 

这会给你:

X = [
  [t01 t02 t03 t04 t05 t06 t07 t08]
  [t09 t10 t11 t12 t13 t14 t15 t16]
  [t17 t18 t19 t20 t21 t22 t23 t24]
]

解决方案 3 - 仅当您实际上有两个不同的序列时才可能

根据你的描述,你似乎只有一个序列。如果您确实有两个不同/独立的序列,那么您可以制作一批大小为 2 的批次。

如果将序列一分为二(并失去它们之间的连接)不是问题,您可以重新排列数据:

X = X.reshape((2,-1,length, features))

Then:

X0 = X[:,0]
X1 = X[:,1]
...

您可以尝试将其分组在一个数组中:

X = X.reshape((2,-1,length, features))
X = np.swapaxes(X,0,1).reshape((-1,length,features))

Then:

X0 = X[0]
X1 = X[1]
...

你可以尝试通过完整的X只要在模型中明确将批量大小设置为 2 即可进行训练输入形状.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何使用有状态 LSTM 和 batch_size > 1 布置训练数据 的相关文章

  • Keras LSTM 密集层多维输入

    我正在尝试创建一个 keras LSTM 来预测时间序列 我的 x train 形状像 3000 15 10 示例 时间步长 特征 y train 形状像 3000 15 1 我正在尝试构建一个多对多模型 每个序列 10 个输入特征产生 1
  • 支持向量机或人工神经网络进行文本处理? [关闭]

    Closed 这个问题不符合堆栈溢出指南 help closed questions 目前不接受答案 对于某些文本处理项目 我们需要在支持向量机和快速人工神经网络之间做出选择 它包括上下文拼写纠正 然后将文本标记为某些短语及其同义词 哪种方
  • 如何将one-hot向量转换为多标签?

    我有一项多分类任务 并且我得到了像这样的单热类型预测 0 1 1 0 1 0 1 0 1 我希望将这个单热向量转换为标签 例如 1 2 1 0 2 我已经尝试过 tf argmax 但它不起作用 那么我该如何处理呢 使用列表理解 oheLi
  • OutOfRangeError(请参阅上面的回溯):FIFOQueue '_1_batch/fifo_queue' 已关闭并且元素不足(请求 32,当前大小 0)

    我在使用队列中张量流读取图像时遇到问题 请让我知道我犯了什么错误 下面是代码 import tensorflow as tf slim tf contrib slim from tensorflow python framework imp
  • keras 层教程和示例

    我正在尝试编码和学习不同的神经网络模型 我对输入维度有很多复杂性 我正在寻找一些教程 显示层的差异以及如何设置每个层的输入和输出 Keras 文档 https keras io layers core 向您展示所有input shape每层
  • ValueError:维度 (-1) 必须在 [0, 2) 范围内

    我的python版本是3 5 2 我已经安装了keras和tensorflow 并尝试了官方的一些示例 示例链接 示例标题 用于多类 softmax 分类的多层感知器 MLP https keras io getting started s
  • 如何反转 dropout 来补偿 dropout 的影响并保持期望值不变?

    我正在学习神经网络中的正则化deeplearning ai课程 在dropout正则化中 教授说 如果应用dropout 计算出的激活值将比不应用dropout时 测试时 更小 因此 我们需要扩展激活以使测试阶段更简单 我理解这个事实 但我
  • keras 中的增量学习

    我正在寻找 scikit learn 的 keras 等效项partial fit https scikit learn org 0 15 modules scaling strategies html incremental learni
  • Python 上每个系数具有特定约束的多元线性回归

    我目前正在数据集上运行多元线性回归 起初 我没有意识到我需要限制自己的体重 事实上 我需要有特定的正权重和负权重 更准确地说 我正在做一个评分系统 这就是为什么我的一些变量应该对音符产生积极或消极的影响 然而 当运行我的模型时 结果不符合我
  • 如何使用 pytorch 同时迭代两个数据加载器?

    我正在尝试实现一个接收两张图像的暹罗网络 我加载这些图像并创建两个单独的数据加载器 在我的循环中 我想同时遍历两个数据加载器 以便我可以在两个图像上训练网络 for i data in enumerate zip dataloaders1
  • 用于分布式计算的 Tensorflow 设置

    任何人都可以提供有关如何设置张量流以在网络上的许多CPU上工作的指导吗 到目前为止 我发现的所有示例最多只使用一个本地盒子和多个 GPU 我发现我可以在 session opts 中传递目标列表 但我不确定如何在每个盒子上设置张量流来侦听网
  • Native TF 与 Keras TF 性能比较

    我使用本机和后端张量流创建了完全相同的网络 但在使用多个不同参数进行了多个小时的测试后 仍然无法弄清楚为什么 keras 优于本机张量流并产生更好 稍微但更好 的结果 Keras 是否实现了不同的权重初始化方法 或者执行除 tf train
  • 在 TensorFlow 中,tf.identity 有何用途?

    我见过tf identity在一些地方使用过 例如官方 CIFAR 10 教程和 stackoverflow 上的批量规范化实现 但我不明白为什么有必要 它是用来做什么的 谁能给出一两个用例吗 一种建议的答案是它可以用于 CPU 和 GPU
  • Keras ImageDataGenerator 相当于 csv 文件

    我在文件夹中排序了一堆数据 如下图所示 我需要构建一个 DataIterator 以便将数据放入神经网络模型中 当数据是图像时 我找到了很多例子来解决这个问题 使用 Keras 类图像数据生成器及其方法流自目录 但当数据是 csv 结构时则
  • Tensorflow Hub - 获取模型的输入形状和问题域?

    我正在使用最新版本的tensorflow hub 想知道如何获取有关模型的预期输入形状以及模型属于什么类型的集合的信息 例如 有没有办法以这种方式在 Python 中加载模型后获取有关预期图像形状的信息 model hub load htt
  • Keras ZeroDivisionError:整数除法或以零为模

    我正在尝试使用 Keras 和 Tensorflow 实现卷积神经网络 我有以下代码 from keras models import Sequential from keras layers import Conv2D MaxPoolin
  • Caffe 的 LSTM 模块

    有谁知道 Caffe 是否有一个不错的 LSTM 模块 我从 russel91 的 github 帐户中找到了一个 但显然包含示例和解释的网页消失了 以前是http apollo deepmatter io http apollo deep
  • 在 Tensorflow 对象检测 API 中绘制验证损失

    我正在使用 Tensorflow 对象检测 API 来检测和定位图像中的一类对象 为了这些目的 我使用预先训练的faster rcnn resnet50 coco 2018 01 28 model 我想在训练模型后检测拟合不足 过度拟合 我
  • 预测测试图像时出现错误 - 无法重塑大小数组

    我正在尝试使用 TensorFlow 和 Keras 在 Python 中进行图像识别 并且我已经关注了下面的博客 https stackabuse com image recognition in python with tensorfl
  • 在不同的 GPU 上同时训练多个 keras/tensorflow 模型

    我想在 Jupyter Notebook 中同时在多个 GPU 上训练多个模型 我正在使用 4GPU 的节点上工作 我想将一个 GPU 分配给一个模型并同时训练 4 个不同的模型 现在 我通过 例如 为一台笔记本选择 GPU import

随机推荐

  • android + eclipse + maven + actionbarsherlock

    我读了很多关于 actionbarsherlock maven android 的东西 但我见过的解决方案都不适合我 我确信我已经非常接近解决方案 但我不明白 我需要一些帮助 所以这是我的问题 我尝试创建一个依赖于 Actionbarshe
  • 如何删除空值?

    如何删除底部计数中的空值 即 我只想查看实际销售单位的产品 我尝试过非空和非空但没有成功 with member Measures Amount Sold as Measures Internet Sales Amount format s
  • 为什么“超时”不适用于管道?

    以下命令行调用timeout 这没有意义 只是出于测试原因 无法按预期工作 它会等待 10 秒 并且在 3 秒后不会停止命令的运行 为什么 timeout 3 ls sleep 10 您的命令正在执行的操作正在运行timeout 3 ls并
  • 在 Windows 上的 XAMPP 中哪里可以更改 lower_case_table_names=2 的值?

    我正在使用 Windows 7 和 XAMPP 我正在尝试导出数据库 在此过程中表名称将转换为小写 我搜索了很多 我知道我必须改变的值lower case table names from 0 to 2 但是我必须在哪里更改这个值 在哪个文
  • 将 TypeScript 网站从 GitHub 部署到 Azure

    我有一个 NET 网站 其中包含一些 TypeScript 文件 我尝试将其从 GitHub 部署为 Azure 网站 但收到与 TypeScript 相关的错误 在我看来 这可能与我使用最新版本 1 0 有关 而 kudu 版本只有 0
  • Google 端点和公共 Api 密钥

    要使用 Google 服务 您可以使用 OAuth 身份验证 或者 如果您不需要用户登录 则可以使用公共 api 密钥 将授权域定义为请求的来源 现在 我正在使用 google 端点编写自己的 API 并且我将允许用户通过公共 api 密钥
  • 使用sessionStorage有什么好处? [复制]

    这个问题在这里已经有答案了 只是想知道在存储要在 Javascript 轮播中使用的 HTML 内容时使用 HTML5 的 sessionStorage 的实际好处是什么 与性能有关吗 加载时间 带宽 是的 您将使用更少的带宽 这会提高性能
  • 使用 ggdendro 在树状图的片段下显示变量标签

    我的问题与安德里的有关answer https i stack imgur com JW0m1 png我之前的问题 我的问题是是否可以在树状图的相应段下显示变量标签和汽车标签 library ggplot2 library ggdendro
  • 扩展 Android 的默认 Gmail/电子邮件应用程序

    我想通过插入 ContentProvider 或使用意图过滤器来扩展 Android 平台的默认 Gmail 电子邮件应用程序 本质上 我希望能够扫描传入的电子邮件以查找将在我的 Android 应用程序中触发事件的特殊规则 如果自动扫描电
  • 立即终止无循环线程,无需中止或挂起

    我正在实现一个协议库 这里有一个简化的描述 main 函数中的主线程将始终检查网络流 在 tcpclient 内 上是否有某些数据可用 假设响应是收到的消息 线程是正在运行的线程 thread new Thread new ThreadSt
  • 在 Sparklyr 中创建虚拟变量?

    我正在尝试扩展我的一些 ML 管道 我喜欢 Sparklyr 打开的 rstudio spark 和 h2o 的组合 http spark rstudio com http spark rstudio com 我试图弄清楚的一件事是如何使用
  • 多个组的可反应聚合函数

    使用 Rreactable包中 我试图使用两个 groupBy 变量显示标记读数的百分比 在较低级别的分组中 这是计算正确的百分比 但在分组的第二 外部 级别上 它没有显示正确的百分比 这是数据 dat lt structure list
  • PHP:查询 MySQL 最快的方法是什么?因为 PDO 慢得令人痛苦

    我需要执行一个简单的查询 从字面上看 我需要执行的是 SELECT price sqft zipcode FROM homes WHERE home id X 当我使用 PHP 时PDO 我读过的是连接到 MySQL 数据库的推荐方法 简单
  • 如何通过id查找页面上的控件

    有没有一种简单的方法可以通过 id 在任何嵌套容器中 在 ASP NET 中查找控件 除了遍历整个控件树之外 像这个例子 TextBox tb new TextBox ID textboxId panel3 Controls Add tb
  • Spring Boot如何选择外部化的Spring属性文件

    我有这个配置需要用于 Spring Boot 应用程序 server port 8085 server servlet context path authserver data source spring jpa hibernate ddl
  • Windows 上的 Python 包:pip 还是本机安装程序?

    与使用打包安装程序 exe msi 相比 使用 pip 在 Windows 上安装 python 软件包的相对优点是什么 对于初学者来说 有些对我来说不起作用 MySQLdb 是 我的新规则 Try pip or easy install
  • NodeJS + Mysql 与 Docker Compose 2

    我正在尝试构建一个 docker compose 文件来在本地部署连接到 mysql 服务器的 NodeJS 应用程序 我已经尝试了所有方法 在 Stackoverflow 中阅读了大量教程和一些问题 但我不断收到 ECONNREFUSED
  • Apache .htaccess:如何在 Firefox 上用斜杠重写反斜杠?

    如何重写反斜杠 带斜线 在火狐浏览器上 Chrome IE Safari Opera 已构建浏览器用斜杠重写反斜杠 但 Firefox 3 6 13 回归404错误页面 Why Firefox returns 404 error page
  • 使用history.pushState()更新URL中的参数

    我在用history pushState在我的页面上进行 AJAX 调用后 将一些参数附加到当前页面 URL 现在 在基于用户操作的同一页面上 我想使用相同或附加的参数集再次更新页面 URL 所以我的代码如下所示 var pageUrl w
  • 如何使用有状态 LSTM 和 batch_size > 1 布置训练数据

    背景 我想在 Keras 中对 有状态 LSTM 进行小批量训练 我的输入训练数据位于一个大矩阵 X 中 其维度为 m x n 其中 m number of subsequences n number of time steps per s