构建模型三要素与权重初始化

2023-11-06

学习过程中的好文,谨防失效,转载自博客园
,结合此篇

1、模型三要素

三要素其实很简单:

  1. 必须要继承nn.Module这个类,要让PyTorch知道这个类是一个Module
  2. __init__(self)中设置好需要的组件,比如conv,pooling,Linear,BatchNorm等等。
  3. 最后在forward(self,x)中用定义好的组件进行组装,就像搭积木,把网络结构搭建出来,这样一个模型就定义好了。

我们先看一个例子:
先看__init__(self)函数

# class Net(nn.Module):
def __init__(self):
	super(Net,self).__init__()
	self.conv1 = nn.Conv2d(3,6,5)
	self.pool1 = nn.MaxPool2d(2,2)
	self.conv2 = nn.Conv2d(6,16,5)
	self.pool2 = nn.MaxPool2d(2,2)
	self.fc1 = nn.Linear(16*5*5,120)
	self.fc2 = nn.Linear(120,84)
	self.fc3 = nn.Linear(84,10)

第一行是初始化,往后定义了一系列组件。nn.Conv2d 就是一般图片处理的卷积模块,然后池化层,全连接层等等。

定义完这些,再定义forward函数

def forward(self,x):
	x = self.pool1(F.relu(self.conv1(x)))
	x = self.pool2(F.relu(self.conv2(x)))
	x = x.view(-1,16*5*5)
	x = F.relu(self.fc1(x))
	x = F.relu(self.fc2(x))
	x = self.fc3(x)
	return x

x为模型的输入,第一行表示x经过conv1,然后经过激活函数relu,然后经过pool1操作
第三行表示对x进行reshape,为后面的全连接层做准备

至此,对一个模型的定义完毕,如何使用呢?
例如:

net = Net()
outputs = net(inputs)

其实net(inputs),就是类似于使用了net.forward(inputs)这个函数。

2、参数初始化

简单地说就是设定什么层用什么初始方法,初始化的方法会在torch.nn.init中

# 定义权值初始化
def initialize_weights(self):
	for m in self.modules():
		if isinstance(m,nn.Conv2d):
			torch.nn.init.xavier_normal_(m.weight.data)
			if m.bias is not None:
				m.bias.data.zero_()
		elif isinstance(m,nn.BatchNorm2d):
			m.weight.data.fill_(1)
			m.bias.data.zero_()
		elif isinstance(m,nn.Linear):
			torch.nn.init.normal_(m.weight.data,0,0.01)
			# m.weight.data.normal_(0,0.01)
			m.bias.data.zero_()

这段代码的基本流程就是,先从self.modules()中遍历每一层,然后判断更曾属于什么类型,是否是Conv2d,是否是BatchNorm2d,是否是Linear的,然后根据不同类型的层,设定不同的权值初始化方法,例如Xavierkaimingnormal_等等。kaiming也是MSRA初始化,是何恺明大佬在微软亚洲研究院的时候,因此得名。

上面代码中用到了self.modules(),这个是什么东西呢?

# self.modules的源码
def modules(self):
	for name,module in self.named_modules():
		yield module

功能就是:能依次返回模型中的各层,yield是让一个函数可以像迭代器一样可以用for循环不断从里面遍历(可能说的不太明确)。

3、完整运行代码

我们用下面的例子来更深入的理解self.modules(),同时也把上面的内容都串起来(下面的代码块可以运行):

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                torch.nn.init.normal_(m.weight.data, 0, 0.01)
                # m.weight.data.normal_(0,0.01)
                m.bias.data.zero_()

net = Net()
net.initialize_weights()
print(net.modules())
for m in net.modules():
    print(m)

运行结果:

# 这个是print(net.modules())的输出
<generator object Module.modules at 0x0000023BDCA23258>
# 这个是第一次从net.modules()取出来的东西,是整个网络的结构
Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)
# 从net.modules()第二次开始取得东西就是每一层了
Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Linear(in_features=400, out_features=120, bias=True)
Linear(in_features=120, out_features=84, bias=True)
Linear(in_features=84, out_features=10, bias=True)

其中呢,并不是每一层都有偏执bias的,有的卷积层可以设置成不要bias的,所以对于卷积网络参数的初始化,需要判断一下是否有bias,(不过我好像记得bias默认初始化为0?不确定,有知道的朋友可以交流)

torch.nn.init.xavier_normal(m.weight.data)
if m.bias is not None:
	m.bias.data.zero_()

上面代码表示用xavier_normal方法对该层的weight初始化,并判断是否存在偏执bias,若存在,将bias初始化为0。

4、尺寸计算与参数计算

我们把上面的主函数部分改成:

net = Net()
net.initialize_weights()
layers = {}
for m in net.modules():
    if isinstance(m,nn.Conv2d):
        print(m)
        break

这里的输出m就是:

Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))

这个卷积层,就是我们设置的第一个卷积层,含义就是:输入3通道,输出6通道,卷积核 5 × 5 , 步长1,padding=0。

