【Tensorflow2.0】8、tensorflow2.0_hdf5_savedmodel_pb模型转换[1]

2023-11-01


2022年7月更新:现在tensorflow2版本已经发展到2.9,这些模型间的互转可以看官方文档中h5 saved_model各自的缺限,默认使用saved model来保存,推荐这种方式。
本文于2020年11月2日(本博客一年之后)增加更新实验,思路与本文大体相似,理解模型间的转换建议两者都看看,参见https://blog.csdn.net/u011119817/article/details/109447755。

1、训练模型

我们将把训练好的模型分别保存成hdf5和saved model格式,然后完成它们之间的互相转换以及分别转tensorflow1.x的pb格式,具体有:

  1. hdf5转saved model,并验证转换后的saved model与直接保存的saved model的无差异性(大小,精度)
  2. saved model转hdf5,并验证转换后的hdf5与直接保存的hdf5的无差异性(大小,精度)
  3. hdf5转pb,并验证转换后的pb与直接原始的的hdf5的无差异性(大小,精度)
  4. saved mode转pb,并验证转换后的pb与直接原始的的saved mode的无差异性(大小,精度)
  5. 对比hdf5所转pb与saved model所转pb的区别
import tensorflow as tf
import os
from functools import partial
import numpy as np
import shutil
print("tf.__version__")
tf.__version__
batch_size=64
epochs=6
regularizer=1e-3
total_train_samples=60000
total_test_samples=10000

output_folder="/tmp/test/hdf5_model"
output_folder1="/tmp/test/saved_model"
output_folder2="/tmp/test/pb_model"
for m in (output_folder,output_folder1,output_folder2):
    if os.path.exists(m):
        inc=input("The model(%s) saved path has exist,Do you want to delete and remake it?(y/n)"%m)
        while(inc.lower() not in ['y','n']):
            inc=input("The model saved path has exist,Do you want to delete and remake it?(y/n)")
        if inc.lower()=='y':
            shutil.rmtree(m)
            os.makedirs(m)
    elif not os.path.exists(m):
        os.makedirs(m)

The model(/tmp/test/hdf5_model) saved path has exist,Do you want to delete and remake it?(y/n)y
The model(/tmp/test/saved_model) saved path has exist,Do you want to delete and remake it?(y/n)y
#指定显卡
physical_devices = tf.config.experimental.list_physical_devices('GPU')#列出所有可见显卡
print("All the available GPUs:\n",physical_devices)
if physical_devices:
    gpu=physical_devices[0]#显示第一块显卡
    tf.config.experimental.set_memory_growth(gpu, True)#根据需要自动增长显存
    tf.config.experimental.set_visible_devices(gpu, 'GPU')#只选择第一块
All the available GPUs:
 [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
#准备数据
fashion_mnist=tf.keras.datasets.fashion_mnist
(train_x,train_y),(test_x,test_y)=fashion_mnist.load_data()

train_x,test_x = train_x[...,np.newaxis]/255.0,test_x[...,np.newaxis]/255.0
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
train_ds = tf.data.Dataset.from_tensor_slices((train_x,train_y))
test_ds = tf.data.Dataset.from_tensor_slices((test_x,test_y))
 
train_ds=train_ds.shuffle(buffer_size=batch_size*10).batch(batch_size).prefetch(buffer_size = tf.data.experimental.AUTOTUNE).repeat()
test_ds = test_ds.batch(batch_size).prefetch(buffer_size = tf.data.experimental.AUTOTUNE)#不加repeat,执行一次就行
#定义模型
l2 = tf.keras.regularizers.l2(regularizer)#定义模型正则化方法
ini = tf.keras.initializers.he_normal()#定义参数初始化方法
conv2d = partial(tf.keras.layers.Conv2D,activation='relu',padding='same',kernel_regularizer=l2,bias_regularizer=l2)
fc = partial(tf.keras.layers.Dense,activation='relu',kernel_regularizer=l2,bias_regularizer=l2)
maxpool=tf.keras.layers.MaxPooling2D
dropout=tf.keras.layers.Dropout
def test_model():
    x_input = tf.keras.layers.Input(shape=(28,28,1),name='input_node')
    x = conv2d(128,(5,5))(x_input)
    x = maxpool((2,2))(x)
    x = conv2d(256,(5,5))(x)
    x = maxpool((2,2))(x)
    x = tf.keras.layers.Flatten()(x)
    x = fc(128)(x)
    x_output=fc(10,activation=None,name='output_node')(x)
    model = tf.keras.models.Model(inputs=x_input,outputs=x_output) 
    return model
model = test_model()
print(model.summary())
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_node (InputLayer)      [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 28, 28, 128)       3328      
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 256)       819456    
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 256)         0         
_________________________________________________________________
flatten (Flatten)            (None, 12544)             0         
_________________________________________________________________
dense (Dense)                (None, 128)               1605760   
_________________________________________________________________
output_node (Dense)          (None, 10)                1290      
=================================================================
Total params: 2,429,834
Trainable params: 2,429,834
Non-trainable params: 0
_________________________________________________________________
None
#编译模型
initial_learning_rate=0.01

optimizer = tf.keras.optimizers.SGD(learning_rate=initial_learning_rate,momentum=0.95)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics=['accuracy','sparse_categorical_crossentropy']
model.compile(optimizer=optimizer,loss=loss,metrics=metrics)
#训练模型
H=model.fit(train_ds,epochs=6,
            steps_per_epoch=np.floor(len(train_x)/batch_size).astype(np.int32),
            validation_data=test_ds,
            validation_steps=np.ceil(len(test_x)/batch_size).astype(np.int32),
            verbose=1)
