将 Tensorflow 模型转换为 tensorflow-lite (.tflite) 格式时出现问题

2024-02-18

我用 python 制作了一个用于图像分类的张量流模型。我使用的是 Windows 10。

我有一个Train.py我在其中定义图形的类build_graph()并训练模型train()。这里是main.py script:

#import fire
import numpy as np
import data_import as di
import os
import tensorflow as tf


class Train:
    __x_ = []
    __y_ = []
    __logits = []
    __loss = []
    __train_step = []
    __merged_summary_op = []
    __saver = []
    __session = []
    __writer = []
    __is_training = []
    __loss_val = []
    __train_summary = []
    __val_summary = []

    def __init__(self):
        pass

    def build_graph(self):
        self.__x_ = tf.placeholder("float", shape=[None, 60, 60, 3], name='X')
        self.__y_ = tf.placeholder("int32", shape=[None, 3], name='Y')
        self.__is_training = tf.placeholder(tf.bool)


        with tf.name_scope("model") as scope:
            conv1 = tf.layers.conv2d(inputs=self.__x_, filters=64,
                                 kernel_size=[5, 5],
                                 padding="same", activation=tf.nn.relu)
            pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

            conv2 = tf.layers.conv2d(inputs=pool1, filters=64, kernel_size=[5, 5], padding="same",
                                 activation=tf.nn.relu)

            pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)

            conv3 = tf.layers.conv2d(inputs=pool2, filters=32, kernel_size=[5, 5], padding="same",
                                 activation=tf.nn.relu)

            pool3 = tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2)

            pool3_flat = tf.reshape(pool3, [-1, 7 * 7 * 32])

            # FC layers
            FC1 = tf.layers.dense(inputs=pool3_flat, units=128, activation=tf.nn.relu)
            FC2 = tf.layers.dense(inputs=FC1, units=64, activation=tf.nn.relu)
            self.__logits = tf.layers.dense(inputs=FC2, units=3)


        # TensorFlow summary data to display in TensorBoard later
        with tf.name_scope("loss_func") as scope:
            self.__loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(logits=self.__logits, labels=self.__y_))
            self.__loss_val = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(logits=self.__logits, labels=self.__y_))

            # Add loss to tensorboard
            self.__train_summary = tf.summary.scalar("loss_train", self.__loss)
            self.__val_summary = tf.summary.scalar("loss_val", self.__loss_val)


        # summary data to be displayed on TensorBoard during training:
        with tf.name_scope("optimizer") as scope:
            global_step = tf.Variable(0, trainable=False)
            starter_learning_rate = 1e-3
            # decay every 10000 steps with a base of 0.96 function
            learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, 1000, 0.9,
                                                       staircase=True)
            self.__train_step = tf.train.AdamOptimizer(learning_rate).minimize(self.__loss, global_step=global_step)
            tf.summary.scalar("learning_rate", learning_rate)
            tf.summary.scalar("global_step", global_step)


        # Merge op for tensorboard
        self.__merged_summary_op = tf.summary.merge_all()
        # Build graph
        init = tf.global_variables_initializer()
        # Saver for checkpoints
        self.__saver = tf.train.Saver(max_to_keep=None)

        # Configure summary to output at given directory
        self.__session = tf.Session()
        self.__writer = tf.summary.FileWriter("./logs/flight_path", self.__session.graph)
        self.__session.run(init)



