keras上手系列之: 模型的保存

2023-05-16

如何将训练好的网络进行保存以便以后使用, 进行后续的研究呢?
首先,定义一个简单的LSTM模型:

from keras.models import Sequential
from keras.layers import LSTM, Dense
model = Sequential()
model.add(LSTM(4,input_shape=(1,8)))
model.add(Dense(1))

整体保存模型及参数

首先,安装python的h5py包.
sudo pip3 install h5py
之后调用model.save(filepath)将Keras模型和权重保存在一个HDF5文件中,该文件将包含:
- 模型的结构,以便重构该模型
- 模型的权重
- 训练配置(损失函数,优化器等)
- 优化器的状态,以便于从上次训练中断的地方开始
使用keras.models.load_model(filepath)来重新实例化之前训练好的模型,如果文件中存储了训练配置的话,该函数还会同时完成模型的编译

from keras.models import load_model
model.save('my_model.h5')  # creates a HDF5 file 'my_model.h5'
del model  # deletes the existing model
# returns a compiled model identical to the previous one
model = load_model('my_model.h5')

只保存模型的结构

可以用model.to_jason()将模型序列化保存为json文件.

# save as JSON
json_string = model.to_json()

例如上面LSTM网络的json_string就是:

json_string
Out[10]: '{"class_name": "Sequential", "config": [{"class_name": "LSTM", "config": {"name": "lstm_1", "trainable": true, "batch_input_shape": [null, 1, 8], "dtype": "float32", "return_sequences": false, "return_state": false, "go_backwards": false, "stateful": false, "unroll": false, "implementation": 0, "units": 4, "activation": "tanh", "recurrent_activation": "hard_sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0}}, {"class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "units": 1, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}], "keras_version": "2.0.8", "backend": "tensorflow"}'

里面记录了网络的整体结构, 各个层的参数设置等信息. 将json字符串保存到文件.

open('my_model_architecture.json','w').write(json_string)

当然,你也可以从保存好的json文件或yaml文件中载入模型:

# 读取json文件
from keras.models import model_from_json
json_string = open('my_model_architecture.json').read()
model = model_from_json(json_string)

除了json格式,还可以保存为yaml格式的字符串:

# save as YAML
yaml_string = model.to_yaml()
# 类似地,读取yaml文件
from keras.models import model_from_yaml
model = model_from_yaml(yaml_string)

保存模型权重等配置信息

经过调参后网络的输出精度比较满意后,可以将训练好的网络权重参数保存下来.
可通过下面的代码利用HDF5进行保存
model.save_weights(‘my_model_weights.h5’)

以后用的时候可以像这样加载模型:
model.load_weights(‘my_model_weights.h5’)

如果你需要加载权重到不同的网络结构(有些层一样)中,例如fine-tune或transfer-learning,你可以通过层名字来加载模型:
model.load_weights('my_model_weights.h5', by_name=True)
首先在建模时,最好对每一层都指定名字, 例如:

# 定义模型
model = Sequential()
model.add(LSTM(4, input_shape=(1, 8), name="lstm_old"))
model.add(Dense(1, name="dense_old"))
...
model.save_weights('my_model_weights.h5')

# 新模型, 重载了前一个模型训练好的LSTM层

model_new = Sequential()
model_new.add(LSTM(4, input_shape=(1, 8), name="lstm_old"))  # will be loaded
model_new.add(Dense(10, name="dense_new"))  # will not be loaded

# 载入LSTM层训练好的参数
model.load_weights('my_model_weights.h5', by_name=True)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

keras上手系列之: 模型的保存 的相关文章

