pytorch自带的模型剪枝工具prune的使用

2023-10-27

torch.nn.utils.prune可以对模型进行剪枝,官方指导如下:

https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

直接上代码

首先建立模型网络:

import torch
import torch.nn as nn
from torchsummary import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SimpleNet(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.conv3 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(in_features=16 * 16 * 24, out_features=num_classes)
    def forward(self, input):
        output = self.conv1(input)
        output = nn.ReLU()(output)
        output = self.conv2(output)
        output = nn.ReLU()(output)
        output = self.pool(output)
        output = self.conv3(output)
        output = nn.ReLU()(output)
        output = self.conv4(output)
        output = nn.ReLU()(output)
        output = output.view(-1, 16 * 16 * 24)
        output = self.fc(output)
        return output
model = SimpleNet().to(device=device)

看一下模型的 summary

summary(model, input_size=(3, 512, 512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 12, 512, 512]             336
            Conv2d-2         [-1, 12, 512, 512]           1,308
         MaxPool2d-3         [-1, 12, 256, 256]               0
            Conv2d-4         [-1, 24, 256, 256]           2,616
            Conv2d-5         [-1, 24, 256, 256]           5,208
            Linear-6                   [-1, 10]          61,450
================================================================
Total params: 70,918
Trainable params: 70,918
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.00
Forward/backward pass size (MB): 78.00
Params size (MB): 0.27
Estimated Total Size (MB): 81.27
----------------------------------------------------------------

打印一下模型结构各层的名称:

print(model.state_dict().keys())

结果:

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'conv3.weight', 'conv3.bias', 'conv4.weight', 'conv4.bias', 'fc.weight', 'fc.bias'])

接下来 对其进行剪枝操作:

import torch.nn.utils.prune as prune
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.conv4, 'weight'),
    (model.fc, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

执行结束后,再打印一下:

print(model.state_dict().keys())

结果:

odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.bias', 'conv2.weight_orig', 'conv2.weight_mask', 'conv3.weight', 'conv3.bias', 'conv4.bias', 'conv4.weight_orig', 'conv4.weight_mask', 'fc.bias', 'fc.weight_orig', 'fc.weight_mask'])

我们发现剪枝结束后conv*.weight已经 消失了,出现了两个weight:weight_orig和weight_mask

其实weight_orig就是剪枝以前的weight,而weight_mask里面 只是0和1,0代表的是被剪枝的

打印一下:

print(model.state_dict()['conv1.weight_orig'])

tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [0., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[0., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 0.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 0.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 0.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 0.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]], device='cuda:0')
prune.remove(module, 

剪枝后,其实还是比较鸡肋的,因为只是剪之后的神经元相当于置零了,模型大小不会变,下面打印一下,有点dropout的意思了

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 12, 512, 512]             336
            Conv2d-2         [-1, 12, 512, 512]           1,308
         MaxPool2d-3         [-1, 12, 256, 256]               0
            Conv2d-4         [-1, 24, 256, 256]           2,616
            Conv2d-5         [-1, 24, 256, 256]           5,208
            Linear-6                   [-1, 10]          61,450
================================================================
Total params: 70,918
Trainable params: 70,918
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.00
Forward/backward pass size (MB): 78.00
Params size (MB): 0.27
Estimated Total Size (MB): 81.27
----------------------------------------------------------------

是不是和剪枝之前实际上是一样的,可能会减少运算,但是似乎好像知乎大神提到的被证明运算也没啥提升

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch自带的模型剪枝工具prune的使用 的相关文章

