LeNet简单实现

2023-05-16

1 LeNet

import torch.nn as nn


class LeNet(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        # input shape: (batch_size, in_channels, 32, 32)
        x = self.conv1(x)
        # shape: (batch_size, 6, 28, 28)
        x = self.relu(x)
        # shape: (batch_size, 6, 28, 28)
        x = self.pool(x)
        # shape: (batch_size, 6, 14, 14)
        x = self.conv2(x)
        # shape: (batch_size, 16, 10, 10)
        x = self.relu(x)
        # shape: (batch_size, 16, 10, 10)
        x = self.pool(x)
        # shape: (batch_size, 16, 5, 5)
        x = x.view(x.size(0), -1)
        # shape: (batch_size, 16 * 5 * 5)
        x = self.fc1(x)
        # shape: (batch_size, 120)
        x = self.relu(x)
        # shape: (batch_size, 120)
        x = self.fc2(x)
        # shape: (batch_size, 84)
        x = self.relu(x)
        # shape: (batch_size, 84)
        x = self.fc3(x)
        # shape: (batch_size, num_classes)
        return x

2 Train

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from model import LeNet

# Define the transform to normalize the data
transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Define hyperparameters
batch_size = 64
learning_rate = 0.01
num_epochs = 10

# Load the MNIST dataset
train_dataset = MNIST(root='./dataset', train=True, download=True, transform=transform)
test_dataset = MNIST(root='./dataset', train=False, download=True, transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Create the LeNet model
model = LeNet()

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

# Train the model
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print training loss and accuracy
        if (i+1) % 100 == 0:
            total = labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            correct = (predicted == labels).sum().item()
            accuracy = correct / total
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%'
                  .format(epoch+1, num_epochs, i+1, len(train_loader), loss.item(), accuracy*100))

# Test the model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy: {:.2f}%'.format(correct / total * 100))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

LeNet简单实现 的相关文章

  • 免费资料 | RoboMaster资料包分享,备赛福利来啦

    资料包链接 xff1a 腾讯文档 RoboMaster 产品资料全集合 RoboMaster 产品资料全集合 2021年天之博特参与协办的首届RMUA人工智能挑战赛中国赛赛事 xff0c 见证了各个高校参赛队伍每一个奋力拼搏的瞬间 xff0
  • PostgreSQL与MySQL对比

    PostgreSQL与MySQL对比 都属于开放源码的一员 xff0c 性能和功能都在高速地提高和增强 MySQL AB的人们和PostgreSQL的开发者们都在尽可能地把各自的数据库改得越来越好 xff0c 所以对于任何商业数据库使用其中
  • 最简单最节省成本的锂电池充电电路!拆开火火兔,搬起小板凳,听老梁分析...

    作者 xff1a LR梁锐 xff0c 整理 xff1a 晓宇 微信公众号 xff1a 芯片之家 xff08 ID xff1a chiphome dy xff09 用了一年的火火兔坏了 xff0c 充不了电 作为一名合格的电工 xff0c
  • 题解 教主的魔法(分块学习记录)

    64 luogu 看到询问个数少 xff0c 分块的复杂度能过 xff0c 于是人生第一次打了分块 xff0c 居然A了 据说也有线段树瞎搞的 xff0c 不过我不会写 总之 xff0c 边角暴力 xff0c 块内二分 xff0c 受影响的
  • Ubuntu虚拟机找不到共享文件夹的解决办法

    Ubuntu虚拟机找不到共享文件夹的解决办法 一 查看共享文件夹是否设置成功 vmware hgfsclient 二 挂载共享文件夹到 mnt目录下 sudo vmhgfs fuse host mnt o nonempty o allow
  • 秒懂函数回调机制,回调函数看这篇就够了

    什么是回调函数 友情提示 xff1a 原理介绍部分摘自 xff1a https www jianshu com p 2f695d6fd64f 有一定基础的直接跳过即可 xff0c 直接查看后面精彩部分 回调函数就是一个通过函数指针调用的函数
  • vbox下安装archlinux

    博主linux小白一个 xff0c 一直想试试archlinux xff0c 最近终于有时间了 xff0c 一番努力之后成功了 xff0c 写出来与大家分享 archlinux版本 2013 06 01 archlinux的优点就不说了 x
  • IDL环境下,HDF文件转TIFF格式

    在IDL环境下 xff0c 将HDF文件转TIFF格式 在遥感图像处理过程中 xff0c 我们经常遇到HDF文件 xff0c 如modis影像数据 那么HDF数据到底是怎样的呢 xff1f 百科的解释 xff1a HDF是用于存储和分发科学
  • ModuleNotFoundError:No Module named‘lpips‘问题怎么解决?

    今天在复现论文的时候 xff0c 发现配置环境中缺少一个 34 lpips 34 的包 这里记录一下 xff0c 给有需要的小伙伴 废话不多说 xff0c 直接上干货 xff1a 1 打开 https pypi org xff0c 输入缺少
  • bat文件批处理vcbuild、msbuild或者devenv

    最近用bat文件调用vcbuild或者msbuild xff0c 对于只调用简单的命令行 xff0c 可以很快上手 xff0c 可以查询msdn的关于msbuild的使用指导http msdn microsoft com zh cn lib
  • Delphi源程序格式书写规范

    1 规范简介 本规范主要规定Delphi源程序在书写过程中所应遵循的规则及注意事项 编写该规范的目的是使公司软件开发人员的源代码书写习惯保持一致 这样做可以使每一个组员都可以理解其它组员的代码 xff0c 以便于源代码的二次开发记忆系统的维
  • NoMachine出现 The session negotiation failed的解决方案及踩坑总结

    问题情况 xff1a 我A电脑输入用户名和密码可以远程B电脑 xff0c B电脑输入用户名密码就是登录不上A电脑 B电脑上密码是用的账户密码 xff08 就是图标是一把钥匙的那个 xff09 A电脑上的密码是用的PIN密码 xff08 Wi
  • 利用Python+阿里云实现DDNS(动态域名解析)

    利用Python 43 阿里云实现DDNS 动态域名解析 因需求公司路由器公网ip不是动态的 xff0c 需要及时的修改阿里云的域名解析 前期准备 二 准备 1 公网IP xff08 向运营商申请的动态IP xff09 2 域名 xff08
  • 已知入栈顺序,总结出栈顺序的规律

    规律 xff1a 出栈的每一个元素的后面 xff0c 其中比该元素先入栈的一定按照入栈逆顺序排列 举例说明 xff1a 已知入栈顺序 xff1a 1 2 3 4 5 判断出栈顺序 xff1a 4 3 5 1 2 结果 xff1a 不合理 x
  • Linux系统使用cpulimit对CPU使用率进行限制

    介绍 cpulimit 是一个限制进程的 CPU 使用率的工具 xff08 以百分比表示 xff0c 而不是以 CPU 时间表示 xff09 当不希望批处理作业占用太多 CPU 时 xff0c 控制批处理作业很有用 目标是防止进程运行超过指
  • 题解·连续攻击游戏

    64 luogu 看上去这是一道二分图题 xff0c 将点i和它的两个属性值分别作为两个点集 xff0c 分别连边后跑匈牙利树 xff0c 若找不到匹配则输出解 span class token macro property span cl
  • Linux系统内网穿透教程

    Linux系统内网穿透可以通过使用SSH反向隧道 NAT端口映射 VPN等多种方式实现 xff0c 下面分别介绍这三种方式的实现方法 1 SSH反向隧道 SSH是一种加密的远程登录协议 xff0c 可以通过SSH反向隧道来实现内网穿透 首先
  • cpufreq 之powersave和performance governer的实现

    我们再来看看powersave的实现 xff0c 如下所示event是CPUFREQ GOV START时 xff0c 即开始这个governer时直接调用 cpufreq driver target来设定最低频率 19 static in
  • SQL 入门,看这篇就够了 ---- 基础篇

    目录 目录 目录 数据库安装 数据库基本概念 数据库管理系统 xff08 DBMS xff09 的分类 SQL 语句 创建 删除 更新操作 创建数据库 xff08 CREAT DATABASE xff09 创建表 删除表 更新表 查询 筛选
  • vi编辑器 编辑模式及命令模式常用命令

    在网上虽然有许多类似的文章 xff0c 但写的很杂 xff0c 不如这本书上看着顺畅 本文是 PHP 43 MySQL开发实战 220页到222页的内容 vi编辑器 文本编辑器是所有计算机系统中最常用的一种工具 UNIX下的编辑器有ex s

随机推荐