【问题1:输入特征图和输出特征图的尺寸计算】

之前的文章也讲过这个了,

在这里插入图片描述
用代码来验证一下这个公式:

net = Net()
net.initialize_weights()
input = torch.ones((16,3,10,10))
output = net.conv1(input)
print(input.shape)
print(output.shape)

初始结果:

torch.Size([16, 3, 10, 10])
torch.Size([16, 6, 6, 6])

第一个维度上batch , 第二个是通道channel, 第三个和第四个是图片
(特征图)的尺寸

在这里插入图片描述

【问题2:这个卷积层中有多少的参数?】

输入通道是3通道的,输出是6通道的,卷积核是 5×5
的,所以理解为6个3 × 5 × 5的卷积核,所以不考虑bias的话,参数量是
3 × 5 × 5 × 6 = 450 , 考虑bais的话,就每一个卷积核再增加一个偏置值。(这是一个一般人会忽略的知识点欸)

下面用代码来验证:

net = Net()
net.initialize_weights()
for m in net.modules():
    if isinstance(m,nn.Conv2d):
        print(m)
        print(m.weight.shape)
        print(m.bias.shape)
        break

输出结果是:

Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
torch.Size([6, 3, 5, 5])
torch.Size([6])
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

构建模型三要素与权重初始化 的相关文章

  • 单例模式的4种写法

    单例模式是开发过程中常用的模式之一 首先了解下单例模式的四大原则 构造方法私有 以静态方法或枚举返回实例 确保实例只有一个 尤其是多线程环境 确保反射或反序列化时不会重新构建对象 饿汉模式 饿汉模式在类被初始化时就创建对象 以空间换时间 故
  • Redis的多路复用机制

    Redis是单线程还是多线程 通常我们所说的Redis 是单线程 主要是指 Redis 的网络 IO 和键值对读写是由一个线程来完成的 这也是 Redis 对外提供键值存储服务的主要流程 但 Redis 的其他功能 比如持久化 异步删除 集
  • [RoarCTF 2019]Easy Calc

    进入题目是一个计算器的功能界面 查看源代码 可以发现是有WAF的 且存在一个calc php文件 这里接收一个num参数 可以看到这里创建了一个黑名单列表 然后用正则是去匹配 进行非法参数的过滤 那这题就是要绕过这个过滤和过一个WAF了 先
  • Pycharm使用---Black代码格式化工具

    前言 一个代码规范 可读性强是我们在写代码或者看代码时最期望的 也便于我们理解代码的功能和思路 而对于格式不是很规范的代码 要去修改其格式 如果单纯靠人工更正格式 对于简短的代码 难度不是很高 但是遇到一个较长 功能复杂的代码或者项目 人工
  • vue3.3 v-model 双向绑定

    配置代码还是有必要贴出的 老截图也不好 plugins vue script defineModel true propsDestructure true vueJsx
  • mongoDB数据库net stop mongoDB 发生系统错误 5。 拒绝访问。

    在使用mongoDB的时候命令行输入 net stop start mongDB停止 启动数据数据库时 终端报错如下 报错原因 权限不够 启动MongoDB服务需要以管理员的身份启动CMD 解决方案 CMD命令提示符地址 c盘 gt win
  • jmeter—建立测试计划

    一个测试计划描述了一系列 Jmeter 运行时要执行的步骤 一个 完整的测试计划包含 一个或者多个线程组 逻 辑控制 取样发生控制 监听器 定时器 断言和配置元件 一 建立测试计划 在这一部分 你将学到如何创建一个基础的测试计划来测试网站

