保存和恢复 Keras BLSTM CTC 模型

2023-12-21

我一直在研究语音情感识别深度神经网络。我使用了具有 CTC 损失的 keras 双向 LSTM。我训练了模型并保存了它

model_json = model.to_json() with open("ctc_model.json", "w") as json_file: json_file.write(model_json) model.save_weights("ctc_weights.h5")

问题是我无法使用这个模型来测试看不见的数据,因为该模型接受 4 个参数作为输入并计算 ctc 损失..只需构建模型并训练。那么我怎样才能以只需要一个输入的方式保存模型呢?不是标签和长度。基本上我如何将模型保存为这个函数test_func = K.function([net_input], [output])

def ctc_lambda_func(args):
    y_pred, labels, input_length, label_length = args

   shift = 2
   y_pred = y_pred[:, shift:, :]
   input_length -= shift
   return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
def build_model(nb_feat, nb_class, optimizer='Adadelta'):
    net_input = Input(name="the_input", shape=(200, nb_feat))
    forward_lstm1  = LSTM(output_dim=64, 
                      return_sequences=True, 
                      activation="tanh"
                     )(net_input)
    backward_lstm1 = LSTM(output_dim=64, 
                      return_sequences=True, 
                      activation="tanh",
                      go_backwards=True
                     )(net_input)
    blstm_output1  = Merge(mode='concat')([forward_lstm1, backward_lstm1])

    forward_lstm2  = LSTM(output_dim=64, 
                      return_sequences=True, 
                      activation="tanh"
                     )(blstm_output1)
    backward_lstm2 = LSTM(output_dim=64, 
                      return_sequences=True, 
                      activation="tanh",
                      go_backwards=True
                     )(blstm_output1)
    blstm_output2  = Merge(mode='concat')([forward_lstm2, backward_lstm2])

    hidden = TimeDistributed(Dense(512, activation='tanh'))(blstm_output2)
    output = TimeDistributed(Dense(nb_class + 1, activation='softmax')) (hidden)

    labels = Input(name='the_labels', shape=[1], dtype='float32')
    input_length = Input(name='input_length', shape=[1], dtype='int64')
    label_length = Input(name='label_length', shape=[1], dtype='int64')
    loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name="ctc")([output, labels, input_length, label_length])
    model = Model(input=[net_input, labels, input_length, label_length], output=[loss_out])
    model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=optimizer, metrics=[])

    test_func = K.function([net_input], [output])

    return model, test_func
model, test_func = build_model(nb_feat=nb_feat, nb_class=nb_class, optimizer=optimizer)
 for epoch in range(number_epoches):
     inputs_train = {'the_input': X_train[i:i+batch_size],
                    'the_labels': y_train[i:i+batch_size],
                    'input_length': np.sum(X_train_mask[i:i+batch_size], axis=1, dtype=np.int32),
                    'label_length': np.squeeze(y_train_mask[i:i+batch_size]),
                   }
     outputs_train = {'ctc': np.zeros([inputs_train["the_labels"].shape[0]])}

    ctcloss = model.train_on_batch(x=inputs_train, y=outputs_train)

    total_ctcloss += ctcloss * inputs_train["the_input"].shape[0] * 1.
loss_train[epoch] = total_ctcloss / X_train.shape[0]
Here is the my model summary

尝试以下解决方案:

import keras.backend as K

def get_prediction_function(model):
    input_tensor = model.layers[0].input
    output_tensor = model.layers[-5].output
    net_function = K.function([input_tensor, K.learning_phase()], [output_tensor])
    def _result_function(x):
        return net_function([x, 0])[0]
    return _result_function

现在您的网络功能可以通过以下方式获得:

test_function = get_prediction_function(model)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