Train for 937 steps, validate for 157 steps
Epoch 1/6
937/937 [==============================] - 78s 84ms/step - loss: 0.9485 - accuracy: 0.8000 - sparse_categorical_crossentropy: 1.5262 - val_loss: 0.6896 - val_accuracy: 0.8527 - val_sparse_categorical_crossentropy: 1.3110
Epoch 2/6
937/937 [==============================] - 76s 82ms/step - loss: 0.5635 - accuracy: 0.8824 - sparse_categorical_crossentropy: 1.2657 - val_loss: 0.5024 - val_accuracy: 0.8892 - val_sparse_categorical_crossentropy: 1.2297
Epoch 3/6
937/937 [==============================] - 76s 81ms/step - loss: 0.4517 - accuracy: 0.8971 - sparse_categorical_crossentropy: 1.2169 - val_loss: 0.4438 - val_accuracy: 0.8928 - val_sparse_categorical_crossentropy: 1.1814
Epoch 4/6
937/937 [==============================] - 76s 81ms/step - loss: 0.4077 - accuracy: 0.9015 - sparse_categorical_crossentropy: 1.1756 - val_loss: 0.4136 - val_accuracy: 0.8995 - val_sparse_categorical_crossentropy: 1.1607
Epoch 5/6
937/937 [==============================] - 76s 81ms/step - loss: 0.3850 - accuracy: 0.9051 - sparse_categorical_crossentropy: 1.1504 - val_loss: 0.3992 - val_accuracy: 0.8978 - val_sparse_categorical_crossentropy: 1.1329
Epoch 6/6
937/937 [==============================] - 76s 81ms/step - loss: 0.3719 - accuracy: 0.9093 - sparse_categorical_crossentropy: 1.1388 - val_loss: 0.4391 - val_accuracy: 0.8844 - val_sparse_categorical_crossentropy: 1.1219
#分别保存两种格式的模型
model.save(filepath=os.path.join(output_folder,'hdf5_model.h5'),save_format='h5')
model.save(filepath=output_folder1,save_format='tf')
#报的warning信息在tensorflow官网的例子中同样存在
INFO:tensorflow:Assets written to: /tmp/test/saved_model/assets
%%bash
echo -e "hdf5 model information...\n"
tree "/tmp/test/hdf5_model"

du -ah "/tmp/test/hdf5_model"

echo -e "saved model information...\n"

tree "/tmp/test/saved_model"

du -ah "/tmp/test/saved_model"
hdf5 model information...

/tmp/test/hdf5_model
└── hdf5_model.h5

0 directories, 1 file
19M	/tmp/test/hdf5_model/hdf5_model.h5
19M	/tmp/test/hdf5_model
saved model information...

/tmp/test/saved_model
├── assets
├── saved_model.pb
└── variables
    ├── variables.data-00000-of-00002
    ├── variables.data-00001-of-00002
    └── variables.index

2 directories, 4 files
4.0K	/tmp/test/saved_model/assets
216K	/tmp/test/saved_model/saved_model.pb
19M	/tmp/test/saved_model/variables/variables.data-00001-of-00002
4.0K	/tmp/test/saved_model/variables/variables.index
8.0K	/tmp/test/saved_model/variables/variables.data-00000-of-00002
19M	/tmp/test/saved_model/variables
19M	/tmp/test/saved_model
#选取第一个样本来做为测试样本,用来评估不同模型之间的精度
test_sample = train_x[0:1]
test_y=train_y[0]
out = model.predict(test_sample)
print("probs:",out[0])
print("true label:{} pred label:{}".format(test_y,np.argmax(out)))

probs: [-2.0793445e+00 -2.2612031e+00 -1.8440809e+00 -1.1460640e+00
 -1.9762940e+00  8.9537799e-03 -3.4592066e+00  3.3828874e+00
  1.5507856e-01  9.2633562e+00]
true label:9 pred label:9

2、各种模型间互转并验证

2.1 hdf5转saved model

rm -rf /tmp/test/hdf52saved_model
ls /tmp/test/hdf5_model
hdf5_model.h5
tf.keras.backend.clear_session()
hdf5_model = tf.keras.models.load_model("/tmp/test/hdf5_model/hdf5_model.h5")
hdf5_model.save("/tmp/test/hdf52saved_model",save_format='tf')
INFO:tensorflow:Assets written to: /tmp/test/hdf52saved_model/assets
%%bash
echo -e "hdf5 model information...\n"
tree "/tmp/test/hdf5_model"

du -ah "/tmp/test/hdf5_model"

echo -e "saved model information...\n"

tree "/tmp/test/saved_model"

du -ah "/tmp/test/saved_model"

echo -e "saved model information...\n"

tree "/tmp/test/hdf52saved_model"

du -ah "/tmp/test/hdf52saved_model"
hdf5 model information...

/tmp/test/hdf5_model
└── hdf5_model.h5

0 directories, 1 file
19M	/tmp/test/hdf5_model/hdf5_model.h5
19M	/tmp/test/hdf5_model
saved model information...

/tmp/test/saved_model
├── assets
├── saved_model.pb
└── variables
    ├── variables.data-00000-of-00002
    ├── variables.data-00001-of-00002
    └── variables.index

2 directories, 4 files
4.0K	/tmp/test/saved_model/assets
216K	/tmp/test/saved_model/saved_model.pb
19M	/tmp/test/saved_model/variables/variables.data-00001-of-00002
4.0K	/tmp/test/saved_model/variables/variables.index
8.0K	/tmp/test/saved_model/variables/variables.data-00000-of-00002
19M	/tmp/test/saved_model/variables
19M	/tmp/test/saved_model
saved model information...