def train(self, save_dir='./model_files', batch_size=20):
    #Load dataset and labels
    x = np.asarray(di.load_images())
    y = np.asarray(di.load_labels())

    #Shuffle dataset
    np.random.seed(0)
    shuffled_indeces = np.arange(len(y))
    np.random.shuffle(shuffled_indeces)
    shuffled_x = x[shuffled_indeces].tolist()
    shuffled_y = y[shuffled_indeces].tolist()
    shuffled_y = tf.keras.utils.to_categorical(shuffled_y, 3)

    dataset = (shuffled_x, shuffled_y)
    dataset = tf.data.Dataset.from_tensor_slices(dataset)
    #dataset = dataset.shuffle(buffer_size=300)

    # Using Tensorflow data Api to handle batches
    dataset_train = dataset.take(200)
    dataset_train = dataset_train.repeat()
    dataset_train = dataset_train.batch(batch_size)

    dataset_test = dataset.skip(200)
    dataset_test = dataset_test.repeat()
    dataset_test = dataset_test.batch(batch_size)

    # Create an iterator
    iter_train = dataset_train.make_one_shot_iterator()
    iter_train_op = iter_train.get_next()
    iter_test = dataset_test.make_one_shot_iterator()
    iter_test_op = iter_test.get_next()

    # Build model graph
    self.build_graph()


    # Train Loop
    for i in range(10):
        batch_train = self.__session.run([iter_train_op])
        batch_x_train, batch_y_train = batch_train[0]
        # Print loss from time to time
        if i % 100 == 0:
            batch_test = self.__session.run([iter_test_op])
            batch_x_test, batch_y_test = batch_test[0]
            loss_train, summary_1 = self.__session.run([self.__loss,
                                                    self.__merged_summary_op],
                                                   feed_dict={self.__x_:
                                                                  batch_x_train,
                                                              self.__y_:
                                                                  batch_y_train,
                                                              self.__is_training: True})
            loss_val, summary_2 = self.__session.run([self.__loss_val,
                                              self.__val_summary],
                                             feed_dict={self.__x_: batch_x_test,
                                                        self.__y_: batch_y_test,
                                                        self.__is_training: False})
            print("Loss Train: {0} Loss Val: {1}".format(loss_train,
                                                 loss_val))
            # Write to tensorboard summary
            self.__writer.add_summary(summary_1, i)
            self.__writer.add_summary(summary_2, i)

        # Execute train op
        self.__train_step.run(session=self.__session, feed_dict={
            self.__x_: batch_x_train, self.__y_: batch_y_train,
            self.__is_training: True})
        print(i)


    # Once the training loop is over, we store the final model into a checkpoint file with op
    # __saver.save:

    # converter = tf.contrib.lite.TFLiteConverter.from_session(self.__session, [self.__x_], [self.__y_])
    # tflite_model = converter.convert()
    # open("MobileNet/ConvertedModelFile.tflite", "wb").write(tflite_model)

    # Save model
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        checkpoint_path = os.path.join(save_dir, "model.ckpt")
        filename = self.__saver.save(self.__session, checkpoint_path)
        tf.train.write_graph(self.__session.graph_def, save_dir, "save_graph.pbtxt")
        print("Model saved in file: %s" % filename)


if __name__ == '__main__':
    cnn = Train()
    cnn.train()

我尝试通过导出将 GraphDef 导出到 .tflite 文件GraphDef from tf.Session, Exporting a GraphDef from file and Exporting a SavedModel。全部都描述在这里转换器 Python API 指南 https://www.tensorflow.org/lite/convert/python_api.

来自 tf.Session 的 GraphDef

当我尝试导出时GraphDef from tf.Session指导,我收到以下错误:

Traceback (most recent call last):
  File "C:/Users/nermi/PycharmProjects/DronePathTracking/main.py", line 226, in <module>
    cnn.train()
  File "C:/Users/nermi/PycharmProjects/DronePathTracking/main.py", line 212, in train
    tflite_model = converter.convert()
  File "C:\Users\nermi\Python\Python36\lib\site-packages\tensorflow\contrib\lite\python\lite.py", line 453, in convert
    **converter_kwargs)
  File "C:\Users\nermi\Python\Python36\lib\site-packages\tensorflow\contrib\lite\python\convert.py", line 342, in toco_convert_impl
    input_data.SerializeToString())
  File "C:\Users\nermi\Python\Python36\lib\site-packages\tensorflow\contrib\lite\python\convert.py", line 135, in toco_convert_protos
    (stdout, stderr))
