PyTorch实现ResNet18

2023-10-31

ResNet-18结构

在这里插入图片描述

基本结点

在这里插入图片描述

代码实现

import torch
import torch.nn as nn
from torch.nn import functional as F


class RestNetBasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(RestNetBasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        output = self.conv1(x)
        output = F.relu(self.bn1(output))
        output = self.conv2(output)
        output = self.bn2(output)
        return F.relu(x + output)


class RestNetDownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(RestNetDownBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride[0], padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride[1], padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.extra = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride[0], padding=0),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        extra_x = self.extra(x)
        output = self.conv1(x)
        out = F.relu(self.bn1(output))

        out = self.conv2(out)
        out = self.bn2(out)
        return F.relu(extra_x + out)


class RestNet18(nn.Module):
    def __init__(self):
        super(RestNet18, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = nn.Sequential(RestNetBasicBlock(64, 64, 1),
                                    RestNetBasicBlock(64, 64, 1))

        self.layer2 = nn.Sequential(RestNetDownBlock(64, 128, [2, 1]),
                                    RestNetBasicBlock(128, 128, 1))

        self.layer3 = nn.Sequential(RestNetDownBlock(128, 256, [2, 1]),
                                    RestNetBasicBlock(256, 256, 1))

        self.layer4 = nn.Sequential(RestNetDownBlock(256, 512, [2, 1]),
                                    RestNetBasicBlock(512, 512, 1))

        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

        self.fc = nn.Linear(512, 10)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.reshape(x.shape[0], -1)
        out = self.fc(out)
        return out

用来预测CIFAR-10数据集

数据集

官网链接:CIFAR-10 DATASET
在这里插入图片描述

测试代码

import torch
from torch import nn, optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from restnet18.restnet18 import RestNet18


#  用CIFAR-10 数据集进行实验

def main():
    batchsz = 128

    cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)

    cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)

    x, label = iter(cifar_train).next()
    print('x:', x.shape, 'label:', label.shape)

    device = torch.device('cuda')
    # model = Lenet5().to(device)
    model = RestNet18().to(device)

    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    for epoch in range(1000):

        model.train()
        for batchidx, (x, label) in enumerate(cifar_train):
            # [b, 3, 32, 32]
            # [b]
            x, label = x.to(device), label.to(device)

            logits = model(x)
            # logits: [b, 10]
            # label:  [b]
            # loss: tensor scalar
            loss = criteon(logits, label)

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(epoch, 'loss:', loss.item())

        model.eval()
        with torch.no_grad():
            # test
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:
                # [b, 3, 32, 32]
                # [b]
                x, label = x.to(device), label.to(device)

                # [b, 10]
                logits = model(x)
                # [b]
                pred = logits.argmax(dim=1)
                # [b] vs [b] => scalar tensor
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)
                # print(correct)

            acc = total_correct / total_num
            print(epoch, 'test acc:', acc)


if __name__ == '__main__':
    main()

运行结果

在这里插入图片描述
感觉挺low的,迭代50多次能达到80多的准确率

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch实现ResNet18 的相关文章

  • 华为OD2023(A卷)基础题26【最大利润、贪心的商人】

    题目 最大利润 商人经营一家店铺 有number种商品 由于仓库限制每件商品的最大持有数量是item index 每种商品的价格是item price item index day 通过对商品的买进和卖出获取利润 请给出商人在days天内能
  • qcharts控件如何提升

    条件 qt5 9版本以后 编译器也要对应的新版本 否则无法释放该版本qt库的所有功能 已经下载了qtcharts模块 如果安装qt时候没有勾选 则打开安装目录的MaintenanceTool exe软件 重新将qtcharts模块勾选上安装
  • R-Drop和SimCSE解读

    R Drop Regularized Dropout for Neural Networks R Drop的基本思想是 同一个step里面 对于同一个样本 前向传播两次 由于Dropout的存在 会得到两个不同但差异很小的概率分布 通过在原

