pytorch入门day5-卷积神经网络实战

2023-11-18

目录

LeNet网络实战

ResNet 

 训练函数


LeNet网络实战

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch import nn
from    lenet5 import Lenet5
from torch import optim


def main():
    batch_size=32
    cifar_train=datasets.CIFAR10('cifar',True,transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
    ]),download=True)
    cifar_train=DataLoader(cifar_train,batch_size=batch_size,shuffle=True)

    cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
    ]), download=True)
    cifar_test = DataLoader(cifar_train, batch_size=batch_size, shuffle=True)
    
    x,label=iter(cifar_train).next()
    print('x:', x.shape, 'label:', label.shape)

    model.train()
    device = torch.device('cuda')
    model = Lenet5().to(device)
    # CrossEntropyLoss进行交叉熵运算(包括了softmax计算),判断多分类问题中预测试与真实值的差距
    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    for epoch in range(1000):
        for batchidx,(x,label) in enumerate(cifar_train):
            # x [b, 3, 32, 32]
            # label [b]
            x, label = x.to(device), label.to(device)
            logits=model(x)
            # logits: [b, 10]
            # label:  [b]
            # loss: tensor scalar
            loss = criteon(logits, label)

            # backprop
            #先将梯度清零,再进行反向传播,再进行梯度更新
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if epoch % 10 == 0:
            print('epoch {}: loss = {}'
                  .format(epoch, loss.item()))

        model.eval()
        #在测试集内 不需要跟踪反向梯度计算
        with torch.no_grad():
            total_correct=0
            total_num=0
            for x,label in cifar_test:
                x,label=x.to(device),label.to(device)
                logits=model(x)
                pred=logits.argmax(dim=1)
                correct=torch.eq(pred,label).float().sum().item()#eq函数调用后会返回一个byte,true或者false估计,然后需要将其转换成float类型再通过item()函数来提取它的值
                total_correct+=correct
                total_num+=x.size(0)

                acc = total_correct / total_num
                print(epoch, 'test acc:', acc)



if __name__ == '__main__':
    main()

ResNet 

import torch
from    torch import  nn
from    torch.nn import functional as F


class ResBlk(nn.Module):
    """
    残差网络块
    """
    def __init__(self,ch_in,ch_out,stride=1):
        super(ResBlk,self).__init__()

        self.conv1=nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)
        self.bn1=nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)
        self.extra=nn.Sequential()

        if ch_out!=ch_in:
            #[b, ch_in, h, w] => [b, ch_out, h, w]
            self.extra=nn.Sequential(
                nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),
                nn.BatchNorm2d(ch_out)
            )

    def forward(self,x):
        """
            :param x: [b, ch, h, w]
        """
        out=F.relu(self.bn1(self.conv1(x)))
        out=self.bn2(self.conv2(out))
        #短接
        # extra module: [b, ch_in, h, w] => [b, ch_out, h, w] 保证能够进行残差计算
        out = self.extra(x) + out
        out = F.relu(out)

        return out

class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18, self).__init__()

        self.conv1=nn.Sequential(
            nn.Conv2d(3,64,kernel_size=3,stride=3,padding=1),
            nn.BatchNorm2d(64)

        )
        #紧跟着设置4个残差块
        # [b, 64, h, w] => [b, 128, h ,w]
        self.blk1=ResBlk(64,128,stride=2)
        # [b, 128, h, w] => [b, 256, h, w]
        self.blk2=ResBlk(128,256,stride=2)
        #  [b, 256, h, w] => [b, 512, h, w]
        self.blk3 = ResBlk(256, 512, stride=2)
        #  [b, 512, h, w] => [b, 512, h, w]
        self.blk4 = ResBlk(512, 512, stride=2)

        self.outlayer = nn.Linear(512 * 1 * 1, 10)

    def forward(self,x):
        x=F.relu(self.conv1(x))
        # [b, 64, h, w] => [b, 1024, h, w]
        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)
        # print('after conv:', x.shape)#[2, 512, 1, 1]
        x = F.adaptive_avg_pool2d(x, [1, 1])
        # print('after pool:', x.shape)
        x = x.view(x.size(0), -1)
        x = self.outlayer(x)

        return x