RuntimeError: TOCO failed see console for info.
b'Traceback (most recent call last):\r\n  File "c:\\users\\nermi\\python\\python36\\lib\\site-packages\\tensorflow\\contrib\\lite\\toco\\python\\tensorflow_wrap_toco.py", line 18, in swig_import_helper\r\n    fp, pathname, description = imp.find_module(\'_tensorflow_wrap_toco\', [dirname(__file__)])\r\n  File "c:\\users\\nermi\\python\\python36\\lib\\imp.py", line 297, in find_module\r\n    raise ImportError(_ERR_MSG.format(name), name=name)\r\nImportError: No module named \'_tensorflow_wrap_toco\'\r\n\r\nDuring handling of the above exception, another exception occurred:\r\n\r\nTraceback (most recent call last):\r\n  File "c:\\users\\nermi\\python\\python36\\lib\\runpy.py", line 193, in _run_module_as_main\r\n    "__main__", mod_spec)\r\n  File "c:\\users\\nermi\\python\\python36\\lib\\runpy.py", line 85, in _run_code\r\n    exec(code, run_globals)\r\n  File "C:\\Users\\nermi\\Python\\Python36\\Scripts\\toco_from_protos.exe\\__main__.py", line 5, in <module>\r\n  File "c:\\users\\nermi\\python\\python36\\lib\\site-packages\\tensorflow\\contrib\\lite\\toco\\python\\toco_from_protos.py", line 22, in <module>\r\n    from tensorflow.contrib.lite.toco.python import tensorflow_wrap_toco\r\n  File "c:\\users\\nermi\\python\\python36\\lib\\site-packages\\tensorflow\\contrib\\lite\\toco\\python\\tensorflow_wrap_toco.py", line 28, in <module>\r\n    _tensorflow_wrap_toco = swig_import_helper()\r\n  File "c:\\users\\nermi\\python\\python36\\lib\\site-packages\\tensorflow\\contrib\\lite\\toco\\python\\tensorflow_wrap_toco.py", line 20, in swig_import_helper\r\n    import _tensorflow_wrap_toco\r\nModuleNotFoundError: No module named \'_tensorflow_wrap_toco\'\r\n'
None

导出已保存的模型

当我尝试使用导出时Exporting a SavedModel指导我的export_saved_model.py脚本我收到以下错误:

Traceback (most recent call last):
  File "C:/Users/nermi/PycharmProjects/DronePathTracking/export_saved_model.py", line 5, in <module>
    converter = tf.contrib.lite.TFLiteConverter.from_saved_model(saved_model_dir)
  File "C:\Users\nermi\Python\Python36\lib\site-packages\tensorflow\contrib\lite\python\lite.py", line 340, in from_saved_model
    output_arrays, tag_set, signature_key)
  File "C:\Users\nermi\Python\Python36\lib\site-packages\tensorflow\contrib\lite\python\convert_saved_model.py", line 239, in freeze_saved_model
    meta_graph = get_meta_graph_def(saved_model_dir, tag_set)
  File "C:\Users\nermi\Python\Python36\lib\site-packages\tensorflow\contrib\lite\python\convert_saved_model.py", line 61, in get_meta_graph_def
    return loader.load(sess, tag_set, saved_model_dir)
  File "C:\Users\nermi\Python\Python36\lib\site-packages\tensorflow\python\saved_model\loader_impl.py", line 196, in load
    loader = SavedModelLoader(export_dir)
  File "C:\Users\nermi\Python\Python36\lib\site-packages\tensorflow\python\saved_model\loader_impl.py", line 212, in __init__
    self._saved_model = _parse_saved_model(export_dir)
  File "C:\Users\nermi\Python\Python36\lib\site-packages\tensorflow\python\saved_model\loader_impl.py", line 82, in _parse_saved_model
    constants.SAVED_MODEL_FILENAME_PB))