/tmp/test/hdf52saved_model
├── assets
├── saved_model.pb
└── variables
    ├── variables.data-00000-of-00002
    ├── variables.data-00001-of-00002
    └── variables.index

2 directories, 4 files
4.0K	/tmp/test/hdf52saved_model/assets
224K	/tmp/test/hdf52saved_model/saved_model.pb
8.0K	/tmp/test/hdf52saved_model/variables/variables.data-00001-of-00002
4.0K	/tmp/test/hdf52saved_model/variables/variables.index
19M	/tmp/test/hdf52saved_model/variables/variables.data-00000-of-00002
19M	/tmp/test/hdf52saved_model/variables
19M	/tmp/test/hdf52saved_model

2.2 saved model转hdf5

tf.keras.backend.clear_session()
saved_model = tf.keras.models.load_model("/tmp/test/saved_model")
saved_model.save("/tmp/test/hdf5_model/saved2hdf5_model.h5",save_format='h5')
---------------------------------------------------------------------------

NotImplementedError                       Traceback (most recent call last)

<ipython-input-144-322a2edc9a6e> in <module>
      1 tf.keras.backend.clear_session()
      2 saved_model = tf.keras.models.load_model("/tmp/test/saved_model")
----> 3 saved_model.save("/tmp/test/hdf5_model/saved2hdf5_model.h5",save_format='h5')


