keras LSTM 层训练时间太长

2023-11-26

每当我在 Keras 上尝试 LSTM 模型时,似乎由于训练时间过长,该模型无法训练。

例如,像这样的模型每步需要 80 秒来训练:

def create_model(self):
        inputs = {}
        inputs['input'] = []
        lstm = []
        placeholder = {}
        for tf, v in self.env.timeframes.items():
            inputs[tf] = Input(shape = v['shape'], name = tf)
            lstm.append(LSTM(8)(inputs[tf]))
            inputs['input'].append(inputs[tf])
        account = Input(shape = (3,), name = 'account')
        account_ = Dense(8, activation = 'relu')(account)
        dt = Input(shape = (7,), name = 'dt')
        dt_ = Dense(16, activation = 'relu')(dt)
        inputs['input'].extend([account, dt])

        data = Concatenate(axis = 1)(lstm)
        data = Dense(128, activation = 'relu')(data)
        y = Concatenate(axis = 1)([data, account, dt])
        y = Dense(256, activation = 'relu')(y)
        y = Dense(64, activation = 'relu')(y)
        y = Dense(16, activation = 'relu')(y)
        output = Dense(3, activation = 'linear')(y)

        model = Model(inputs = inputs['input'], outputs = output)
        model.compile(loss = 'mse', optimizer = 'adam', metrics = ['mae'])
        return model

而使用 Flatten + Dense 替代 LSTM 的模型如下:

def create_model(self):
        inputs = {}
        inputs['input'] = []
        lstm = []
        placeholder = {}
        for tf, v in self.env.timeframes.items():
            inputs[tf] = Input(shape = v['shape'], name = tf)
            #lstm.append(LSTM(8)(inputs[tf]))
            placeholder[tf] = Flatten()(inputs[tf])
            lstm.append(Dense(32, activation = 'relu')(placeholder[tf]))
            inputs['input'].append(inputs[tf])
        account = Input(shape = (3,), name = 'account')
        account_ = Dense(8, activation = 'relu')(account)
        dt = Input(shape = (7,), name = 'dt')
        dt_ = Dense(16, activation = 'relu')(dt)
        inputs['input'].extend([account, dt])

        data = Concatenate(axis = 1)(lstm)
        data = Dense(128, activation = 'relu')(data)
        y = Concatenate(axis = 1)([data, account, dt])
        y = Dense(256, activation = 'relu')(y)
        y = Dense(64, activation = 'relu')(y)
        y = Dense(16, activation = 'relu')(y)
        output = Dense(3, activation = 'linear')(y)

        model = Model(inputs = inputs['input'], outputs = output)
        model.compile(loss = 'mse', optimizer = 'adam', metrics = ['mae'])
        return model

每步训练需要 45-50 毫秒。

模型中是否存在导致此问题的问题?或者这是否与该模型的运行速度一样快?

-- self.env.timeframes 看起来像这样:有 9 个项目的字典

timeframes = {
            's1': {
                'lookback': 86400,
                'word': '1 s',
                'unit': 1,
                'offset': 12
                },
            's5': {
                'lookback': 200,
                'word': '5 s',
                'unit': 5,
                'offset': 2
                },
            'm1': {
                'lookback': 100,
                'word': '1 min',
                'unit': 60,
                'offset': 0
                },
            'm5': {
                'lookback': 100,
                'word': '5 min',
                'unit': 300,
                'offset': 0
                },
            'm30': {
                'lookback': 100,
                'word': '30 min',
                'unit': 1800,
                'offset': 0
                },
            'h1': {
                'lookback': 200,
                'word': '1 h',
                'unit': 3600,
                'offset': 0
                },
            'h4': {
                'lookback': 200,
                'word': '4 h',
                'unit': 14400,
                'offset': 0
                },
            'h12': {
                'lookback': 100,
                'word': '12 h',
                'unit': 43200,
                'offset': 0
                },
            'd1': {
                'lookback': 200,
                'word': '1 d',
                'unit': 86400,
                'offset': 0
                }
            }

提示中的 GPU 信息 -