随机推荐

  • 谈谈测试种类有哪些?

    此块引用怎么也删不掉了 那就留着吧 本来想在前面写点感想 害 也许是我不会用 灰度测试 A B测试 BVT测试 UAT测试 埋点测试 接口测试 缓存测试 灰度测试 灰度测试 就是在某项产品或应用正式发布前 选择特定人群试用 逐步扩大其试用者
  • @RequestBody不生效,获取不到数据

    RequestBody不生效 获取不到数据 网上找了很多 试过了不生效 最后检查引用包的时候 看到引用到了swagger的 RequestBody去了 大无语事件 不看还不知道swagger也有一个 RequestBody 改为引用spri
  • 【廖雪峰python进阶笔记】函数式编程

    1 高阶函数 高阶函数就是可以把函数作为参数的函数 下面我们看一个简单的高阶函数 def add x y f return f x f y 如果传入abs作为参数f的值 add 5 9 abs 根据函数的定义 函数执行的代码实际上是 abs
  • 常用的COMSOL操作符和数学函数

    算符 d f x f对x方向的微分 1 使用d算符来计算一个变量对另一个变量的导数 如 d T x 指变量T对x求导 而d u 2 u 2 u等 2 如果模型中含有任何独立变量 建模中使用d算符会使模型变为非线性 3 在解的后处理上使用d算
  • moviepy音视频剪辑:使用fl_time进行诸如快播、慢播、倒序播放等时间特效处理的原理、代码实现以及需要注意的坑

    专栏 Python基础教程目录 专栏 使用PyQt开发图形界面Python应用 专栏 PyQt moviepy音视频剪辑实战 专栏 PyQt入门学习 老猿Python博文目录 老猿学5G博文目录 一 引言 在 moviepy音视频剪辑 mo
  • Eclipse中创建Web项目(2023年)

    在创建Web项目前要先配置好JDK环境以及Tomcat环境 配置教程已经发过了 接下来我们开始创建第一个Web项目 目录 一 创建web项目 二 整合Tomcat服务器 三 项目部署到Tomcat中 一 创建web项目 1 打开Eclips
  • Android集合和数据的相互转换

    1 集合转换成数组 如有需要可以把String换成其他类 List
  • 知识图谱从哪里来:实体关系抽取的现状与未来

    点击上方 Datawhale 选择 星标 公众号 第一时间获取价值内容 最近几年深度学习引发的人工智能浪潮席卷全球 在互联网普及带来的海量数据资源和摩尔定律支配下飞速提升的算力资源双重加持下 深度学习深入影响了自然语言处理的各个方向 极大推
  • iOS Provisioning Profile(Certificate)与Code Signing详解

    引言 关于开发证书配置 Certificates Identifiers Provisioning Profiles 相信做 iOS 开发的同学没少被折腾 对于一个 iOS 开发小白 半吊子 比如像我自己 抑或老兵 或多或少会有或曾有过以下
  • js中数组常用几种方法

    Array 前端js数组常用方法 1 for Each 此方法是将数组中的每个元素执行传进提供的函数 没有返回值 var arr 1 2 3 4 5 function m1 a console log a 2 arr forEach m1
  • jQuery XSS漏洞原因查找及解决方案

    测试网站是否存在此XSS跨站漏洞 以google浏览器为例 打开要测试的网站 在Console窗口输入 element attribute img src 123123 回车之后会出现弹窗 说明存在XSS跨站漏洞 解决方案 升级jquery
  • 专访虎牙直播毛茂德

    引言 作为一位经历了互联网 移动互联网阶段的老兵 毛茂德老师一路走来 始终保持自己的技术初心 不断探索未知领域的宽度 进入虎牙直播后 他积极推动虎牙拥抱云原生 进行业务创新 同时他也发挥技术优势 通过高效运维为企业实现了降本增效 专注于技术
  • OSS 如何获取阿里云的bucket和endpoint

    如何获取阿里云oss所需的bucket和endpoint 关于阿里云oss的使用 本篇文章主要讲述如何获取我们需要获取的参数是 bucket和endpoint 这2个参数比较好获得 实际上 和这2个参数决定了 您上传文件的最终访问地址 这个
  • 关联对象源码分析

    什么是关联对象 一个对象可以关联多个对象 可以扩展原有对象的能力 关联是拥有的关系 Case1 Category可以使用 property添加一个属性吗 interface NSString MyNSString property nona
  • Pandas模块:Python科学计算神器之一

    欢迎来到我的博客 作者 秋无之地 简介 CSDN爬虫 后端 大数据领域创作者 目前从事python爬虫 后端和大数据等相关工作 主要擅长领域有 爬虫 后端 大数据开发 数据分析等 欢迎小伙伴们点赞 收藏 留言 关注 关注必回关 上一篇文章已
  • JSON和String的相互转换

    1 java转JSON JSON toJSONString 将java对象 java集合 Json对象转为jsonString JSON toJSON 将java对象 java集合转为json对象 3 JSON转Java JSON pars
  • java中如何创建一个多线程类呢?

    转自 java中如何创建一个多线程类呢 下文笔者讲述创建多线程类的方法分享 如下所示 实现思路 方式1 继承Thread类 重新Run方法 方式2 继承Runnable接口 重写Run方法 方式3 使用拉姆达表达式 例 package co
  • 陷波滤波器(Notch Filter)和峰值滤波器(Peak Filter)

    陷波滤波器 Notch Filter 陷波滤波器是带阻滤波器的一种 其阻带很窄 因此也称点阻滤波器 常常用于去除固定频率分量或阻带很窄的地方 如用于去除直流分量 去除某些特定频率分量 峰值滤波器与陷波滤波器恰好相反 峰值滤波器是带通滤波器的
  • Sublime Text 3高亮主题配置

    之前由于sublime的默认主题 灰白 比较难看 所以用得少 最近找到了一个比较漂亮的主题 再次因为sublime的轻便再次高频使用 先上图 以下是python代码的显示情况 这里使用的是theme freesia主题1 该主题下还有很多配
  • 构建模型三要素与权重初始化

    学习过程中的好文 谨防失效 转载自博客园 结合此篇看 1 模型三要素 三要素其实很简单 必须要继承nn Module这个类 要让PyTorch知道这个类是一个Module 在 init self 中设置好需要的组件 比如conv pooli