OSError: SavedModel file does not exist at: model_files/{saved_model.pbtxt|saved_model.pb}

The export_saved_model.py:

import tensorflow as tf

saved_model_dir = "model_files"

converter = tf.contrib.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
open("MobileNet/converted_model.tflite", "wb").write(tflite_model)

从文件导出 GraphDef

最后,我有以下内容freeze_model.py冻结已保存模型的脚本:

from tensorflow.python.tools import freeze_graph

# Freeze the graph
save_path="C:/Users/nermi/PycharmProjects/DronePathTracking/model_files/" #directory to model files
MODEL_NAME = 'my_model' #name of the model optional
input_graph_path = save_path+'save_graph.pbtxt'#complete path to the input graph
checkpoint_path = save_path+'model.ckpt' #complete path to the model's checkpoint file
input_saver_def_path = ""
input_binary = False
output_node_names = "X, Y" #output node's name. Should match to that mentioned in your code
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_frozen_graph_name = save_path+'frozen_'+MODEL_NAME+'.pb' # the name of .pb file you would like to give
clear_devices = True


def freeze():
    freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                              input_binary, checkpoint_path, output_node_names,
                              restore_op_name, filename_tensor_name,
                              output_frozen_graph_name, clear_devices, "")


freeze()

但是当我尝试转换我的frozen_my_model.pb与我的tfliteexport_to_tflite.py script:

import tensorflow as tf

grap_def_file = "model_files/frozen_my_model.pb" # the .pb file

input_arrays = ["X"] #Input node
output_arrays = ["Y"] #Output node

converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph(
    grap_def_file, input_arrays, output_arrays
)

tflite_model = converter.convert()

open("MobileNet/my_model.tflite", "wb").write(tflite_model)

我收到以下错误:

Traceback (most recent call last):
  File "C:/Users/nermi/PycharmProjects/DronePathTracking/export_to_tflite.py", line 12, in <module>
    tflite_model = converter.convert()
  File "C:\Users\nermi\Python\Python36\lib\site-packages\tensorflow\contrib\lite\python\lite.py", line 453, in convert
    **converter_kwargs)
  File "C:\Users\nermi\Python\Python36\lib\site-packages\tensorflow\contrib\lite\python\convert.py", line 342, in toco_convert_impl
    input_data.SerializeToString())
  File "C:\Users\nermi\Python\Python36\lib\site-packages\tensorflow\contrib\lite\python\convert.py", line 135, in toco_convert_protos
    (stdout, stderr))
RuntimeError: TOCO failed see console for info.
b'Traceback (most recent call last):\r\n  File "c:\\users\\nermi\\python\\python36\\lib\\site-packages\\tensorflow\\contrib\\lite\\toco\\python\\tensorflow_wrap_toco.py", line 18, in swig_import_helper\r\n    fp, pathname, description = imp.find_module(\'_tensorflow_wrap_toco\', [dirname(__file__)])\r\n  File "c:\\users\\nermi\\python\\python36\\lib\\imp.py", line 297, in find_module\r\n    raise ImportError(_ERR_MSG.format(name), name=name)\r\nImportError: No module named \'_tensorflow_wrap_toco\'\r\n\r\nDuring handling of the above exception, another exception occurred:\r\n\r\nTraceback (most recent call last):\r\n  File "c:\\users\\nermi\\python\\python36\\lib\\runpy.py", line 193, in _run_module_as_main\r\n    "__main__", mod_spec)\r\n  File "c:\\users\\nermi\\python\\python36\\lib\\runpy.py", line 85, in _run_code\r\n    exec(code, run_globals)\r\n  File "C:\\Users\\nermi\\Python\\Python36\\Scripts\\toco_from_protos.exe\\__main__.py", line 5, in <module>\r\n  File "c:\\users\\nermi\\python\\python36\\lib\\site-packages\\tensorflow\\contrib\\lite\\toco\\python\\toco_from_protos.py", line 22, in <module>\r\n    from tensorflow.contrib.lite.toco.python import tensorflow_wrap_toco\r\n  File "c:\\users\\nermi\\python\\python36\\lib\\site-packages\\tensorflow\\contrib\\lite\\toco\\python\\tensorflow_wrap_toco.py", line 28, in <module>\r\n    _tensorflow_wrap_toco = swig_import_helper()\r\n  File "c:\\users\\nermi\\python\\python36\\lib\\site-packages\\tensorflow\\contrib\\lite\\toco\\python\\tensorflow_wrap_toco.py", line 20, in swig_import_helper\r\n    import _tensorflow_wrap_toco\r\nModuleNotFoundError: No module named \'_tensorflow_wrap_toco\'\r\n'
None

