将tensorpack的inference改为pytorch

2023-10-26

最近在跑一个OCR模型,模型是用Tensorpack写的,模型做inference的时候,显存,速度都不是很理想,改成pytorch后,显存占用,速度比之前好了很多。记录下改inference的过程遇到的一些坑。

  • 将pb文件转为pth文件
import torch
from collections import OrderedDict
import tensorflow as tf
from tensorflow.python.framework import tensor_util
def view_params():
    pb_file = 'ocr/checkpoint/text_recognition_377500.pb'
    graph = tf.Graph()
    with graph.as_default():
        with tf.gfile.FastGFile(pb_file, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            _ = tf.import_graph_def(graph_def, name='')
            graph_nodes=[n for n in graph_def.node]
            wts = [n for n in graph_nodes if n.op=='Const']

    odic = OrderedDict()
    for n in wts:
        param = tensor_util.MakeNdarray(n.attr['value'].tensor)
        if not param.size == 0:
            odic[n.name] = tensor_util.MakeNdarray(n.attr['value'].tensor)
    torch.save(odic, 'pb_377500.pth')
  • 模型代码
class TextRecognition(nn.Module):
    def __init__(self):
        super(TextRecognition, self).__init__()
        self.features = nn.Sequential(OrderedDict([
            ('Conv2d_1a_3x3', BasicConv2d(3, 32, kernel_size=3, stride=2, padding='SAME')),
            ('Conv2d_2a_3x3', BasicConv2d(32, 32, kernel_size=3, stride=1, padding='SAME')),
           	...
            ('Mixed_6h', Inception_B()),
        ]))
        self.attention_lstm = AttentionLstm()
        
    def forward(self, x):
        x = self.features(x)
        x = self.attention_lstm(x)
        return x
        
class LinearBias(nn.Module):
    def __init__(self, size):
        super(LinearBias, self).__init__()
        self.param = nn.Parameter(torch.Tensor(size))

    def forward(self, x):
        x = x + self.param
        return x
        
class AttentionLstm(nn.Module):
    def __init__(self, seq_len=33, is_training=False, num_classes=7569,
                    wemb_size=256, channel=1024, lstm_size=512):
        super(AttentionLstm, self).__init__()
        self.seq_len = seq_len  # 33
		...
        self.W_wemb = nn.Linear(self.num_classes, self.wemb_size, bias=False)
        self.lstm_b = LinearBias(self.lstm_size*4)
        self.tanh = nn.Tanh()
        self.softmax_1d = nn.Softmax(dim=1)
        self.sigmoid = nn.Sigmoid()
        self.dropout_1d = nn.Dropout(0.)

    def forward(self, cnn_feature):  # bs, 1024, h, w
        _, _, self.height, self.width = cnn_feature.size()
        ...
        return output_array, attention_array

Pytorch 与 TensorFlow 二维卷积(Conv2d)填充(padding)上的差异,写卷积层的时候遇到的坑。
这种差异是由 TensorFlow 和 Pytorch 在卷积运算时使用的填充方式不同导致的。Pytorch 在填充的时候,上、下、左、右各方向填充的大小是一样的,但 TensorFlow 却允许不一样。
参考博客1
参考博客2

在AttentionLstm中,有一个LinearBias类,该类会将pack和self.lstm_b加起来,但是如果在forward中写成相加的形式,就不能将该self.lstm_b保存下来,写成类可以使模型加载参数的时候可以一次加载完成。

class AttentionLstm(nn.Module):
    def __init__(self):
        super(AttentionLstm, self).__init__()
        self.seq_len = 33  # 33
        self.W_wemb = nn.Linear(10, 20, bias=False)
        self.lstm_b = LinearBias(4)
        self.a = nn.Parameter(torch.Tensor(1))
        self.b = torch.randn(1, 3)

test = AttentionLstm()
# odict_keys(['a', 'W_wemb.weight', 'lstm_b.param']),self.b不会保存在state_dict中,而self.lstm_b会保存
print(test.state_dict())
pack = self.lstm_W(wemb_prev) + self.lstm_U(h_prev) + self.lstm_Z(attention_feature)  # bs, 2048
pack_with_bias = self.lstm_b(pack)

原代码使用的大都是tensorflow的函数,所以要改成相应的pytorch的函数。

tensorflow pytorch
tf.matmul torch.matmul
tf.multiply torch.mul
tf.sigmoid torch.nn.Sigmoid
tf.nn.dropout torch.nn.Dropout
tf.nn.softmax torch.nn.Softmax
tf.tanh torch.tanh
tf.split torch.split
tf.shape torch.size
tf.reshape / tf.transpose torch.reshape / view
tf.expand_dims torch.unsqueeze
tf.add_n/tf.add torch.add
tf.reduce_sum torch.sum
tf.reduce_mean torch.mean
tf.transpose torch.permute
tf.concat torch.cat
tf.nn.embedding_lookup torch.index_select
  • 加载参数

最后加载参数验证

net = attention_ocr_pytorch.TextRecognition()
net.load_state_dict(torch.load('log/pytorch/pb_377500_fl.pth'))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

将tensorpack的inference改为pytorch 的相关文章

  • Tensorflow 版本与 Tensorboard 版本

    我想问一下tensorflow版本是否可以与tensorboard版本不同 我有个问题 404 problem 有人建议安装一个新版本的张量板 https github com tensorflow tensorboard issues 9
  • 模块“tensorflow._api.v2.train”没有属性“GradientDescentOptimizer”

    我使用Python 3 7 3并安装了tensorflow 2 0 0 alpha0 但是存在一些问题 例如 模块 tensorflow api v2 train 没有属性 GradientDescentOptimizer 这是我的全部代码
  • 如何在 Tensorflow 上测试自己的图像到 Cifar-10 教程?

    我训练了 Tensorflow Cifar10 模型 我想为其提供自己的单个图像 32 32 jpg png 我想将标签和每个标签的概率视为输出 但我对此遇到了一些麻烦 搜索堆栈溢出后 我发现了一些帖子this https stackove
  • Tensorflow 中多维时间序列预测中的向量表示

    我有一个大型数据集 约 3000 万个数据点 具有 5 个特征 我已使用 K 均值将其减少到 200 000 个集群 数据是大约 150 000 个时间步长的时间序列 我想要训练模型的数据是每个时间步上特定簇的存在 预测模型的目的是生成一个
  • TF 数据 API:如何有效地从图像中采样小块

    考虑创建从高分辨率图像目录中采样随机小图像块的数据集的问题 Tensorflow 数据集 API 提供了一种非常简单的方法来实现此目的 即构建图像名称的数据集 对它们进行排序 将其映射到加载的图像 然后映射到随机裁剪的补丁 然而 这种幼稚的
  • CNN 模型分类错误:logits 和标签必须可广播:logits_size=[32,10] labels_size=[32,13]

    这里我尝试在图像分类上运行 CNN 模型 这是批量大小和 13 个标签 Image batch shape 32 32 32 3 Label batch shape 32 13 Watch Back Watch Chargers Watch
  • Odroid XU4 上的 Tensorflow 编译

    我正在尝试在 Odroid XU4 16GB eMMc Ubuntu 16 上编译 Tensorflow 尝试了完整和精简版 但出现如图所示的错误 https www dropbox com sh j86ysncze1q0eka AAB8R
  • 如何将 .pb 文件转换为 .h5。 (张量流模型到keras)

    我已经使用重新训练了我的模型tensorflow现在想使用keras以避免会话内容 我怎样才能转换 pb文件至 h5 import tensorflow as tf from tensorflow keras models import s
  • BERT - 池化输出与序列输出的第一个向量不同

    我在 Tensorflow 中使用 BERT 有一个细节我不太明白 根据文档 https tfhub dev google bert uncased L 12 H 768 A 12 1 https tfhub dev google bert
  • Tensorflow 何时更新权重和偏差?

    张量流什么时候更新for循环中的权重和偏差 下面是tf的github上的代码 mnist softmax py https github com tensorflow tensorflow blob master tensorflow ex
  • Blenderbot 微调

    我一直在尝试微调 HuggingFace 的对话模型 Blendebot 我已经尝试过官方拥抱脸网站上给出的传统方法 该方法要求我们使用 trainer train 方法来完成此操作 我使用 compile 方法尝试了它 我尝试过使用 Py
  • 分布式张量流中的并行进程

    我有带有训练参数的张量流神经网络 它是代理的 策略 网络正在核心程序的主张量流会话的训练循环中进行更新 在每个训练周期结束时 我需要将该网络传递给几个并行进程 工作人员 这些进程将使用它来从代理策略与环境的交互中收集样本 我需要并行执行 因
  • 在 Tensorflow 中使用队列将数据馈送到网络时分开验证和训练图

    我一直在做大量关于如何使用队列将数据正确输入网络的研究 但是 我在互联网上找不到任何解决方案 目前我的代码能够读取训练数据并执行训练 但无需验证和测试 这里有一些重要的行构成了我的代码 images volumes utils inputs
  • 如何仅从源代码构建 TensorFlow lite 而不是所有 TensorFlow?

    我正在尝试使用 Edgetpu USB 加速器与 Intel ATOM 单板计算机和 C API 进行实时推理 Edgetpu 的 C API 基于 TensorFlow lite C API 我需要包含来自tensorflow lite目
  • Tensorflow构建量化工具-bazel构建错误

    我正在尝试编译量化脚本 如下所述皮特 沃登的博客 https petewarden com 2016 05 03 how to quantize neural networks with tensorflow 但是 在运行以下 bazel
  • TensorFlow:训练时参数不更新

    我正在使用 TensorFlow 实现分类模型 我面临的问题是 当我运行训练步骤时 我的权重和误差没有更新 结果 我的网络不断返回相同的结果 我根据以下内容开发了我的模型MNIST 示例 https www tensorflow org v
  • Tensorflow:使用 Adam 优化器

    我正在张量流中试验一些简单的模型 包括一个看起来与第一个非常相似的模型面向 ML 初学者的 MNIST 示例 http www tensorflow org tutorials mnist beginners index md 但维数稍大一
  • 使用 TFLite 量化模型的参数进行计算操作

    我正在尝试使用量化的 Mobilenetv2 模型在硬件中实现图像分类here https www tensorflow org lite guide hosted models 为此 我首先需要从头到尾重现推理过程 以确保我理解对数据执行
  • Google Colab:为什么 CPU 比 TPU 快?

    我正在使用 Google colabTPU训练一个简单的Keras模型 删除分布式strategy并在CPU比TPU 这怎么可能 import timeit import os import tensorflow as tf from sk
  • Tensorflow:np数组的next_batch函数

    我的火车数据为 xTrain numpy asarray 100 1 5 6 yTrain numpy asarray 200 2 10 12 如何定义 next batch size 方法以从训练数据中获取随机元素的 size 个数 您可

随机推荐

  • 武邑中学2021高考成绩查询,武邑中学高考成绩

    问 衡水武邑中学怎么样 答 收费2 3万 大前年中考400多分的进衡水二中的去年高考600多分 而中考400多分去武中的去年高考300分都不到 这可是有名有姓的真人的真实情况 光复习生每年就60多个班 应届考生30多个班 每年六七千人都抬不
  • 箭头函数()=>{}与function的区别

    1 箭头函数与function定义函数的写法 function function fn a b return a b arrow function var foo a b gt return a b 2 this的指向 使用function
  • uni-app开发总结分享

    目录 一 uni app介绍 二 uni app和vue的具体区别 1 组件 标签的变化 2 js 3 uniapp自带路由和请求方式 三 环境搭建 1 安装HbuilderX 2 创建uni app项目 四 项目目录结构 五 运行uni
  • 安装mysql提示3306端口已经被占用解决方案

    今天遇到的问题是这样的 之前已经安装过mysql了 一直用的好好的 但是今天开启服务时报异常 无法启动 为了省事 于是想到卸载重装 在安装的过程中发现3306已经被占用 这也是一开始服务无法启动的原因 看到有人说用fport查看端口号 于是
  • JSP学生网上选课系统设计(源代码+论文+答辩PPT)

    QQ 19966519194 摘要 随着科学技术的不断提高 计算机科学日渐成熟 其强大的功能已为人们深刻认识 它已进入人类社会的各个领域并发挥着越来越重要的作用 学生选课系统作为一种现代化的教学技术 以越来越受到人民的重视 是一个学校不可缺
  • [Unity][ShaderGraph][FlowCanvas] SetFloat 无效:通过脚本控制 shader 的动态参数时需要使用参数的引用名

    我的 shader 很简单 就是一个 tiling and offset 制作滚动效果 然后我想用一个脚本控制 speed 但是实际运行没有起效果 一开始我看的这个 然后用的 sharedmaterial https forum unity
  • Stable Diffuse AI 绘画 之 ControlNet 插件及其对应模型的下载安装

    Stable Diffuse AI 绘画 之 ControlNet 插件及其对应模型的下载安装 目录 Stable Diffuse AI 绘画 之 ControlNet 插件及其对应模型的下载安装 一 简单介绍 二 ControlNet 插
  • Swift - 将String类型的数字转换成数字类型(支持十进制、十六进制)

    https www cnblogs com Free Thinker p 7243683 html 1 十进制的字符串转成数字 Swift中 如果要把字符串转换成数字类型 比如整型 浮点型等 可以先转成NSString类型 让后再转 1 2
  • JAVA:jdbc:sqlserver 连接SQLserver实例名

    weChatjdbc driverClassName com microsoft sqlserver jdbc SQLServerDriver weChatjdbc url jdbc sqlserver 127 0 0 1 instance
  • Ubuntu服务器下安装FastDFS及nginx配置访问等问题记录

    Ubuntu服务器下安装FastDFS及nginx配置访问 下载对应包 编译环境 包解压环境配置 配置nginx模块和安装nginx来进行访问该图片 下载对应包 下载方式一 直接使用 wget 下载 如果太慢 可以去github下载 然后上
  • 基于Matlab开发的动态机器人轨迹仿真

    基于Matlab开发的动态机器人轨迹仿真 近年来 机器人技术的发展已经进入了高速发展时期 控制与仿真技术作为机器人领域中至关重要的一环 也随之发展壮大 而在动态机器人轨迹仿真方面 Matlab作为一款具备强大数学计算能力的软件 在该领域中得
  • QT实现sqlite数据库连接池

    ifndef CONNECTIONPOOL H define CONNECTIONPOOL H FileName 数据库连接池 Function 获取连接时不需要了解连接的名字 支持多线程 保证获取到的连接一定是没有被其他线程正在使用 按需
  • MySQL 远程登录与其常用命令的介绍

    以下的文章主要介绍的是MySQL 远程登录与其常用命令的介绍 MySQL 远程登录与其常用命令之所以能在很短的时间内被人们广泛的应用 原因也是因为它们的独特功能 以下的文章就有对其相关内容的介绍 MySQL 远程登录及常用命令 第一招 My
  • Unbantu22.04使用DevStack一键部署OpenStack(使用nat静态IP)

    d 学习openstack的小白 第一步就遇到了大麻烦 下载并部署Openstack 传统的基于组件 一个个的安装配置更加麻烦 使用DevStack工具 一键部署可能是个不错的选择 But devstack部署期间总是会出现各种各样的错误
  • JavaScript面向对象

    JavaScript面向对象 面向过程 面向过程就是讲需求一步一步自己完全实现 如 一堆衣服 需要自己一件一件洗 面向对象 面向对象是把有共同特征的方法抽取为类 比如 一堆衣服 都需要洗 创建洗衣机类 女朋友类 让她洗 类的定义和使用 定义
  • 机器学习 day09(如何设置学习率α,特征工程,多项式回归)

    1 常见的错误的学习曲线图 上方两个 当关于迭代次数的学习曲线图 出现波浪型或向上递增型 表示梯度下降算法出错 该情况可由 学习率 过大 或代码有bug导致 2 常用的调试方法 选择一个非常非常小的学习率 来查看学习曲线是否还是有误 即在某
  • uni-app网络请求的封装

    uni app网络请求的封装 这几天没事干 就去小程序开发小团队里看看 顺便看了一下代码 在网络请求上发现了一些问题 差点没忍住破口大骂 最终想了想 他们之前没做过 都是第一次就算了 其实是安慰自己而已 网络请求都写在page里 每个请求都
  • 池化方法总结(Pooling)

    在卷积神经网络中 我们经常会碰到池化操作 而池化层往往在卷积层后面 通过池化来降低卷积层输出的特征向量 同时改善结果 不易出现过拟合 为什么可以通过降低维度呢 因为图像具有一种 静态性 的属性 这也就意味着在一个图像区域有用的特征极有可能在
  • JAVA-while循环语句

    while循环语句用法比for语句用起来简单 格式也对的简单 while 判断条件 循环体 public class WhileTest public static void main String args int i 1 while i
  • 将tensorpack的inference改为pytorch

    最近在跑一个OCR模型 模型是用Tensorpack写的 模型做inference的时候 显存 速度都不是很理想 改成pytorch后 显存占用 速度比之前好了很多 记录下改inference的过程遇到的一些坑 将pb文件转为pth文件 i