保存和恢复 Keras BLSTM CTC 模型 的相关文章

  • ca 证书 Mac OS X

    我需要在emacs 上安装offlineimap 和mu4e 问题是配置 当我运行 Offlineimap 时 我得到 OfflineIMAP 6 5 5 Licensed under the GNU GPL v2 v2 or any la
  • Python GTK + webkit - 在 gtk.main() 之后插入 JavaScript

    我在终端中尝试了这个 一切正常 但是如果我在脚本内运行这个 我无法在 gtk main 之后插入 JavaScript import gtk import webkit w gtk Window b webkit WebView w add
  • Pandas 连接问题:列重叠但未指定后缀

    我有以下数据框 print df a mukey DI PI 0 100000 35 14 1 1000005 44 14 2 1000006 44 14 3 1000007 43 13 4 1000008 43 13 print df b
  • 组和平均 NumPy 矩阵

    假设我有一个任意的 numpy 矩阵 如下所示 arr 6 0 12 0 1 0 7 0 9 0 1 0 8 0 7 0 1 0 4 0 3 0 2 0 6 0 1 0 2 0 2 0 5 0 2 0 9 0 4 0 3 0 2 0 1 0
  • 神经网络不能立即重现?

    通过使用反向传播导数 弹性 的前馈神经网络中的随机权重初始化 误差图上的初始位置位于某个随机谷的顶部 该随机谷可能是也可能不是局部最小值 可以使用方法来克服局部最小值 但假设这些方法没有被使用 或者在给定的地形上不能很好地工作 那么神经网络
  • 如何使用 i18n 切换器将“LANGUAGE_CODE”保存到数据库,以便在 Django 中的不同浏览器中语言不会更改?

    有什么办法可以改变它的值LANGUAGE CODE单击按钮 发送请求 时 settings py 中的变量会动态变化吗 我希望用户设置自己的 默认语言 他们的帐户 现在 用户可以使用下拉列表选择他们的首选语言 并且网站会得到完美的翻译 并且
  • 在Python中以交互方式执行多行语句

    我是 Python 世界的新手 这是我用 Python 编写的第一个程序 我来自 R 世界 所以这对我来说有点不直观 当我执行时 In 15 import math import random random random math sqrt
  • 神经网络中的时间序列提前预测(N点提前预测)大规模迭代训练

    N 90 使用神经网络进行提前预测 我试图预测提前 3 分钟 即提前 180 点 因为我将时间序列数据压缩为每 2 个点的平均值为 1 所以我必须预测 N 90 超前预测 我的时间序列数据以秒为单位给出 值在 30 90 之间 它们通常从
  • reStructuredText:README.rst 未在 PyPI 上解析

    我有一个托管在 Github 和 PyPI 上的 Python 项目 在 Github 上 https github com sloria TextBlob blob master README rst https github com s
  • 如何使用 PyMongo 在重复键错误后继续插入

    如果我需要在 MongoDB 中插入尚不存在的文档 db stock update one document set document upsert True 将完成这项工作 如果我错了 请随时纠正我 但是 如果我有一个文档列表并想将它们全
  • 返回上个月的日期时间对象

    如果 timedelta 在它的构造函数中有一个月份参数就好了 那么最简单的方法是什么 EDIT 正如下面指出的那样 我并没有认真考虑这一点 我真正想要的是上个月的任何一天 因为最终我只会获取年份和月份 因此 给定一个日期时间对象 返回的最
  • Pandas:将 pytz.FixedOffset 应用于系列

    我有一个带有timestamp列看起来像这样 0 2020 01 26 05 00 00 08 00 1 2020 01 26 06 00 00 08 00 Name timestamp dtype datetime64 ns pytz F
  • 为什么 __instancecheck__ 没有被调用?

    我有以下 python3 代码 class BaseTypeClass type def new cls name bases namespace kwd result type new cls name bases namespace p
  • 线性同余生成器 - 如何选择种子和统计检验

    我需要做一个线性同余生成器 它将成功通过所选的统计测试 我的问题是 如何正确选择发电机的数字以及 我应该选择哪些统计检验 我想 均匀性的卡方频率测试 每代收集10 000个号码的方法 将 0 1 细分为10个相等的细分 柯尔莫哥洛夫 斯米尔
  • 在python中读取PASCAL VOC注释

    我在 xml 文件中有注释 例如这个 它遵循 PASCAL VOC 约定
  • 在 scipy 中创建新的发行版

    我试图根据我拥有的一些数据创建一个分布 然后从该分布中随机抽取 这是我所拥有的 from scipy import stats import numpy def getDistribution data kernel stats gauss
  • 在 HDF5 (PyTables) 中存储 numpy 稀疏矩阵

    我在使用 PyTables 存储 numpy csr matrix 时遇到问题 我收到此错误 TypeError objects of type csr matrix are not supported in this context so
  • Python:无法使用 os.system() 打开文件

    我正在编写一个使用该应用程序的 Python 脚本pdftk http www pdflabs com tools pdftk the pdf toolkit 几次来执行某些操作 例如 我可以在 Windows 命令行 shell 中使用
  • 更新 SQLAlchemy 中的特定行

    我将 SQLAlchemy 与 python 一起使用 我想更新表中等于此查询的特定行 UPDATE User SET name user WHERE id 3 我通过 sql alchemy 编写了这段代码 但它不起作用 session
  • Streamlabs API 405 响应代码

    我正在尝试使用Streamlabs API https dev streamlabs com Streamlabs API 使用 Oauth2 来创建应用程序 因此 首先我将使用我的应用程序的用户发送到一个授权链接 其中包含我的应用程序的客