额外信息

当我将模型保存在 model_files 目录中时,它看起来像这样:

我尝试了很多事情,但没有运气。

任何帮助表示赞赏!


Windows 上的 TOCO 是有问题的。我遇到过这样的问题,直到找到解决方案。解决方案是将所有保存的模型或 graphdef 上传到 Google Colab 笔记本. Then,

  1. 与 GPU 或 TPU 运行时连接。 (更改运行时类型选项)
  2. 在运行时上传saved_model。(左上角的文件部分)
  3. 在您提到的一个单元格中编写相同的脚本。
  4. 确保在运行时创建必要的目录。看到这个answer https://stackoverflow.com/a/50479784/10878733.
  5. 转换将在云端进行。

所以,TOCO没有问题。 看到这个notebook https://colab.research.google.com/drive/1IUIn9ffk5ICKujqPyuGaHL2irQ9Wmtpm#scrollTo=go_GFH86fLHr获取信息。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

将 Tensorflow 模型转换为 tensorflow-lite (.tflite) 格式时出现问题 的相关文章

  • 获取 .wav 文件长度或持续时间

    我正在寻找一种方法来找出 python 中音频文件 wav 的持续时间 到目前为止我已经了解了 pythonwave图书馆 mutagen pymedia pymad我无法获取 wav 文件的持续时间 Pymad给了我持续时间 但它不一致
  • 此 TypeError 消息中提到的“代码对象”是什么?

    在尝试使用Python时exec声明 我收到以下错误 TypeError exec arg 1 must be a string file or code object 我不想传递字符串或文件 但什么是代码对象 如何创建一个 创建代码对象的
  • Mypy 无法从文字列表推断项目的类型

    我有一个变量x和一个文字列表 例如 0 1 2 我想转换x这些文字之一 如果x在列表中 我将其退回 否则我返回一个后备值 from typing import Literal Set Foo Literal 0 1 2 foos Set F
  • 从 Azure ML 实验中访问 Azure Blob 存储

    Azure ML 实验提供了通过以下方式读取 CSV 文件并将其写入 Azure Blob 存储的方法 Reader and Writer模块 但是 我需要将 JSON 文件写入 blob 存储 由于没有模块可以执行此操作 因此我尝试在Ex
  • 使用 Paramiko 进行 DSA 密钥转发?

    我正在使用 Paramiko 在远程服务器上执行 bash 脚本 在其中一些脚本中 存在与其他服务器的 ssh 连接 如果我只使用 bash 不使用 Python 我的 DSA 密钥将被第一个远程服务器上的 bash 脚本转发并使用 以连接
  • 如何在 Tensorflow 对象检测 API 中查找边界框坐标

    我正在使用 Tensorflow 对象检测 API 代码 我训练了我的模型并获得了很高的检测百分比 我一直在尝试获取边界框坐标 但它不断打印出 100 个奇怪数组的列表 经过在线广泛搜索后 我发现数组中的数字意味着什么 边界框坐标相对于底层
  • python 中的 <> 运算符有什么作用?

    我刚刚遇到这个here http www feedparser org feedparser py 总是这样使用 if string1 find string2 lt gt 1 pass 什么是 lt gt 运算符这样做 为什么不使用通常的
  • Spark 和 Python 使用自定义文件格式/生成器作为 RDD 的输入

    我想问一下 Spark 中输入的可能性 我可以看到从http spark apache org docs latest programming guide html http spark apache org docs latest pro
  • Airflow 1.9 - 无法将日志写入 s3

    我在 aws 的 kubernetes 中运行气流 1 9 我希望将日志发送到 s3 因为气流容器本身的寿命并不长 我已经阅读了描述该过程的各种线程和文档 但我仍然无法让它工作 首先是一个测试 向我证明 s3 配置和权限是有效的 这是在我们
  • 在Raspberry pi上升级skimage版本

    我已经使用 Raspberry Pi 2 上的 synaptic 包管理器安装了 python 包 然而 skimage 模块版本 0 6 是 synaptic 中最新的可用版本 有人可以指导我如何将其升级到0 11 因为旧版本中缺少某些功
  • 可以使用哪些技术来衡量 pandas/numpy 解决方案的性能

    Question 如何简洁全面地衡量下面各个功能的性能 Example 考虑数据框df df pd DataFrame Group list QLCKPXNLNTIXAWYMWACA Value 29 52 71 51 45 76 68 6
  • 如何在亚马逊 EC2 上调试 python 网站?

    我是网络开发新手 这可能是一个愚蠢的问题 但我找不到可以帮助我的确切答案或教程 我工作的公司的网站 用 python django 构建 托管在亚马逊 EC2 上 我想知道从哪里开始调试这个生产站点并检查存储在那里的日志和数据库 我有帐户信
  • 如何给URL添加变量?

    我正在尝试从网站收集数据 我有一个 Excel 文件 其中包含该网站的所有不同扩展名 F i www example com example2 我有一个脚本可以成功从网站中提取 HTML 但现在我想为所有扩展自动执行此操作 然而 当我说 s
  • Django 管理器链接

    我想知道是否有可能 如果可以的话 如何 将多个管理器链接在一起以生成受两个单独管理器影响的查询集 我将解释我正在研究的具体示例 我有多个抽象模型类 用于为其他模型提供小型的特定功能 其中两个模型是DeleteMixin 和GlobalMix
  • 带 Flask 的 RPI dht22:无法将第 4 行设置为输入 - 等待 PulseIn 消息超时

    我正在尝试制作一个 Raspberry Pi 3 REST API 使用 DHT22 提供温度和湿度 整个代码 from flask import Flask jsonify request from sds011 import SDS01
  • 如何编写一个接受 int 或 float 的 C 函数?

    我想用 C 语言创建一个扩展 Python 的函数 该函数可以接受 float 或 int 类型的输入 所以基本上 我想要f 5 and f 5 5 成为可接受的输入 我认为我不能使用if PyArg ParseTuple args i v
  • 如何从namedtuple实例列表创建pandas DataFrame(带有索引或多索引)?

    简单的例子 from collections import namedtuple import pandas Price namedtuple Price ticker date price a Price GE 2010 01 01 30
  • 如何(安全)将 Python 对象发送到我的 Flask API?

    我目前正在尝试构建一个 Flask Web API 它能够在 POST 请求中接收 python 对象 我使用 Python 3 7 1 创建请求 使用 Python 2 7 运行 API 该 API 设置为在我的本地计算机上运行 我试图发
  • 用于插入或替换 URL 参数的 Django 模板标签

    有人知道 Django 模板标签可以获取当前路径和查询字符串并插入或替换查询字符串值吗 例如向 some custom path q how now brown cow page 3 filter person 发出请求 电话 urlpar
  • 无法安装最新版本的 Numpy (1.22.3)

    我正在尝试安装最新版本的 numpy 即 1 22 3 但看起来 pip 无法找到最后一个版本 我知道我可以从源代码本地安装它 但我想了解为什么我无法使用 pip 安装它 PS 我有最新版本的pip 22 0 4 ERROR Could n

随机推荐