2018-06-30 07:35:16.204320: I T:\src\github\tensorflow\tensorflow\core\platform\cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
2018-06-30 07:35:16.495832: I T:\src\github\tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:1356] Found device 0 with properties:
name: GeForce GTX 1080 major: 6 minor: 1 memoryClockRate(GHz): 1.86
pciBusID: 0000:01:00.0
totalMemory: 8.00GiB freeMemory: 6.59GiB
2018-06-30 07:35:16.495981: I T:\src\github\tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:1435] Adding visible gpu devices: 0
2018-06-30 07:35:16.956743: I T:\src\github\tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:923] Device interconnect StreamExecutor with strength 1 edge matrix:
2018-06-30 07:35:16.956827: I T:\src\github\tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:929]      0
2018-06-30 07:35:16.957540: I T:\src\github\tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:942] 0:   N
2018-06-30 07:35:16.957865: I T:\src\github\tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:1053] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 6370 MB memory) -> physical GPU (device: 0, name: GeForce GTX 1080, pci bus id: 0000:01:00.0, compute capability: 6.1)

如果您使用 GPU,请将所有 LSTM 层替换为 CuDNNLSTM 层。您可以从以下位置导入它keras.layers:

from keras.layers import  CuDNNLSTM

def create_model(self):
    inputs = {}
    inputs['input'] = []
    lstm = []
    placeholder = {}
    for tf, v in self.env.timeframes.items():
        inputs[tf] = Input(shape = v['shape'], name = tf)
        lstm.append(CuDNNLSTM(8)(inputs[tf]))
        inputs['input'].append(inputs[tf])
    account = Input(shape = (3,), name = 'account')
    account_ = Dense(8, activation = 'relu')(account)
    dt = Input(shape = (7,), name = 'dt')
    dt_ = Dense(16, activation = 'relu')(dt)
    inputs['input'].extend([account, dt])

    data = Concatenate(axis = 1)(lstm)
    data = Dense(128, activation = 'relu')(data)
    y = Concatenate(axis = 1)([data, account, dt])
    y = Dense(256, activation = 'relu')(y)
    y = Dense(64, activation = 'relu')(y)
    y = Dense(16, activation = 'relu')(y)
    output = Dense(3, activation = 'linear')(y)

    model = Model(inputs = inputs['input'], outputs = output)
    model.compile(loss = 'mse', optimizer = 'adam', metrics = ['mae'])
    return model

以下是更多信息:https://keras.io/layers/recurrent/#cudnnlstm

这将显着加快模型速度 =)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

keras LSTM 层训练时间太长 的相关文章