随机推荐

  • JavaScript中远程级联调用(RPC)java对象中的方法并返回结果

    http code google com p json rpc for java downloads list
  • 如何修改服务器远程端口

    1 开始 运行 regedit 2 依次展开 HKEY LOCAL MACHINE SYSTEM CURRENTCONTROLSET CONTROL TERMINAL SERVER WDS RDPWD TDS TCP 右边键值中 PortN
  • Stm32f030 双串口

    void USART INIT void GPIO InitTypeDef GPIO InitStructure USART InitTypeDef USART InitStructure NVIC InitTypeDef NVIC Ini
  • Ubuntu 安装 conda

    下载 Anaconda 进入 Ubuntu 自己新建下载路径 输入以下命令开始下载 注意 如果不是 x86 64 需要去镜像看对应的版本 https mirrors bfsu edu cn anaconda archive C M O A
  • (十三)CMake MESSAGE和PROJECT

    一 MESSAGE MEESSAGE的功能是记录一个信息 当我们执行 编译 含有message命令的代码时 将会在终端打印指定内容 如果超过一个信息字符串 它将会拼接成一个信息 无缝连接 MESSAGE报告的信息可以是 普通信息 报告检查信
  • android手机安装ubuntu并创建ubuntu图形界面(1)

    在安卓手机上安装Ubuntu并创建图形界面 1 下载termux 用手机直接打开网址Termux F Droid Free and Open Source Android App Repository 点击下载apk并安装 安装后启动界面如
  • Scrapy运行builtins.ImportError: No module named 'win32api'

    windows 下 安装好scrapy后 运行 scrapy bench 报错builtins ImportError No module named win32api 解决方法 pip install pypiwin32
  • 关于自搭网站XAMPP(一)前后端AJAX-PHP数据连通

    前端AJAX代码
  • DEDECMS如何将图片轮播做到后台控制

    网上找了一大堆 试了好多方法 都不管用 最后偶尔看到这几行代码 没想到成功了 然后自己做个总结 方法如下 直接建立一个顶级栏目 然后在该顶级栏目里添加文档 在文档里面只上传缩略图 不要添加内容 然后在模板页面调用下面的代码标签 就好啦 把下
  • CRC32爆破小结

    前言 最近在bugku遇到了一道隐写题 binwalk之后发现里面有很多个压缩包 然后就无从下手 于是查看别人大佬的wp才发现是CRC32爆破 由于本人第一次遇到这种题目 就记录一下吧 正文 CRC想必大家都知道 它的全称是循环冗余校验 C
  • 2022-面试题汇总

    1 四大频繁Full GC原因 1 大量反射代码使永久代类太多导致频繁Full GC 解决方案 在有大量反射代码的场景下 只要把 XX SoftRefLRUPolicyMSPerMB 0 这个参数设置大一些即可 千万别让一些新手同学设置为0
  • 图像处理库(fbc_cv):源自OpenCV代码提取

    在实际项目中会经常用到一些基本的图像处理操作 而且经常拿OpenCV进行结果对比 因此这里从OpenCV中提取了一些代码组织成fbc cv库 项目fbc cv所有的代码已放到GitHub中 地址为 https github com feng
  • java总结输入流输出流

    1 什么是IO Java中I O操作主要是指使用Java进行输入 输出操作 Java所有的I O机制都是基于数据流进行输入输出 这些数据流表示了字符或者字节数据的流动序列 Java的I O流提供了读写数据的标准方法 任何Java中表示数据源
  • 算法笔记-图搜索

    统计图的连通分支数 思路 建图 搜索 注意这种建图方式是有向图 反例 1 2 3 4 4 1这种不会识别出来 因此建图时需要使用有向图 在add阶段加入两个方向的路径 add时从1开始的边的标号 0用来判断结束 斗则冲突有问题 int to
  • 追雨的际遇

    追雨 下班 刚出公司 隐约看到远处电闪雷鸣 明明今天是大好的晴天 看到电闪 确实稀奇 忽然豆大的雨点落了下来 恰逢我骑到桥洞底下 让雨先跑10分钟 等我换好雨衣 就去追她 桥洞底下 停车 开后备箱 开始换雨衣 陆续很多摩托停在我的身后 他们
  • Vue使用v-for遍历map

    功能 遍历数据库中按钮的图片和名字 当页面打开时 触发查询事件 以下图形式显示出来 前端代码 遍历存在数据库中的按钮名称和图片名称 其中按钮的click事件名称和按钮图片名称相同
  • 【Linux】Linux是如何诞生的?

    本文主要讲述Linux的诞生背景以及一些小故事 其中 还清晰地讲述了Unix BSD GUN GPL等名词的含义及来源 Table of Contents Unix C语言 BSD GUN GPL Linux Linux的内核发展 注意 本
  • 无法打开程序因为msvcp140.dll丢失,msvcp140.dll丢失的解决方法

    前几天看到有小伙伴再问什么是msvcp140 dll文件 相信很多人都不知道这是什么吧 如果电脑msvcp140 dll文件丢失的话会怎么样呢 丢失了应该如何找回呢 其实这些都是一些比较常见的电脑知识 我们是需要去了解一下的 废话不多说 下
  • DearMob iPhone Manager for Mac(iPhone手机数据加密传输软件)

    DearMob iPhone Manager 是Mac平台上一款功能强大的iPhone数据传输工具 无需iTunes即可完成数据传输 DearMob iPhone Manager Mac版能够为您进行影片 音乐 照片 通讯录等内容进行传输或
  • PyTorch实现ResNet18

    ResNet 18结构 基本结点 代码实现 import torch import torch nn as nn from torch nn import functional as F class RestNetBasicBlock nn