随机推荐

  • tensorflow2.0系列(4): Eager Execution和Auto Graph

    目录 静态图的弊端Eager模式Eager execution的基本特性对 numpy 的支持 Auto Graph 动态图static analysis VS dynamic flow局部参数的可见域python collections
  • Eclipse离线安装ADT插件

    Eclipse安装 ADT插件 但是由于某些不可抗拒的原因连上 https dl ssl google com android eclipse 后 xff0c 始终无法更新 ADT插件 卡死在 Fetchingcontent jar上 解决
  • ubuntu在shell中把文件拷贝进U盘

    1 创建挂载位置 xff0c 例如 sudo mkdir mnt u 这个位置只要建好 xff0c 以后就可以不用再建了 2 用mount命令将U盘挂载在这个位置 sudo mount dev sdb1 mnt u 注意U盘的盘符不一定是
  • linux文件管理

    linux文件管理 计算机操作系统都采用了目录树的文件结构 linux中 xff1a 符号名称 根目录 bin常见用户口令 boot内核和启动文件 dev设备文件 home系统默认的普通用户主目录 etc系统和服务配置文件 lib系统函数库
  • C++和Windows平台的一些书籍

    从2010年学习编程以来 xff0c 到现在有差不多3年时间了 xff0c 过的真快啊 目前在深圳工作 xff0c 主要使用的是C 43 43 语言 xff0c 那么我就说说C 43 43 和Windows平台的书籍吧 1 C primer
  • ubuntu上运行C程序

    ubuntu版本为Ukylin14 04LTS 首先配置编辑器vim step1 xff1a 查看系统是否安装vim 打开终端 xff0c 输入vi xff0c 按下tab键 xff0c 如果列表里没有vim xff0c 说明系统没有安装
  • 怎么让ubuntu变得更加好用

    ubunut14 04LTS版本其实已经很好用了 但是也有一些小小的美中不足 以下设置是陆续收集 摸索到的可以让系统更好用的方法 1 在终端打开已经安装的应用程序时 xff0c 总是会显示一些错误信息 在 bin下添加x文件 xff1a c
  • linux命令(1):touch

    touch 命令 功能说明 xff1a 改变文件或目录时间 xff0c 包括存取时间和更改时间 语 法 xff1a 补充说明 xff1a 使用touch指令可更改文件或目录的日期时间 最常用用法 xff1a touch fileA 如果fi
  • bash shell命令(1);、&&、||

    xff1b 命令 按照先后顺序一次执行多个命令 xff0c 命令之间用 xff1b 分割 xff1a command 1 command 2 command 3 amp amp 命令 如果前一个命令 command 1 顺利执行 xff0c
  • linux命令(2):less

    less工具也是对文件或其它输出进行分页显示的工具 xff0c 比more的功能更强大 命令格式 xff1a less 参数 文件1 xff08 文件2 xff09 命令功能 xff1a less 与 more 类似 xff0c 但使用 l
  • [zz] linux下vi或vim操作Found a swap file by the name的原因及解决方法

    在linux下用vi或vim打开Test java文件时 root 64 localhost tmp vi Test java 出现了如下信息 xff1a E325 ATTENTION Found a swap file by the na
  • ubuntu中使用判断符号[]

    鸟哥的私房菜p270中13 3 2使用 符号有这样一个例子 xff1a vim sh06 sh 脚本内容如下 xff1a bin bash Program This program shows the user 39 s choice Hi
  • 深度学习caffe框架(1):如何快速上手caffe?

    初识caffe 安装caffe跑一个例子mnist配置caffe框架的深度学习网络结构输入数据 数据层的定义图片数据如何保存为lmdb格式 模型的保存和读取 caffe的代码层次参考 初识caffe 安装caffe 跑一个例子 mnist
  • 深度学习caffe框架(2): layer定义

    caffe的代码层次 首先让我们回顾一下caffe的代码层次 blob layer net和solver 其中blob是数据结构 layer是网络的层 net是将layer搭建成的网络 solver是网络BP时候的求解算法 本节主要介绍ca
  • 安装Qt及相关问题解决

    安装Qt及相关问题解决 Download Qt 1 Qt下载 关于Qt下载 xff0c 官网可以下载 但是需要填一大堆信息 非常麻烦 可以打开下面的链接 xff0c 里面有各版本Qt http download qt io archive
  • 可编程的SQL是什么样的?

    背景 如果你使用传统编程语言 xff0c 比如Python xff0c 那么恭喜你 xff0c 你可能需要解决大部分你不需要解决的问题 xff0c 用Python你相当于拿到了零部件 xff0c 而不是一辆能跑的汽车 你花了大量时间去组装汽
  • matplotlib 绘制动画

    matplotlib动画 载入matplotlib动画绘制工具 span class hljs import span class hljs keyword import span matplotlib animation span cla
  • Robust Real-Time Extreme Head Pose Estimation

    基本思路 xff1a 用RGB D 的摄像头 xff0c 利用RGB和深度信息对人脸进行三位建模和合成 之后建立了一个由33个人不同头部姿态点云合成数据组成的数据库Dali3DHP xff0c 基于级联决策树 xff08 5个 xff09
  • 如何将ipython的历史记录导出到.py文件中?

    python绝对是生产力工具 真的太好用了 python jupyter提供了非常好的交互编程方式 最棒的就是在数据分析过程中 可以把想法和代码实现放在一起 大大加速了分析过程 也使得代码的可读性更好 回到上面的问题 两种办法解决 xff1
  • keras上手系列之: 模型的保存

    如何将训练好的网络进行保存以便以后使用 进行后续的研究呢 首先 定义一个简单的LSTM模型 span class hljs keyword from span keras models span class hljs keyword imp