pytorch和paddle的存储模型变量state_dict命名规则分析

2023-05-16

目录

  • 一、Pytorch存储模型变量命名分析
  • 二、PaddlePaddle存储模型变量命名分析
  • 三、Pytorch和Paddle相互转化

一、Pytorch存储模型变量命名分析

在pytorch中,存储变量的名称就在def init(self)中定义,名字就是self中的定义名称。若在类中还调用了其他的类,那么名称则为实例化的变量名称。

典型示例如下:

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
 
 
class test(nn.Module):
    '''
    定义子类
    '''
    def __init__(self):
        super(test, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)


class TheModelClass(nn.Module):
    '''
    定义测试类
    '''
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.test= test()
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def main():
    '''
    主函数
    '''
    # 建立模型
    model = TheModelClass()

    # 建立优化器
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    # 输出模型变量字典
    print('Model.state_dict:')
    for param_tensor in model.state_dict():
        # 打印 key value字典
        print(param_tensor, '\t', model.state_dict()[param_tensor].size())

    # 输出优化器变量字典
    print('Optimizer,s state_dict:')
    for var_name in optimizer.state_dict():
        print(var_name, '\t', optimizer.state_dict()[var_name])


if __name__ == '__main__':
    '''
    程序入口
    '''
    main()

输出结果如下:

Model.state_dict:
conv1.weight     torch.Size([6, 3, 5, 5])
conv1.bias       torch.Size([6])
conv2.weight     torch.Size([16, 6, 5, 5])
conv2.bias       torch.Size([16])
fc1.weight       torch.Size([120, 400])
fc1.bias         torch.Size([120])
fc2.weight       torch.Size([84, 120])
fc2.bias         torch.Size([84])
fc3.weight       torch.Size([10, 84])
fc3.bias         torch.Size([10])
test.conv1.weight        torch.Size([6, 3, 5, 5])
test.conv1.bias          torch.Size([6])
test.conv2.weight        torch.Size([16, 6, 5, 5])
test.conv2.bias          torch.Size([16])
Optimizer,s state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]}]

二、PaddlePaddle存储模型变量命名分析

典型代码如下:

import paddle.nn as nn
import paddle.optimizer as optim
import paddle.nn.functional as F
 
 
class test(nn.Layer):
    '''
    定义子类
    '''
    def __init__(self):
        super(test, self).__init__()
        self.conv1 = nn.Conv2D(3, 6, 5)
        self.pool = nn.MaxPool2D(2, 2)
        self.conv2 = nn.Conv2D(6, 16, 5)