随机推荐

  • 机器学习——numpy逻辑回归(手写数字识别)

    二分类 识别1 7 import numpy as np import struct import matplotlib pyplot as plt import os from PIL import Image from sklearn
  • timm库安装

    按理来说 conda config append channels conda forge conda install timm 更新timm版本 pip install timm 0 5 4 python3 8 pytorch 1 11
  • HTTP与TCP的区别和联系

    https blog csdn net u013485792 article details 52100533 相信不少初学手机联网开发的朋友都想知道Http与Socket连接究竟有什么区别 希望通过自己的浅显理解能对初学者有所帮助 一 基
  • Kaggle入门——Titanic+随机森林(调参)+逻辑回归

    本博客记录一下自己的Kaggle入门题目 Titanic 只弄了一天 特征工程做得比较草率 结果只有0 76 不过主要是为了体验一下Kaggle竞赛的流程 以及熟悉一下Kaggle的使用 目录 1 题目相关 2 特征工程 3 随机森林 调参
  • 2021-09-08 PuTTY & Xftp-5使用密钥连接服务器

    PuTTY Xftp 5使用密钥连接服务器 所需材料 Xftp 5操作 PuTTY操作 所需材料 Xftp 5 PuTTY 密钥 服务器IP 用户名和密码 Xftp 5操作 新建会话 输入会话名称 主机IP 选择SFTP协议 方法选择Pub
  • 用python输出0到100所有能被3整除的数字_python: 输出 1~100 之间不能被 7 整除的数,每行输出 10 个数字,要求应用字符串格式化方法美化输出格式。...

    输出 1 100 之间不能被 7 整除的数 j 0 for i in range 1 101 遍历1 100取值 定义为变量 i if i 7 0 找出不能被 7 整除的数 print 3d format i end Format格式化输出
  • 伏秒积和安秒积

    伏秒平衡原则 伏秒平衡原则 在稳态工作的开关电源中电感两端的正伏秒值等于负伏秒值 安秒平衡原则 在稳态工作的开关电源中电容两端的正安秒值等于负安秒值 电容两端的电压不能突变 当电容足够大时 可认为其电压不变 电感中的电流不能突变 当电感足够
  • ./configure –prefix 命令用法

    在Linux上编译安装软件时 经常遇到 configure prefix usr这个命令 configure prefix 是什么意思呢 下面简单介绍一下 configure prefix 的用法 源码的安装一般由有这三个步骤 配置 con
  • 使用valgrind检查内存问题并且输出报告

    valgrind内存泄漏分析 是在linux中检查内存泄漏的工具 当程序编写完之后我一般都会使用它来检查一次内存问题 基本上能杜绝服务器的内存泄漏问题 当然是面对C C 这样的语言的 使用方式就是将程序编译好 然后通过valgrind来启动
  • iOS内存管理之autorelease

    当你需要延迟调用release方法的时候会使用autorelease 如 NSString fullName NSString string NSString alloc initWithFormat self firstName self
  • 基础练习 矩阵乘法

    问题描述 给定一个N阶矩阵A 输出A的M次幂 M是非负整数 例如 A 1 2 3 4 A的2次幂 7 10 15 22 输入格式 第一行是一个正整数N M 1 lt N lt 30 0 lt M lt 5 表示矩阵A的阶数和要求的幂数 接下
  • AutoMapper基本使用

    导包 AutoMapper AutoMapper Extensions Microsoft DependencyInjection 假如需要将Student映射为StudentCopy namespace WebApplication14
  • 国内比较快的DNS服务器IP汇总

    DNS是什么 DNS Domain Name System 域名系统 简单的说 就是把我们输入的网站域名翻译成IP地址的系统 比如我们想访问百度 我们会在网页里键入www baidu com 但是电脑不会理解这串字符的含义 于是就把这串字符
  • ElementUI浅尝辄止33:Form 表单

    Form 表单 日常业务中很常见 由输入框 选择器 单选框 多选框等控件组成 用以收集 校验 提交数据 常见于表单请求 登录 数据校验等业务操作中 1 如何使用 包括各种表单项 比如输入框 选择器 开关 单选框 多选框等 在 Form 组件
  • 数据结构与算法(六):图结构

    一 基本概念 二 图的存储结构 1 邻接矩阵 2 邻接表 3 十字链表 三 图的遍历 1 深度优先遍历 2 广度优先遍历 四 最小生成树 1 Prim算法 2 Kruskal算法 五 最短路径 1 Dijkstra算法 图是一种比线性表和树
  • 15:00面试,15:06就出来了,问的问题有点变态。。。

    从小厂出来 没想到在另一家公司又寄了 到这家公司开始上班 加班是每天必不可少的 看在钱给的比较多的份上 就不太计较了 没想到8月一纸通知 所有人不准加班 加班费不仅没有了 薪资还要降40 这下搞的饭都吃不起了 还在有个朋友内推我去了一家互联
  • uniapp 发布前隐私条款、用户协议等配置

    隐私类型 1 android隐私与政策提示框 2 网页版隐私条款 android ios都会用到 3 app内部隐私条款 一 android隐私与政策提示框 根据工业和信息化部关于开展APP侵害用户权益专项整治要求 App提交到应用市场必须
  • CVE-2021-41773&&CVE-2021-42013复现

    CVE 2021 41773 漏洞原理 Apache HTTP Server 是 Apache 基础开放的流行的 HTTP 服务器 在其 2 4 49 版本中 引入了一个路径体验 满足下面两个条件的 Apache 服务器将受到影响 版本等于
  • 汇编寄存器介绍

    1 通用寄存器 名称 全称 32位 16位 8位 编号 功能 rax 累加器 Accumulator eax ax ah al 0 0000 返回值 rcx 计数器 Count Register ecx cx ch cl 1 0001 第二
  • pytorch自带的模型剪枝工具prune的使用

    torch nn utils prune可以对模型进行剪枝 官方指导如下 https pytorch org tutorials intermediate pruning tutorial html 直接上代码 首先建立模型网络 impor