def main():
    blk = ResBlk(64, 128, stride=2)
    tmp = torch.randn(2, 64, 32, 32)
    out = blk(tmp)
    print('block:', out.shape)

    x = torch.randn(2, 3, 32, 32)
    model = ResNet18()
    out = model(x)
    print('resnet:', out.shape)

if __name__ == '__main__':
    main()

 训练函数

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch import nn
from    lenet5 import Lenet5
from torch import optim
from    resnet import ResNet18


def main():
    batch_size=32
    cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batch_size, shuffle=True)

    cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batch_size, shuffle=True)
    
    x,label=iter(cifar_train).next()
    print('x:', x.shape, 'label:', label.shape)


    device = torch.device('cuda')
    #model = Lenet5().to(device)
    model = ResNet18().to(device)
    # CrossEnmodel = Lenet5().to(device)tropyLoss进行交叉熵运算(包括了softmax计算),判断多分类问题中预测试与真实值的差距
    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    model.train()
    for epoch in range(1000):
        for batchidx,(x,label) in enumerate(cifar_train):
            # x [b, 3, 32, 32]
            # label [b]
            x, label = x.to(device), label.to(device)
            logits=model(x)
            # logits: [b, 10]
            # label:  [b]
            # loss: tensor scalar
            loss = criteon(logits, label)

            # backprop
            #先将梯度清零,再进行反向传播,再进行梯度更新
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if epoch % 10 == 0:
            print('epoch {}: loss = {}'
                  .format(epoch, loss.item()))

        model.eval()
        #在测试集内 不需要跟踪反向梯度计算
        with torch.no_grad():
            total_correct=0
            total_num=0
            for x, label in cifar_test:
                x,label=x.to(device),label.to(device)
                logits=model(x)
                pred=logits.argmax(dim=1)
                correct=torch.eq(pred,label).float().sum().item()#eq函数调用后会返回一个byte,true或者false估计,然后需要将其转换成float类型再通过item()函数来提取它的值
                total_correct+=correct
                total_num+=x.size(0)

                acc = total_correct / total_num
            if epoch % 10 == 0:
                print('epoch {}: acc = {}'
                        .format(epoch, acc))




if __name__ == '__main__':
    main()

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch入门day5-卷积神经网络实战 的相关文章