~/miniconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options)
    973     """
    974     saving.save_model(self, filepath, overwrite, include_optimizer, save_format,
--> 975                       signatures, options)
    976 
    977   def save_weights(self, filepath, overwrite=True, save_format=None):


~/miniconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options)
    103         not isinstance(model, sequential.Sequential)):
    104       raise NotImplementedError(
--> 105           'Saving the model to HDF5 format requires the model to be a '
    106           'Functional model or a Sequential model. It does not work for '
    107           'subclassed models, because such models are defined via the body of '


NotImplementedError: Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model. It does not work for subclassed models, because such models are defined via the body of a Python method, which isn't safely serializable. Consider saving to the Tensorflow SavedModel format (by setting save_format="tf") or using `save_weights`.

以上信息说明,从saved model是无法转换成hdf5模型的,所以个人感觉在训练过程中保存hdf5格式的模型比较好

2.3 所有模型精度测试

测试三个模型的精度,原始hdf5模型,原始saved model,hdf5转换的saved model

origin_hdf5_model = tf.keras.models.load_model("/tmp/test/hdf5_model/hdf5_model.h5")
origin_saved_model = tf.keras.models.load_model("/tmp/test/saved_model")
converted_saved_model = tf.keras.models.load_model("/tmp/test/hdf52saved_model")
out1 = origin_hdf5_model.predict(test_sample)
out2 = origin_saved_model.predict(test_sample)
out3 = converted_saved_model.predict(test_sample)
print("probs:",out1[0])
print("true label:{} pred label:{}".format(test_y,np.argmax(out1)))
print("probs:",out2[0])
print("true label:{} pred label:{}".format(test_y,np.argmax(out2)))
print("probs:",out3[0])
print("true label:{} pred label:{}".format(test_y,np.argmax(out3)))
np.testing.assert_array_almost_equal(out,out1)
np.testing.assert_array_almost_equal(out,out2)
np.testing.assert_array_almost_equal(out,out3)
probs: [-2.0793445e+00 -2.2612031e+00 -1.8440809e+00 -1.1460640e+00
 -1.9762940e+00  8.9537799e-03 -3.4592066e+00  3.3828874e+00
  1.5507856e-01  9.2633562e+00]
true label:9 pred label:9
probs: [-2.0793445e+00 -2.2612031e+00 -1.8440809e+00 -1.1460640e+00
 -1.9762940e+00  8.9537799e-03 -3.4592066e+00  3.3828874e+00
  1.5507856e-01  9.2633562e+00]
true label:9 pred label:9
probs: [-2.0793445e+00 -2.2612031e+00 -1.8440809e+00 -1.1460640e+00
 -1.9762940e+00  8.9537799e-03 -3.4592066e+00  3.3828874e+00
  1.5507856e-01  9.2633562e+00]
true label:9 pred label:9

可以看到结果完全一致,模型转换没有问题

2.4 hdf5和saved模型转tensorflow1.x pb模型

我们需要在tensorflow2.0中使用tensorflow1.x内容

以下是hdf5转pb模型

import tensorflow.compat.v1 as tf1
tf1.reset_default_graph()
tf1.keras.backend.set_learning_phase(0) #调用模型前一定要执行该命令
tf1.disable_v2_behavior() #禁止tensorflow2.0的行为
#加载hdf5模型
hdf5_pb_model = tf.keras.models.load_model("/tmp/test/hdf5_model/hdf5_model.h5")
def freeze_session(session,keep_var_names=None,output_names=None,clear_devices=True):
    graph = session.graph
    with graph.as_default():
#         freeze_var_names = list(set(v.op.name for v in tf1.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
#         output_names += [v.op.name for v in tf1.global_variables()]
        print("output_names",output_names)
        input_graph_def = graph.as_graph_def()
#         for node in input_graph_def.node:
#             print('node:', node.name)
        print("len node1",len(input_graph_def.node))
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph =  tf1.graph_util.convert_variables_to_constants(session, input_graph_def,
                                                      output_names)
        
        outgraph = tf1.graph_util.remove_training_nodes(frozen_graph)#云掉与推理无关的内容
        print("##################################################################")
        for node in outgraph.node:
            print('node:', node.name)
        print("len node1",len(outgraph.node))
        return outgraph

frozen_graph = freeze_session(tf1.keras.backend.get_session(),output_names=[out.op.name for out in hdf5_pb_model.outputs])
tf1.train.write_graph(frozen_graph, output_folder2, "hdf52pb.pb", as_text=False)
output_names ['output_node/BiasAdd']
len node1 626
INFO:tensorflow:Froze 8 variables.
INFO:tensorflow:Converted 8 variables to const ops.
##################################################################
node: input_node
node: conv2d/kernel
node: conv2d/bias
node: conv2d/Conv2D
node: conv2d/BiasAdd
node: conv2d/Relu
node: max_pooling2d/MaxPool
node: conv2d_1/kernel
node: conv2d_1/bias
node: conv2d_1/Conv2D
node: conv2d_1/BiasAdd
node: conv2d_1/Relu
node: max_pooling2d_1/MaxPool
node: flatten/Reshape/shape
node: flatten/Reshape
node: dense/kernel
node: dense/bias
node: dense/MatMul
node: dense/BiasAdd
node: dense/Relu
node: output_node/kernel
node: output_node/bias
node: output_node/MatMul
node: output_node/BiasAdd
len node1 24





'/tmp/test/pb_model/hdf52pb.pb'

以下是saved model转pb模型

import tensorflow.compat.v1 as tf1
tf1.reset_default_graph()
tf1.keras.backend.set_learning_phase(0) #调用模型前一定要执行该命令
tf1.disable_v2_behavior() #禁止tensorflow2.0的行为
#加载hdf5模型
saved_pb_model = tf.keras.models.load_model("/tmp/test/saved_model")
def freeze_session(session,keep_var_names=None,output_names=None,clear_devices=True):
    graph = session.graph
    with graph.as_default():
#         freeze_var_names = list(set(v.op.name for v in tf1.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
#         output_names += [v.op.name for v in tf1.global_variables()]
        print("output_names",output_names)
        input_graph_def = graph.as_graph_def()
#         for node in input_graph_def.node:
#             print('node:', node.name)
        print("len node1",len(input_graph_def.node))
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph =  tf1.graph_util.convert_variables_to_constants(session, input_graph_def,
                                                      output_names)
        
        outgraph = tf1.graph_util.remove_training_nodes(frozen_graph)#云掉与推理无关的内容
        print("##################################################################")
        for node in outgraph.node:
            print('node:', node.name)
        print("len node1",len(outgraph.node))
        return outgraph

frozen_graph = freeze_session(tf1.keras.backend.get_session(),output_names=[out.op.name for out in saved_pb_model.outputs])
tf1.train.write_graph(frozen_graph, output_folder2, "saved2pb.pb", as_text=False)
output_names ['model/StatefulPartitionedCall']
len node1 304
INFO:tensorflow:Froze 8 variables.
INFO:tensorflow:Converted 8 variables to const ops.
##################################################################
node: conv2d/kernel
node: conv2d/bias
node: conv2d_1/kernel
node: conv2d_1/bias
node: dense/kernel
node: dense/bias
node: output_node/kernel
node: output_node/bias
node: input_1
node: model/StatefulPartitionedCall
len node1 10





'/tmp/test/pb_model/saved2pb.pb'

以下是转换后的saved model转pb

import tensorflow.compat.v1 as tf1
tf1.reset_default_graph()
tf1.keras.backend.set_learning_phase(0) #调用模型前一定要执行该命令
tf1.disable_v2_behavior() #禁止tensorflow2.0的行为
#加载hdf5模型
hdf52saved_pb_model = tf.keras.models.load_model("/tmp/test/hdf52saved_model/")
def freeze_session(session,keep_var_names=None,output_names=None,clear_devices=True):
    graph = session.graph
    with graph.as_default():
#         freeze_var_names = list(set(v.op.name for v in tf1.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
#         output_names += [v.op.name for v in tf1.global_variables()]
        print("output_names",output_names)
        input_graph_def = graph.as_graph_def()
#         for node in input_graph_def.node:
#             print('node:', node.name)
        print("len node1",len(input_graph_def.node))
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph =  tf1.graph_util.convert_variables_to_constants(session, input_graph_def,
                                                      output_names)
        
        outgraph = tf1.graph_util.remove_training_nodes(frozen_graph)#云掉与推理无关的内容
        print("##################################################################")
        for node in outgraph.node:
            print('node:', node.name)
        print("len node1",len(outgraph.node))
        return outgraph

frozen_graph = freeze_session(tf1.keras.backend.get_session(),output_names=[out.op.name for out in saved_pb_model.outputs])
tf1.train.write_graph(frozen_graph, output_folder2, "hdf52saved2pb.pb", as_text=False)
output_names ['model/StatefulPartitionedCall']
len node1 304
INFO:tensorflow:Froze 8 variables.
INFO:tensorflow:Converted 8 variables to const ops.
##################################################################
node: conv2d/kernel
node: conv2d/bias
node: conv2d_1/kernel
node: conv2d_1/bias
node: dense/kernel
node: dense/bias
node: output_node/kernel
node: output_node/bias
node: input_1
node: model/StatefulPartitionedCall
len node1 10





'/tmp/test/pb_model/hdf52saved2pb.pb'
%%bash
tree /tmp/test

du -ah /tmp/test
/tmp/test
├── hdf52saved_model
│   ├── assets
│   ├── saved_model.pb
│   └── variables
│       ├── variables.data-00000-of-00002
│       ├── variables.data-00001-of-00002
│       └── variables.index
├── hdf5_model
│   └── hdf5_model.h5
├── pb_model
│   ├── hdf52pb.pb
│   ├── hdf52saved2pb.pb
│   └── saved2pb.pb
└── saved_model
    ├── assets
    ├── saved_model.pb
    └── variables
        ├── variables.data-00000-of-00002
        ├── variables.data-00001-of-00002
        └── variables.index

8 directories, 12 files
19M	/tmp/test/hdf5_model/hdf5_model.h5
19M	/tmp/test/hdf5_model
9.3M	/tmp/test/pb_model/hdf52saved2pb.pb
9.3M	/tmp/test/pb_model/saved2pb.pb
9.3M	/tmp/test/pb_model/hdf52pb.pb
28M	/tmp/test/pb_model
4.0K	/tmp/test/saved_model/assets
216K	/tmp/test/saved_model/saved_model.pb
19M	/tmp/test/saved_model/variables/variables.data-00001-of-00002
4.0K	/tmp/test/saved_model/variables/variables.index
8.0K	/tmp/test/saved_model/variables/variables.data-00000-of-00002
19M	/tmp/test/saved_model/variables
19M	/tmp/test/saved_model
4.0K	/tmp/test/hdf52saved_model/assets
224K	/tmp/test/hdf52saved_model/saved_model.pb
8.0K	/tmp/test/hdf52saved_model/variables/variables.data-00001-of-00002
4.0K	/tmp/test/hdf52saved_model/variables/variables.index
19M	/tmp/test/hdf52saved_model/variables/variables.data-00000-of-00002
19M	/tmp/test/hdf52saved_model/variables
19M	/tmp/test/hdf52saved_model
84M	/tmp/test

可以看到三个pb模型的大小是相同的,但了节点名称不一样,且打印出来的名称顺序在hdf5模型体现更好。
另外,hdf5模型不经保存也可以直接保存成pb,但是保存再读取和直接保存节点名称会变,但精度还是相同的。所以可以把节点名称打印出来分析。

2.5 加载并测试pb模型

有三个pb模型分别进行测试

import tensorflow.compat.v1 as tf1
import numpy as np

def load_graph(file_path):
    with tf1.gfile.GFile(file_path,'rb') as f:
        graph_def = tf1.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf1.Graph().as_default() as graph:
        tf1.import_graph_def(graph_def,input_map = None,return_elements = None,name = "",op_dict = None,producer_op_list = None)
    graph_nodes = [n for n in graph_def.node]
    return graph,graph_nodes

三个模型依次调用:

  • /tmp/test/pb_model/hdf52saved2pb.pb
  • /tmp/test/pb_model/saved2pb.pb
  • /tmp/test/pb_model/hdf52pb.pb

第一个模型

file_path='/tmp/test/pb_model/hdf52pb.pb'
graph,graph_nodes = load_graph(file_path)
print("num nodes",len(graph_nodes))
for node in graph_nodes:
    print('node:', node.name) 

num nodes 24
node: input_node
node: conv2d/kernel
node: conv2d/bias
node: conv2d/Conv2D
node: conv2d/BiasAdd
node: conv2d/Relu
node: max_pooling2d/MaxPool
node: conv2d_1/kernel
node: conv2d_1/bias
node: conv2d_1/Conv2D
node: conv2d_1/BiasAdd
node: conv2d_1/Relu
node: max_pooling2d_1/MaxPool
node: flatten/Reshape/shape
node: flatten/Reshape
node: dense/kernel
node: dense/bias
node: dense/MatMul
node: dense/BiasAdd
node: dense/Relu
node: output_node/kernel
node: output_node/bias
node: output_node/MatMul
node: output_node/BiasAdd
input_node = graph.get_tensor_by_name('input_node:0')
output = graph.get_tensor_by_name('output_node/BiasAdd:0')

config = tf1.ConfigProto()
config.gpu_options.allow_growth = True
# config.gpu_options.per_process_gpu_memory_fraction = 0.25# 设定GPU使用占比
config.gpu_options.visible_device_list = '0'  # '0,1'
config.allow_soft_placement = True
config.log_device_placement = False

with tf1.Session(config=config,graph=graph) as sess:
        logits = sess.run(output, feed_dict = {input_node:test_sample})
print("logits:",logits)
np.testing.assert_array_almost_equal(out,logits)
logits: [[-2.0793445e+00 -2.2612031e+00 -1.8440809e+00 -1.1460640e+00
  -1.9762940e+00  8.9537799e-03 -3.4592066e+00  3.3828874e+00
   1.5507856e-01  9.2633562e+00]]

从以上结果可以看到,hdf5转换的pb结果完全正确
另外经过验证,使用tensorflow2.x中tf.compat.v1 api转换成的pb模型,只能用tensorflow1.14和tensorflow1.15两个版本调用使用。
第二个模型

file_path='/tmp/test/pb_model/saved2pb.pb'
graph,graph_nodes = load_graph(file_path)
print("num nodes",len(graph_nodes))
for node in graph_nodes:
    print('node:', node.name)
---------------------------------------------------------------------------

InvalidArgumentError                      Traceback (most recent call last)

~/miniconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py in _import_graph_def_internal(graph_def, input_map, return_elements, validate_colocation_constraints, name, op_dict, producer_op_list)
    500         results = c_api.TF_GraphImportGraphDefWithResults(
--> 501             graph._c_graph, serialized, options)  # pylint: disable=protected-access
    502         results = c_api_util.ScopedTFImportGraphDefResults(results)


InvalidArgumentError: Input 1 of node model/StatefulPartitionedCall was passed float from conv2d/kernel:0 incompatible with expected resource.


During handling of the above exception, another exception occurred:


ValueError                                Traceback (most recent call last)

<ipython-input-108-4ad5ab9ba9bb> in <module>
      1 file_path='/tmp/test/pb_model/saved2pb.pb'
----> 2 graph,graph_nodes = load_graph(file_path)
      3 print("num nodes",len(graph_nodes))
      4 for node in graph_nodes:
      5     print('node:', node.name)


<ipython-input-103-6c7963fc55a7> in load_graph(file_path)
      7         graph_def.ParseFromString(f.read())
      8     with tf1.Graph().as_default() as graph:
----> 9         tf1.import_graph_def(graph_def,input_map = None,return_elements = None,name = "",op_dict = None,producer_op_list = None)
     10     graph_nodes = [n for n in graph_def.node]
     11     return graph,graph_nodes


~/miniconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py in new_func(*args, **kwargs)
    505                 'in a future version' if date is None else ('after %s' % date),
    506                 instructions)
--> 507       return func(*args, **kwargs)
    508 
    509     doc = _add_deprecated_arg_notice_to_docstring(


~/miniconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py in import_graph_def(graph_def, input_map, return_elements, name, op_dict, producer_op_list)
    403       name=name,
    404       op_dict=op_dict,
--> 405       producer_op_list=producer_op_list)
    406 
    407 


~/miniconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py in _import_graph_def_internal(graph_def, input_map, return_elements, validate_colocation_constraints, name, op_dict, producer_op_list)
    503       except errors.InvalidArgumentError as e:
    504         # Convert to ValueError for backwards compatibility.
--> 505         raise ValueError(str(e))
    506 
    507     # Create _DefinedFunctions for any imported functions.


ValueError: Input 1 of node model/StatefulPartitionedCall was passed float from conv2d/kernel:0 incompatible with expected resource.

可以看出,不论是原始的saved model还是hdf5转成的saved_model,都可以转pb,但使用中还是会报错。
另外一种调用hdf5转换的pb(或在tensorflow2.x中调用tensorflow1.x转的pb)

tf1.reset_default_graph()
tf1.enable_v2_behavior()#tensorflow2.x中调用tensorflow1.x的内容需要激活tensorflow2.x的特性
tf.keras.backend.clear_session()
def wrap_frozen_graph(graph_def, inputs, outputs):
    def _imports_graph_def():
        tf1.import_graph_def(graph_def, name="")
    wrapped_import = tf1.wrap_function(_imports_graph_def, [])
    import_graph = wrapped_import.graph
    return wrapped_import.prune(
        tf.nest.map_structure(import_graph.as_graph_element, inputs),
        tf.nest.map_structure(import_graph.as_graph_element, outputs))
file_path='/tmp/test/pb_model/hdf52pb.pb'
with open(file_path,'rb') as f:
    graph_def = tf1.GraphDef()
    graph_def.ParseFromString(f.read())
    for node in graph_def.node:
        print("node.name",node.name)
node.name input_node
node.name conv2d/kernel
node.name conv2d/bias
node.name conv2d/Conv2D
node.name conv2d/BiasAdd
node.name conv2d/Relu
node.name max_pooling2d/MaxPool
node.name conv2d_1/kernel
node.name conv2d_1/bias
node.name conv2d_1/Conv2D
node.name conv2d_1/BiasAdd
node.name conv2d_1/Relu
node.name max_pooling2d_1/MaxPool
node.name flatten/Reshape/shape
node.name flatten/Reshape
node.name dense/kernel
node.name dense/bias
node.name dense/MatMul
node.name dense/BiasAdd
node.name dense/Relu
node.name output_node/kernel
node.name output_node/bias
node.name output_node/MatMul
node.name output_node/BiasAdd
model_func = wrap_frozen_graph(
    graph_def, inputs='input_node:0',
    outputs='output_node/BiasAdd:0')

o=model_func(tf.constant(test_sample,dtype=tf.float32))

print(o)

np.testing.assert_array_almost_equal(out,o.numpy())
tf.Tensor(
[[-2.0793445e+00 -2.2612031e+00 -1.8440809e+00 -1.1460640e+00
  -1.9762940e+00  8.9537799e-03 -3.4592066e+00  3.3828874e+00
   1.5507856e-01  9.2633562e+00]], shape=(1, 10), dtype=float32)

总结

  1. tensorflow2.x保存的hdf5模型可以转tensorflow1.x的pb ,也可以转tensorflow2.x saved model
  2. saved model可以转pb ,但是转换后无法使用
  3. saved model不可以转换成hdf5模型
  4. 在tensorflow2.x中可以使用tensorflow1.x或tensorflow2.x的语法来调用,从而选则不同版本
  5. tensorflow2.x训练的模型 ,转换成pb后,只能用tensorflow1.14和1.15来调用。
    所以我们在以后可以只保存hdf5模型,这样可以使用tensorflow2.x来训练模型,如果要用tensorflow2.x推荐的格式,就把hdf5转换成save_model来用;如果要用旧的tensorflow1.x版本,可以把hdf5转换成pb来用。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【Tensorflow2.0】8、tensorflow2.0_hdf5_savedmodel_pb模型转换[1] 的相关文章

  • Facebook Messenger 机器人的日期选择器 webview - 无法将字段值带回机器人的输入字段

    我正在使用 Dialogflow 和 Messenger 开发聊天机器人 Webhook 是用 Python 3 x 编写的 我面临着如何再次将数据从 webview 传输到信使聊天窗口以继续与用户对话的问题 Messenger 聊天机器人
  • 如何在jsp页面中包含javascript

    我是 J2EE 和 Web 开发的新手 这是我的问题 我想在网页中包含 angular js 这是有效的版本 但我也想要一些本地的 javascript 文件 并且希望我想在本地目录中导入 angularjs
  • 获取Java中ResultSet返回的行数

    我用过一个ResultSet返回一定数量的行 我的代码是这样的 ResultSet res getData if res next System out println No Data Found while res next code t
  • 将主题应用到 v7 支持操作栏

    我正在使用support v7库来实现ActionBar在我的应用程序中 我的styles xml file
  • 如何使用 Word Automation 获取页面范围

    如何使用办公自动化找到 Microsoft Word 中第 n 页的范围 似乎没有 getPageRange n 函数 并且不清楚它们是如何划分的 这就是您从 VBA 执行此操作的方法 转换为 Matlab COM 调用应该相当简单 Pub
  • CSS3 背景渐变未验证,有人可以告诉我为什么吗?里面的代码示例

    有人能告诉我为什么下面的 css 没有验证吗 我一直在尝试自己研究这个问题 但没有运气 我读过的所有文档都说这是在 css3 中进行渐变的正确原因 header color white font size 12px font family
  • 在客户端系统中安装后桌面应用程序无法打开

    我目前正在使用 Visual Studio 2017 和 4 6 1 net 框架 我为桌面应用程序创建了安装文件 安装程序在我的系统中完美安装并运行 问题是安装程序在其他计算机上成功安装 但应用程序无法打开 edit 在客户端系统中下载了
  • JSF 2.0:如何添加 UIComponent 及其内容以查看根?

    我正在建立一个自定义UIComponent并在其中添加元素 和其他库存 UIComponents 该组件呈现正常 但无法从ViewRoot 假设我有 ResponseWriter writer Override public void en
  • Mac 操作系统屏幕上的 Git 自动补全

    我在 mac 上使用 git 并配置了自动完成功能 如下所示http www codethatmatters com 2010 01 git autocomplete in mac os x http www codethatmatters
  • 如何正确使用 std::condition_variable?

    我很困惑conditions variables以及如何 安全 使用它们 在我的应用程序中 我有一个创建 gui 线程的类 但是当 gui 是由 gui 线程构造时 主线程需要等待 情况与下面的函数相同 主线程创建互斥体 锁和conditi
  • Xpages SSJS 如何显示数组?

    我一直在学习 Xpages 编程 我们目前使用的是 domino 8 5 2 我逐渐熟悉显示 输入控件 并且使用它们显示来自后端多米诺骨牌文档 视图 非数组的作用域变量的信息 并取得了一些成功 我无法发现的是如何显示动态创建的作用域变量数组
  • aws-s3 gem 和 right_aws gem 之间的 Rails Paperclip 冲突。怎么解决?

    对于新应用程序 我想使用回形针将文件存储到 S3 我已经为另一个应用程序安装了 aws s3 gem 这似乎会导致一些问题 因为 Paperclip 应该使用 right aws 但正在尝试使用 aws s3 gem 但我不想从我的系统中删
  • WebAPI 自定义 ExceptionFilterAttribute 中的 Ninject 属性注入不起作用

    我正在尝试使用 Ninject 将 EventLogger 实例注入自定义 ExceptionFilterAttribute 中 每当我运行代码时 EventLogger 实例为空 我已经实现了 IFilterProvider 以类似的方式
  • AngularJS 应用程序:如何将 .js 文件包含到 index.html 中

    我是 angularJS 的新手 我设法使用 AngularJS 构建了一个phonegap应用程序 该应用程序正常并且运行良好 问题是 现在我对 angularJS 的工作原理有了更多的了解 至少我认为我已经了解了 我担心我的应用程序文件
  • 找不到 securityToken 的有效键映射

    我正在开发测试应用程序 用于在 MVC ASP net Visual studio 2013 中显示经过身份验证的身份声明 我已通过以下方式从活动目录进行身份验证 1 在解决方案中添加新的mvc项目 2 单击更改身份验证 3 选择组织账户
  • django admin 中内联模型的分页器

    我有这个简单的 django 模型 由一个传感器和特定传感器的值组成 每个日射强度计的值数量很多 gt 30k 是否可以以某种方式分页PyranometerValues在特定日期或一般情况下将分页器应用于管理内联视图 class Pyran
  • Microsoft SQL 数据库的 WebSocket 侦听器

    我目前正在开发一个项目 该项目必须使用 WebSockets 作为将数据传输到客户端的方式 基础设施看起来像这样 客户端 gt Web 服务器 gt Microsoft SQL 数据库 我想最理想的情况应该是这样的 客户端打开一个到服务器的
  • android.view.WindowLeaked - 使用对话框和新意图时

    我已经尝试了 stackoverflow 上提供的所有可能的解决方案 但我仍然在 logcat 中遇到此错误 活动 com xyz MainActivity 泄露了最初在此处添加的窗口 com android internal policy
  • Java 可变 BigInteger 类

    我正在使用 BigIntegers 进行计算 该计算使用一个调用 multiply 大约 1000 亿次的循环 并且从 BigInteger 创建新对象使其非常慢 我希望有人编写或找到了 MutableBigInteger 类 我在 jav
  • FCM(Firebase Cloud Messaging)如何发送到所有手机?

    我创建了一个小型应用程序 能够从 FCM 控制台接收推送通知 我现在想做的是向所有使用 API 安装应用程序的 Android 手机发送推送通知 这就是我完全迷失的地方 有没有办法在不收集所有注册ID的情况下将其发送到所有手机 这是否仅适用

随机推荐

  • 【JAVAWEB开发】基于Java+Servlet+Ajax+jsp网上购物系统设计实现

    哈喽 大家好呀 这篇给的大家带来的是网上购物系统设计 在传统电商时代 用户是先有需求再购买 用户对平台较为依赖 商家对消费者很难有直接的影响力 而如今社交 电商解决了产品质量的信息不对称问题 电商已经成为当今经济发展的一个重要领域 而网上购
  • 一张图看明白GPU原理

    GPU直通实现方式 通过虚拟化平台的直通技术可以将显卡直接给虚拟机使用 与物理机接入显卡效果基本一致 在询价上只要安装了对应显卡的显示驱动 显卡就可以为这个虚拟机提供高性能的图形能力 GPU虚拟化 共享能够将一个物理存在的显卡分享给多个虚拟
  • QPainterPath全功能解锁

    QPainterPath可以自动计算bounding和shap 前者决定了重绘区域 后者决定了碰撞边界 可以说 QPainterPath是绘制的最优解之一 但QPainterPath内并未直接提供缩放 旋转等功能 很多人借助QPainter
  • Qt QList和QLinkedList使用

    文章目录 1 QList 1 1 链表基础使用 添加 修改 查找 删除 1 2 迭代器使用 STL风格 Java风格 2 QLinkedList 1 QList 1 1 链表基础使用 添加 修改 查找 删除 链表初始化 添加元素 QList
  • 2022第三届全国大学生网络安全精英赛练习题(6)

    全国大学生网络安全精英赛 2022第三届全国大学生网络安全精英赛练习题 6 文章目录 全国大学生网络安全精英赛 2022第三届全国大学生网络安全精英赛练习题 6 总结 501 下列有关代理服务器说法错误的是 A 代理服务器访问模式是浏览器不
  • sort函数与结构体

    include
  • Java高级编程实验_java高级编程项目实践.ppt

    java高级编程项目实践 ppt 由会员分享 可在线阅读 更多相关 java高级编程项目实践 ppt 32页珍藏版 请在人人文库网上搜索 1 Java高级编程项目实践 徐铭 课程目录 第一部分 需求定义 第二部分 用户界面设计 第三部分 数
  • 停止开发GPT-4?我更加关注数据版权、信息安全和数字鸿沟问题

    近日 随着ChatGPT和GPT 4的迅猛发展 人工智能对于人类社会以及文明的影响将是我们需要重视的问题 有人认为ChatGPT的表现引人入胜 但同时也让人感到毛骨悚然 因此 AI是否可靠 是否会导致灾难 机器智能超过人类的 奇点 是否真正
  • 公共IPV6 dns大全

    dns是什么和公共ipv4可阅读本篇文章 dns大全 一 阿里ipv6 dns 阿里的dns好在于自家的服务器遍布全球 加上自家研究的CDN技术快稳定 强大的阿里云团队技术坚持也是国内首家支持IPv4和IPv6 双端加持 安全快速 2400
  • js数组常见操作方法总结

    0 将数组中所有name改成ChName Name改成EnName var arr1 name aa Name ss children name ww Name nn name ff Name ee let arr2 JSON parse
  • 过滤器 和 拦截器 的区别

    1 过滤器 Filter 过滤器配置比较简单 直接实现Filter 接口即可 也可以通过 WebFilter注解实现对特定URL拦截 看到Filter 接口中定义了三个方法 init 该方法在容器启动初始化过滤器时被调用 它在 Filter
  • matplotlib图表多曲线多纵轴绘制工具方法

    matplotlib是常用的可视化库 画折线图只要把列表plot进去就可以轻松展示 这里只弄折线图 其它图暂时不管 同一图上不同曲线数值大小差太多就能绘制成地板和天花板还不能给人家量纲去了 所以不同曲线需要不同纵轴才能清晰看出细小波动 要是
  • vscode代码上传到gitlab

    1 打开终端 1 1输入一下内容提交到本地仓库 PS D VueProject2 mall admin web gt git add PS D VueProject2 mall admin web gt git commit m 商品优化
  • PCB翘曲度

    为了正确放置 SMT 组件 PCB 必须保持完全平整 为了准确放置 贴片机必须将 SMT 组件释放到所有组件的电路板上方相同高度 如果 PCB 有翘曲 也就是说不平整 则机器在将元件放置在电路板上时 在释放元件时无法保持恒定的高度 这会影响
  • 德隆现象给中国企业的反思

    德隆现象给中国企业的反思 刘亚军 萨尼威投资管理顾问公司董事长兼首席咨询顾问 一 德隆 一个资本扩张神话的终结 最近 德隆继创造了自1992年进入快速成长以来 十二年形成了220亿的资产规模 在国内股市长期低迷的情况下 旗下的 老三股 屹立
  • Matlab数字图像处理--分别采用 5×5,9×9,15×15 和 25×25 大小的拉普拉斯算子对图像进行锐化滤波,并完成图像的锐化增强

    题目 代码 初始化 B为灰度图 B rgb2gary img i表示生成尺寸为i i的拉普拉斯算子 function init B i lap genlaplacian i img lap imfilter B lap replicate
  • 计算机协会管理,计算机爱好者协会内部管理制度

    计算机爱好者协会内部管理制度 由会员分享 可在线阅读 更多相关 计算机爱好者协会内部管理制度 5页珍藏版 请在人人文库网上搜索 1 计算机爱好者协会内部管理制度 1内部管理考核制度目录第一章总章 3 第二章组织工作制度 4 第三章普通会员权
  • Injection of autowired dependencies failed; 的解决办法!

    错误信息 严重 Exception sending context initialized event to listener instance of class org springframework web context Contex
  • 趣解面向对象

    小白自述 过去就听说 到面向对象的时候即使没有女朋友 都可以new好多个 啥时候我也能想new多少new多少 面向对象听了很多老师的课 感觉好绕啊 这个类套那个类 怎么套的也是一头雾水 怎么才能学好了面向对象嘛 好多人都说面向对象是java
  • 【Tensorflow2.0】8、tensorflow2.0_hdf5_savedmodel_pb模型转换[1]

    文章目录 1 训练模型 2 各种模型间互转并验证 2 1 hdf5转saved model 2 2 saved model转hdf5 2 3 所有模型精度测试 2 4 hdf5和saved模型转tensorflow1 x pb模型 2 5