lstmcell转onnx报错 自定义pytorch模型替换

2023-10-27

lstmcell在转onnx的时候会遇到不支持的情况,如果模型已经训练好,可以通过自己实现lstmcell的方式,加载训练好的权重;以下是实现代码

class MyLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(MyLSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = nn.Parameter(torch.Tensor(4 * hidden_size, input_size))
        self.weight_hh = nn.Parameter(torch.Tensor(4 * hidden_size, hidden_size))
        self.bias_ih = nn.Parameter(torch.Tensor(4 * hidden_size))
        self.bias_hh = nn.Parameter(torch.Tensor(4 * hidden_size))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight_ih, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.weight_hh, a=math.sqrt(5))
        nn.init.zeros_(self.bias_ih)
        nn.init.zeros_(self.bias_hh)

    def forward(self, input, hx):
        # input: (batch_size, input_size)
        # hx: (batch_size, hidden_size)
        hx = hx[0] if isinstance(hx, tuple) else hx
        gates = (input @ self.weight_ih.t() + self.bias_ih +
                 hx @ self.weight_hh.t() + self.bias_hh)
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)
        cy = (forgetgate * hx) + (ingate * cellgate)
        hy = outgate * torch.tanh(cy)
        return hy, cy

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

lstmcell转onnx报错 自定义pytorch模型替换 的相关文章

随机推荐

  • Linux笔记

    命令 提供一定功能的工具 ssh 提供远程登录功能 参数 命令的作用对象 193 3 3 3 远程登录的作用主机 选项 命令作用的方式 p 22 通过22端口登录到主机 电脑 外壳shell 内核 输入输出设备 用户 提供意愿 转化为命令与
  • nestjs:Cannot read property ‘retryAttempts‘ of undefined

    描述 Cannot read property retryAttempts of undefined 解决 检查数据库的配置是否有问题
  • 日期格式化方法

    时间格式化 有时候我们会用到时间的展示 时间的展示种类也是各种各样 对于不用的产品需要不同的样式 这时候就需要我们做一下时间的格式化处理 下面是一种常见的日期显示方式 代码如下 格式化时间 param String date 原始时间格式
  • 23种设计模式(七) —— 手写实现 Builder 模式 (组装复杂实例)

    文章目录 一 Builder 模式 二 示例 2 1 示例实现功能 2 2 具体实现 2 3 运行结果 三 Builder 模式中登场的角色 四 原文链接 Author Gorit Date 2021 10 24 2021年发表博文 22
  • 你还不知道的简历准备及面试技巧

    最近已经不止听到一位朋友吐槽工作不好找了 一波又一波的裁员潮 ChatGPT 等人工智能工具的爆火 1158 万的应届毕业生 都让今年 IT 行业的就业状况雪上加霜 面对愈加激烈的求职竞争 作为程序员 应该掌握哪些面试技巧 本文邀请了 2
  • Internet的路由选择协议(RIP、OSPF)

    有关路由选择协议的几个概念 1 理想的路由算法 路由选择协议的核心就是路由算法 即路由器通过算法来获得路由 一个理想的路由算法应该具有以下的特点 算法必须是正确和完整的 算法在计算上应简单 算法应能适应通信量和网络拓扑的变化 算法应具有稳定
  • OSG仿真案例(9)——JY61陀螺仪控制飞机姿态

    前言 在调试osg中模型运动姿态时 总觉得直观性不够强 所以有了想买个硬件陀螺仪 当时并不知道这个硬件应该叫什么名字 在淘宝搜索角度传感器的 几个驱动 1 CH340驱动 这个驱动在自带资源包里面 但是不可以用 只能自己在网上找 发现是型号
  • 数据库JDBC --- Java Database Connectivity

    数据库JDBC Java Database Connectivity 关于JDBC 什么是JDBC JDBC的组成 JDBC API JDBC的数据类型 创建JDBC的步骤 常用属性 Result Set ResultSetMetaData
  • Oracle使用IN 不能超过1000问题

    1 美图 2 背景 是写代码的是遇到问题 ORA 01795 列表中的最大表达式数为 1000 虽然使用了 批量处理解决了问题 但是因为是使用了myIbatis spring boot oracle 我不太想 直接改代码 想通过修改myIb
  • 25行jQuery代码实现轮播图

    对于刚刚学习前端的同学来说 做一个轮播图是非常不容易的 今天我就将自己的心得跟和大家分享一下 实现轮播图有很多方法 今天我们就讲其中一种方法 让图片显示在一行内 然后让图片有规律的向左移动 大家可以先看看效果http www shareko
  • sqli-labs (less-24)

    sqli labs less 24 进入24关 输入用户名和密码 登入后会显示你的用户名 下面的输入框就是改密码 我在输入用户名和密码的位置试了很多次 发现用户名和密码的位置是没有注入点的 这里我们先点击右下角的 New User clic
  • Flutter-设置分割线Divider

    Divider height 1 0 indent 0 0 color MyColors color gray 150
  • PowerBI开发 第十八篇:行级安全(RLS)

    PowerBI可以通过RLS Row level security 限制用户对数据的访问 过滤器在行级别限制数据的访问 用户可以在角色中定义过滤器 通过角色来限制数据的访问 在PowerBI Service中 workspace中的memb
  • uniapp getUserProfile:fail invalid session

    uniapp uni getUserProflie 部分安卓手机调不起来弹窗 错误原因 应该在uni getUserProflie之前调用uni login 但是直接在uni login的成功回调里面调用uni getUserProflie
  • 九、Linux系统中的文件传输

    九 Linux系统中的文件传输 实验准备 两台可以通信的主机 systemctl disable firewalld systemctl stop firewalld 9 1 scp命令 上传 scp 本地文件 远程主机用户 远程主机ip
  • SDUT 2023 summer team contest(for 22) - 14

    A Amanda Lounges 题意 有n个机场 m条边 对于每个机场可能需要等候室也可能不需要 如果输入2 代表路线连接的两个机场都需要建立 输入1 代表路线连接的其中一个机场建立 必须 输入0代表路线连接的两个机场都不可以建立 问你最
  • 关于https页面使用ifream嵌套http页面问题解决

    之前公司项目 部署的时候协议用的http 然后前几天把协议换成了https的 当时也没仔细测试 觉得没什么问题 然后 昨天发现其中的某个播放视频的页面显示不出来了 报错信息 接着上这个页面的部分代码 就是这个页面用ifram嵌套了另一个项目
  • MyBatisPlus-黑马-笔记

    MyBatisPlus 目录 入门案例 标准数据层开发 标准CRUD使用 分页 DQL编程控制 条件查询 null判定 查询投影 查询条件 等值查询 范围查询 模糊查询 映射匹配兼容性 DML编程控制 id生成策略控制 多记录操作 逻辑删除
  • window.close()失效问题

    一般的窗口关闭的JS如下写法 window close 但是呢 chrome firefox等中有时候会不起作用 改为下面的写法 window open about blank self close 或者 window open self
  • lstmcell转onnx报错 自定义pytorch模型替换

    lstmcell在转onnx的时候会遇到不支持的情况 如果模型已经训练好 可以通过自己实现lstmcell的方式 加载训练好的权重 以下是实现代码 class MyLSTMCell nn Module def init self input