随机推荐

  • HTML5 FormData 方法介绍

    XMLHttpRequest 是一个浏览器接口 通过它 我们可以使得 Javascript 进行 HTTP S 通信 XMLHttpRequest 在现在浏览器中是一种常用的前后台交互数据的方式 2008年 2 月 XMLHttpReque
  • Node.js中的回调解析

    Node js 异步编程的直接体现就是回调 异步编程依托于回调来实现 但不能说使用了回调后程序就异步化了 回调函数在完成任务后就会被调用 Node 使用了大量的回调函数 Node 所有 API 都支持回调函数 例如 我们可以一边读取文件 一
  • TCP打洞和UDP打洞的区别

    为什么网上讲到的P2P打洞基本上都是基于UDP协议的打洞 难道TCP不可能打洞 还是TCP打洞难于实现 假设现在有内网客户端A和内网客户端B 有公网服务端S 如果A和B想要进行UDP通信 则必须穿透双方的NAT路由 假设为NAT A和NAT
  • 决策树和随机森林的实现,可视化和优化方法

    决策树原理 决策树原理这篇文章讲的很详细 本文仅写代码实现 构造决策树 matplotlib inline import matplotlib pyplot as plt import pandas as pd from sklearn d
  • Ubuntu 下配置android studio 配置 adb环境变量 普通用户可以执行 root用户无法执行

    我们环境变量其实已经配置好了 但普通用户下可以执行adb root权限下就不能执行 我们看一下 普通用户和 root用户所在的目录 root 权限不能执行 我们切换一下 普通用户的当前目录 普通用户adb 可以执行 我们新开一个窗口 切换到
  • notepad++插件查看十六进制

    下载hex editor 点击plugins 选择plugin manager show plugin managers 然后再available里面找到hex editor 然后下载 使用hex editor 点击plugins hex
  • 1055. 集体照 (25) PAT乙级真题

    1055 集体照 25 拍集体照时队形很重要 这里对给定的N个人K排的队形设计排队规则如下 每排人数为N K 向下取整 多出来的人全部站在最后一排 后排所有人的个子都不比前排任何人矮 每排中最高者站中间 中间位置为m 2 1 其中m为该排人
  • clockwise print binary search tree

    给一个二叉树 顺时针打印出所有的节点 例如 应该打印 20 8 4 10 14 25 22 思路 可以分为三步打印 1 打印左边界 最后一个叶子节点不打印 2 打印所有叶子结点 3 打印右边界 根和最后一个叶子结点不打印 代码如下 void
  • buu [ACTF2020 新生赛]Exec

    这道题倒是很直白简单 基本上知道命令注入这个知识点就行 命令注入 127 0 0 1 ls 127 0 0 1 ls flag 127 0 0 1 cat flag
  • 软件测试包括哪些内容

    以下是一些需要考虑的步骤 1 得到需求 功能设计 内部设计说书和其他必要的文档 2 得到预算和进度要求 3 确定与项目有关的人员和他们的责任 对报告的要求 所需的标准和过程 例如发行过程 变更过程 等等 4 确定应用软件的高风险范围 建立优
  • NIO - IO多路复用详解

    文章目录 Java NIO IO多路复用详解 现实场景 典型的多路复用IO实现 Reactor模型和Proactor模型 传统IO模型 Reactor事件驱动模型 Reactor模型 业务处理与IO分离 Reactor模型 并发读写 Rea
  • Adobe PhotoShop安装程序无法初始化的解决办法

    近日需要使用PhotoShop 不想下了好几个水版 给大家一个可以用的 http pan baidu com s 1dDnJLy5 不仅安装不了而且把机子的注册表改了 再次安装时就出现了 安装程序无法初始化 的问题 网上各种查找 零零散散的
  • 范数(简单的理解)、范数的用途、什么是范数

    没学好矩阵代数的估计范数也不是太清楚 当然学好的人也不是太多 范数主要是对矩阵和向量的一种描述 有了描述那么 大小就可以比较了 从字面理解一种比较构成规范的数 有了统一的规范 就可以比较了 例如 1比2小我们一目了然 可是 3 5 3 和
  • kafka消费的三种模式_快速认识Kafka

    1 Kafka是什么 简单的说 Kafka是由Linkedin开发的一个分布式的消息队列系统 Message Queue kafka的架构师jay kreps非常喜欢franz kafka 觉得kafka这个名字很酷 因此将linkedin
  • 【TypeScript(一)】TypeScript的变量类型及声明

    TypeScript的变量类型及声明 TS和JS最大的区别就是在其中给变量引入了类型的概念 比如像之前我们使用JS的时候 var a 10 这时a是数据类型 但当如果我们后面要使用a 我们也可以a hello 这时a就变成了一个字符串类型
  • 微信小程序开发设置获取权限管理,摄像头权限,位置权限,用户信息权限等

    在小程序开发的时候 我们总会遇到很多权限问题 比如摄像头权限 位置权限 用户信息权限等 如果不加以判断 很难给用户一个好的体验 有一天 小明来参观一个拍照微信小程序 他很感兴趣 看着精美的页面 忍不住点击了拍照按钮 然而 他太兴奋了 以至于
  • linux 系统调用列表 /usr/include/asm/unistd.h

    一 进程控制 fork 创建一个新进程 clone 按指定条件创建子进程 execve 运行可执行文件 exit 中止进程 exit 立即中止当前进程 getdtablesize 进程所能打开的最大文件数 getpgid 获取指定进程组标识
  • pybind 传递指针

    编码h264可以参考 https blog csdn net jacke121 article details 87484745 python部分 先接收指针vp 再调用 是可以的 coding utf 8 import binddemo
  • 连接db2的客户端工具(原创)

    最近在用友做项目 用得数据库是db2 以前从来没用过 但是对于写程序来说 啥数据库都一样 都是那几个语句 能执行就行 说是这样说 但是真用上就发现问题了 最大的就是没有好的客户端工具 网上搜了很多 什么toad quest都用了 感觉用着都
  • pytorch入门day5-卷积神经网络实战

    目录 LeNet网络实战 ResNet 训练函数 LeNet网络实战 import torch from torch utils data import DataLoader from torchvision import datasets