随机推荐

  • C# 仅使用代码添加引用(没有 IDE“添加引用”功能)

    我正在为一个程序编写一个插件 我想将我的代码放入 DLL 中 这样我就可以自由地共享该插件 而无需暴露 泄露 我的代码 这是我可以访问的基本结构 using System public class Plugin public void In
  • flutter - 自更新到 firebase 9.0.X 以来出现错误 event.snapshot.value

    我收到错误event snapshot value自从更新到 firebase 9 0 5 以来 我有很多这样的函数 它们在 firebase 8 X 中运行良好 Stream
  • 如何根据指定的行数拆分 CSV 文件?

    我有 CSV 文件 大约 10 000 行 每行有 300 列 存储在 LINUX 服务器上 我想将此 CSV 文件分成 500 个 CSV 文件 每个文件有 20 条记录 每个都具有与原始 CSV 中相同的 CSV 标头 有没有什么lin
  • Java HTTP/2 服务器套接字

    我想让服务器套接字在 Java 中支持 HTTP 2 最好是 TLS https 我有一个 TLS 服务器套接字工作正常 但浏览器只能与它对话 HTTP 1 1 如果我理解正确的话 您需要 ALPN 来让 HTTP 2 浏览器连接到您的 T
  • 如何在 SELECT 语句中使用 BOOLEAN 类型

    我有一个参数为 BOOLEAN 的 PL SQL 函数 function get something name in varchar2 ignore notfound in boolean 此功能是第三方工具的一部分 我无法更改它 我想在
  • 如何授予 ASP.NET 写入 Windows 7 文件夹的权限?

    我有一个新的 Win7 工作站 我正在尝试让 ScrewTurn Wiki 在该机器上运行 我的 STW 安装使用文件系统选项来存储其数据 因此我需要向网站安装文件夹中的 ASP NET 工作进程授予写入权限 然而 我似乎无法想出 Win7
  • 如何更改 openshift 容器平台中的权限?

    我是 Openshift 的新手 我已经在 openshift 中部署了一个应用程序 当我检查日志时 某些文件存在权限被拒绝错误 现在 我想更改已部署在 Openshift 中的容器的权限 但收到 不允许操作 警告 我该如何解决 这是针对运
  • 面试:为集合的集合设计一个迭代器

    在java中为集合的集合设计一个迭代器 迭代器应该隐藏嵌套 允许您迭代属于所有集合的所有元素 就像使用单个集合一样 这是一个老问题 但现在 2019 年 我们有了 JDK8 的好东西 特别是 我们有流 这使得这项任务变得简单 public
  • 自定义 Mathematica 快捷键

    Is there a place I can view change global shortcut options like Command 9 turn into Input style 特别是 我需要一种更快的方法来创建项目符号列表
  • 从元组列表中格式化 JSON 字符串的更 Pythonic 方式

    目前我正在这样做 def getJSONString lst join rs for i in lst rs join str i 0 str i 1 join return rs 我称之为 rs getJSONString name va
  • 字符串中的 JSON 转义序列无效

    我正在使用一个 MySQL 数据库 它为谷歌地图编码了多边形 当我尝试以 json 形式返回查询时 jsonlint 抱怨 我不确定为什么它抱怨 我确实尝试转义 latlon 中的 但仍然得到相同的错误 Parse error on lin
  • JavaScript:打印前 12 个月——“March”打印两次?

    我正在尝试编写一个脚本来打印过去 12 个月的名称 由于本月是一月 因此应该打印 十二月 十一月 十月 九月 八月 七月 六月 可能 四月 行进 二月 一月 相反 它打印 March 两次 http jsfiddle net h69gm04
  • 有什么方法可以分析 firestore 数据库吗?

    我的 Firestore 数据库中的实体写入数量非常多 大多数路径的写入权限都受到限制 通过后端服务器使用 admin SDK 完成 只有极少数路径具有写访问权限 特别是仅对已通过身份验证 注册 加入和批准的特定组的用户而言 因此即使滥用的
  • 连接到 Amazon RDS Oracle 实例时如何解决“读取调用减一”错误

    我在 Amazon RDS 实例上运行 Oracle 11GR2 有时我会得到一个IO Error Got minus one from a read call当打电话给DriverManager getConnection getUrl
  • Git 哈希重复

    Git 允许使用以下命令检索提交的哈希值 git rev parse HEAD 这使33b316c or git rev parse short HEAD 这使33b316cbeeab3d69e79b9fb659414af4e7829a32
  • 错误:未找到名称“ngModel”的导出

    构建我的角度项目后 我收到错误 错误 未找到名称 ngModel 的导出 我的 UI 在 Docker 容器中运行 甚至不知道在哪里寻找这个 它在开发中工作正常 发球 有任何想法吗 我有同样的错误 尽管在开发中 事实证明我没有添加表单模块模
  • 如何从 Google Analytics 获取原始日志?

    是否可以从 Google Analytic 获取原始日志 有没有可以从GA生成原始日志的工具 不 您无法获取原始日志 但没有什么可以阻止您将完全相同的数据记录到您自己的 Web 服务器日志中 看看顽童代码并借用它 将以下两行更改为指向您的
  • 如何创建 AND 或 OR 表达式?

    我写了这个 if a 11 b 1 if a 1 AND b 1 但两者都不起作用 我也有同样的问题OR 如何编写包含以下内容的表达式OR or AND You use 对于 和 以及 为 或
  • 如何将 JavaScript onClick 处理程序添加到嵌入的 html 对象?

    我正在尝试将 onClick 处理程序添加到嵌入对象中 处理程序需要执行外部 js 文件中的函数 该文件通过链接到当前 html 文件button svg id buttonEmbed width 95 height 53 type ima
  • keras LSTM 层训练时间太长

    每当我在 Keras 上尝试 LSTM 模型时 似乎由于训练时间过长 该模型无法训练 例如 像这样的模型每步需要 80 秒来训练 def create model self inputs inputs input lstm placehol