class TheModelClass(nn.Layer):
    '''
    定义测试类
    '''
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2D(3, 6, 5)
        self.pool = nn.MaxPool2D(2, 2)
        self.conv2 = nn.Conv2D(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.test= test()
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def main():
    '''
    主函数
    '''
    # 建立模型
    model = TheModelClass()

    # 建立优化器
    optimizer = optim.SGD(parameters = model.parameters(), learning_rate=0.001, weight_decay=0.9)

    # 输出模型变量字典
    print('Model.state_dict:')
    for param_tensor in model.state_dict().keys():
        # 打印 key value字典
        print(param_tensor, '\t', model.state_dict()[param_tensor].shape)


if __name__ == '__main__':
    '''
    程序入口
    '''
    main()

其输出如下所示:

Model.state_dict:
conv1.weight     [6, 3, 5, 5]
conv1.bias       [6]
conv2.weight     [16, 6, 5, 5]
conv2.bias       [16]
fc1.weight       [400, 120]
fc1.bias         [120]
fc2.weight       [120, 84]
fc2.bias         [84]
fc3.weight       [84, 10]
fc3.bias         [10]
test.conv1.weight        [6, 3, 5, 5]
test.conv1.bias          [6]
test.conv2.weight        [16, 6, 5, 5]
test.conv2.bias          [16]

通过对比发现,在命名规则上pytorch和paddlepaddle是一样的。只不过对于fc层来说,它的weight的形状是相互转置的关系。

三、Pytorch和Paddle相互转化

通过上面的分析我们知道,pytorch和paddle的模型变量命名规则是完全一样的。那么对于训练好的pytorch或paddle模型,我们就可以基于上述原则进行互转。在互换时注意fc层,对于fc层的变量需要做转置处理。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch和paddle的存储模型变量state_dict命名规则分析 的相关文章

  • 开发GB28181监控平台前期准备总结

    首先得准备PJLIB的编译 xff0c 这个搜一下 xff0c 下载以后它是有VS的工程文件的 xff0c 所以编译很方便 得到这个库以后 xff0c 就可以编写SIP服务程序了 xff0c 服务程序可以验证GB28181的相关通讯流程 x
  • esp8266 丢失固件 丢失程序问题

    1 首先esp8266 丢失固件 丢失程序问题已经解决 2 解决方法 我们制作了一个固件保护主板 xff0c 提供2种供电接口 xff0c 支持5v稳压 串口电平保护 xff0c 固件保护 xff0c 反电动势保护 xff0c 支持复位按键
  • printf重定向

    1 printf与fputc 对于 printf 函数相信大家都不陌生 xff0c 第一个C语言程序就是使用 printf 函数在屏幕上的控制台打印出Hello World xff0c 之后使用 printf 函数输出各种类型的数据 xff
  • ESP32_BLUFI代码移植过程遇到的问题

    1 先是运行esp32官方给的例程 xff0c 出现了错误报错如下 xff1a esp image Image length 1053648 doesn t fit in partition length 1048576 boot Fact
  • Java 中的 Iterator 迭代器详解

    x1f366 Iterator 接口 在程序开发中 xff0c 经常需要遍历集合中的所有元素 针对这种需求 xff0c JDK 专门提供了一个接口 java util Iterator Iterator 接口也是 Java 集合中的一员 x
  • 三.【NodeJs入门学习】POST接口

    上一节我们学习了get接口 xff0c 这一节我们自己来写一下post接口 1 复习一下 先复习一下上一节中get请求的步骤 上图是在入口app js中处理get请求 xff0c 先拿到请求的url xff0c 然后设置了一个函数handl
  • 多进程和多线程比较

    原文 xff1a http blog csdn net lishenglong666 article details 8557215 很详细 对比维度 多进程 多线程 总结 数据共享 同步 数据共享复杂 xff0c 需要用IPC xff1b
  • C++ 之头文件声明定义

    最近在学习 c 43 43 在编译与链接过程中遇到了一些定义与声明的问题 经过多处查阅资料 基本解惑 现记录与此 希望让后面人少走些弯路 C 43 43 的头文件应该用什么扩展名 目前业界的常用格式如下 implementation fil
  • arduino修改串口缓冲区大小的三种办法

    由于SoftwareSerial h默认只接收64字节串行缓冲区 xff0c Arduino会将之后接收到的数据丢弃 xff0c 不满足业务需求 以下三种方法是笔者参考网上各种资料总结出来 xff0c 对于WEMOS D1 R2 xff0c
  • C语言调用libcurl的一个简单例子

    首先我们创建一个php页面 xff1a lt meta http equiv 61 span class hljs string 34 Content Type 34 span content 61 span class hljs stri
  • 【C++】类构造函数、析构函数的调用顺序「完整版」

    一 全局变量 静态变量和局部变量 全局变量在程序开始时调用构造函数 在程序结束时调用析构函数 静态变量在所在函数第一次被调用时调用构造函数 在程序结束时调用析构函数 xff0c 只调用一次 局部变量在所在的代码段被执行时调用构造函数 xff
  • linux下使用shell发送http请求

    本文主要介绍如何在linux下使用shell发送http请求 一 curl 1 get请求 curl命令默认下就是使用get方式发送http请求 curl www span class hljs preprocessor baidu spa
  • 【STL真好用】1057 Stack C++(30)

    1057 Stack 30 分 Stack is one of the most fundamental data structures which is based on the principle of Last In First Ou
  • C++学习之头文件引用

    目录结构如下 test h的定义如下 xff1a ifndef TEST H define TEST H include lt vector gt include lt string gt using namespace std class
  • checksum 算法

    说明 checksum xff1a 总和检验码 xff0c 校验和 xff0c 可以理解为check xff08 校验 xff09 xff0c sum xff08 和 xff09 在数据处理和通信领域 xff0c 通过一定算法对传输的数据进
  • 解决cannot open shared object file: No such file or directory

    一 linux下调用动态库 so文件时提示 xff1a cannot open shared object file No such file or directory 解决办法 xff1a 1 此时ldd xxx查看依赖缺少哪些库 lib
  • cmake 使用(六)

    本文是 cmake 使用的第六篇 主要介绍如何设置编译器优化标志 上一篇的链接为 xff1a https blog csdn net QCZL CC article details 119825737 xff0c 主要介绍如何将自己的软件安
  • 8086寄存器介绍

    8086 有14个16位寄存器 xff0c 这14个寄存器按其用途可分为 1 通用寄存器 2 指令指针 3 标志寄存器和 4 段寄存器等4类 1 通用寄存器有8个 又可以分成2组 一组是数据寄存器 4个 另一组是指针寄存器及变址寄存器 4个
  • C++常用操作符:: -> . (例子详解)

    C 43 43 提供了三种访问类或者类对象的操作符 xff0c 他们是 双冒号 点 箭头 gt 这三种操作符有着各自的使用场景和定义 双冒号 A B 表示作用域运算符 A一定是一个类的名称或命名空间的名称 仅仅用于当B是A类 A命名空间的一
  • STM32中断优先级的分配以及中断原则

    STM32d的中断优先级由NVIC IPRx寄存器来配置 xff0c IPR的宽度为8bit所以原则上每个中断可配置的优先级为0 255 xff0c 数值越小优先级越高 xff0c 但对于大部分的 Cortex M3芯片都会精简设计 xff

随机推荐