随机推荐

  • 如何导入 iTerm2 配置文件?

    这个问题 https stackoverflow com a 23356086 332936帮助我如何export配置文件配置文件 但我该如何import该文件到我的新机器上的 iterm2 中吗 我导出的文件名为com googlecod
  • 可可触摸问题。是否应该保留[NSMutableArray array]?

    这是我正在编写的一些代码的要点 我担心我没有正确解决 NSMutableArray 上数组类方法的保留 释放问题 下面的内容真的会泄漏内存吗 for a while do stuff NSMutableArray a nil do stuf
  • Spring编码不正确

    RequestMapping value getCategoryAspectValues method RequestMethod GET ResponseBody public Map
  • mysql中如何限制变量的增量

    这是我的桌子 id idName fldName fld Date 1 1 Marlon 2013 06 03 2 1 Marlon 2013 06 05 3 1 Marlon 2013 06 07 4 1 Marlon 2013 06 0
  • Spark 行转 JSON

    我想从 Spark v 1 6 使用 scala 数据帧创建 JSON 我知道有一个简单的解决方案df toJSON 但是 我的问题看起来有点不同 例如 考虑具有以下列的数据框 A B C1 C2 C3 1 test ab 22 TRUE
  • 在 XCode 编译上运行 Bash 脚本 - 在哪里获取构建变量列表?

    当使用 XCode 编译 Cocoa 应用程序时 我在构建阶段运行自定义 Bash 脚本 不幸的是 我必须拼写出完整路径 相反 我几乎可以肯定 我可以在 Bash 中使用一些变量 其中之一可能会覆盖它 这是我正在运行的 Users mike
  • 为什么 Meteor 使用纤程而不是 Promise 或异步或其他东西?

    为什么 Meteor 使用纤程而不是 Promise 或 async 或者可能留下异步调用 纤维有什么好处 有人可以解释一下这个架构决定吗 直接从马嘴里说出来 http www quora com Node js Why is Meteor
  • 在从基类派生类的对象上调用时的“this”关键字类型

    如果我有这样的事情 class Base public void Write if this is Derived this Name calls Name Method of Base class i e prints Base Deri
  • 如何计算不接触 UIView 的随机 CGPoint

    这是一个示例视图 我想计算一个框架CGPoint我可以在哪里生成另一张卡 UIView 而不触及任何现有卡 当然 这是可选的 因为视图可能充满卡片 因此没有空闲位置 这就是我如何在屏幕上看到任何卡片以及我的功能现在的样子 func free
  • 在 Windows 上安装 pymssql 时遇到问题

    我在 Windows 上找不到对 pymssql 安装支持的强大支持 我正在尝试通过另一个员工的 python 包装器连接到企业数据库 这个包装器需要我安装 pymssql 这RTFM http pymssql org en v2 1 2
  • 在 Rails 4 中使用 Foundation 和 Turbolinks 时出现问题

    我有一个带有 2 个按钮的标题 登录和注册像这样 http postimg org image tfcwebeoz 当我点击其中一个时 会出现一个窗口 窗户打开 http postimg org image lv7gn2ks3 为此 我使用
  • 创建一个 HTMLCollection

    我正在尝试垫片Element prototype children http www w3 org TR domcore dom element children应该返回一个HTML集合 http www w3 org TR domcore
  • 在ExtJS中,当我显示网格时如何加载商店?

    在ExtJS中 当我显示网格时如何加载商店 我希望商店仅在显示网格时加载 用户单击按钮来显示网格 因此预先加载商店是浪费的 我尝试过afterrender侦听器 但它在错误的位置呈现负载掩码 并且afterlayout每次调整网格大小时 侦
  • SpringBoot @WebMvcTest 和 @MockBean 未按预期工作

    看起来 WebMvcTest and MockBean没有按预期工作 也许我错过了一些东西 我有一个带有一些我正在嘲笑的依赖项的控制器 MockBean 但是应用程序无法启动 因为它找不到另一个我认为在这种情况下不需要的 bean 控制器
  • 为什么 joint_tests 函数(emmeans 包)的结果没有显示模型的交互之一?

    我运行 GLMM adaptive 模型 我正在执行资源选择函数 并且使用 joint tests 函数 emmeans 包 来计算模型中项的联合测试 问题是其中一种相互作用没有出现在结果中 模型是 mod hinc lt mixed mo
  • 批量 - 根据最后 2 个字符复制文件夹

    我在网上搜索后找不到解决方案 或者无法使它们适应我的问题 我希望仅当任何子文件夹的最后两个字符为 14 时 才能批量从文件夹 TEMP 几千个子文件夹 复制子文件夹 for d f in temp 14 do md c somewhere
  • Js 音频音量滑块

    我对此很陌生 我有一个问题 如何放置音量滑块 谢谢你 我暂时找不到任何适合我的代码 希望您的帮助 HTML a class fa fa play JavaScript a
  • PyPlot 将替代 y 轴移动到背景

    在 pyplot 中 您可以使用以下命令更改不同图形的顺序zorder选项或通过更改顺序plot 命令 但是 当您通过添加替代轴时ax2 twinx 新轴将始终覆盖旧轴 如文档 http matplotlib org api pyplot
  • Jersey 2.0 中 GZIPContentEncodingFilter 的等价物是什么

    我正在将 Jerset 1 x 客户端项目迁移到 Jersey 2 0 我找到GZIPContentEncodingFilter不再存在 有类似的东西吗 我绊倒了GZIPEncoder但不知道如何插入 在 Jersey 1 17 中我使用
  • 保存和恢复 Keras BLSTM CTC 模型

    我一直在研究语音情感识别深度神经网络 我使用了具有 CTC 损失的 keras 双向 LSTM 我训练了模型并保存了它 model json model to json with open ctc model json w as json