基于Pytorch1.8.0+Win10+RTX3070的MNIST网络构建与训练

2023-11-16

直接上代码

先上整个的代码

import torch
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

#  参考:https://blog.csdn.net/sxf1061700625/article/details/105870851?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522162486393316780265489114%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=162486393316780265489114&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduend~default-1-105870851.first_rank_v2_pc_rank_v29_1&utm_term=pytorch++mnist&spm=1018.2226.3001.4187

class Mnist_Net(nn.Module):
    '''
    定义网络
    '''
    def __init__(self):
        super(Mnist_Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        # 激活函数
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        # 激活函数
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        # 激活函数
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        # 返回结果
        return F.log_softmax(x)

def training_net(epoch,network,train_loader,optimizer,train_losses, train_counter,log_interval):
    '''
    一个种群训练一代
    :param epoch: 用于现实到第几个代了
    :param network: 模型对象
    :param train_loader: 训练数据对象
    :param optimizer: 优化器对象
    :param train_losses:
    :param train_counter:
    :param log_interval:
    :return:
    '''
    network.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        # 将一个图片传入到网络中,得到out结果
        output = network(data)
        # 计算LOSS
        loss = F.nll_loss(output, target)
        # 反向传播LOSS
        loss.backward()
        # 优化器
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))
            train_losses.append(loss.item())
            train_counter.append((batch_idx * 64) + ((epoch - 1) * len(train_loader.dataset)))
            # 保存网络模型
            torch.save(network.state_dict(), './model.pth')
            # 保存优化器结果
            torch.save(optimizer.state_dict(), './optimizer.pth')


def testing_net(network, test_loader,test_losses):
    '''
    测试集执行
    :param network:
    :param test_loader:
    :param test_losses:
    :return:
    '''
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            # 首先得到out结果
            output = network(data)
            # 计算LOSS
            test_loss += F.nll_loss(output, target, size_average=False).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))


def view_dataset_figure(test_loader):
    '''
    展示训练和测试的数据图
    :param test_loader:
    :return:
    '''
    # 让我们看看一批测试数据由什么组成。
    examples = enumerate(test_loader)
    batch_idx, (example_data, example_targets) = next(examples)
    print(example_targets)
    print(example_data.shape)
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
        plt.title("Ground Truth: {}".format(example_targets[i]))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def show_loss_line_figure(train_counter,train_losses,test_counter,test_losses):
    '''
    展示LOSS曲线
    :param train_counter:
    :param train_losses:
    :param test_counter:
    :param test_losses:
    :return:
    '''
    fig = plt.figure()
    plt.plot(train_counter, train_losses, color='blue')
    plt.scatter(test_counter, test_losses, color='red')
    plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
    plt.xlabel('number of training examples seen')
    plt.ylabel('negative log likelihood loss')
    plt.show()


def show_predict_result(network,test_loader):
    '''
    展示预测数据的结果,目前是用的test数据集中的数据
    :param network:
    :param test_loader:
    :return:
    '''
    examples = enumerate(test_loader)
    batch_idx, (example_data, example_targets) = next(examples)
    with torch.no_grad():
        output = network(example_data)
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
        plt.title("Prediction: {}".format(
            output.data.max(1, keepdim=True)[1][i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def execute_through_new():
    '''
    新的执行训练
    :return:
    '''
    n_epochs = 3
    batch_size_train = 64
    batch_size_test = 1000
    learning_rate = 0.01
    momentum = 0.5
    log_interval = 10
    random_seed = 1
    torch.manual_seed(random_seed)
    train_loader_obj = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('./data/', train=True, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                           (0.1307,), (0.3081,))
                                   ])),
        batch_size=batch_size_train, shuffle=True)
    test_loader_obj = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('./data/', train=False, download=True, transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))])
                                   ), batch_size=batch_size_test, shuffle=True
    )
    view_dataset_figure(test_loader_obj)
    network_obj = Mnist_Net()
    optimizer_obj = optim.SGD(network_obj.parameters(), lr=learning_rate,momentum=momentum)
    train_losses_obj = []
    train_counter_obj = []
    test_losses_obj = []
    test_counter_obj = [i * len(train_loader_obj.dataset) for i in range(n_epochs + 1)]
    testing_net(network_obj, test_loader_obj, test_losses_obj)
    for epoch in range(1, n_epochs + 1):
        # 训练一代
        training_net(epoch, network_obj, train_loader_obj, optimizer_obj, train_losses_obj, train_counter_obj,log_interval)
        # 测试一代
        testing_net(network_obj, test_loader_obj, test_losses_obj)
    #画一下训练曲线
    show_loss_line_figure(train_counter_obj,train_losses_obj,test_counter_obj,test_losses_obj)
    #做预测的可视化
    show_predict_result(network_obj,test_loader_obj)


def execute_through_checkpoint():
    '''
    基于断点的执行训练
    :return:
    '''
    n_epochs = 30
    batch_size_train = 64
    batch_size_test = 1000
    learning_rate = 0.01
    momentum = 0.5
    log_interval = 10
    random_seed = 1
    torch.manual_seed(random_seed)
    # 加载数据
    train_loader_obj = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('./data/', train=True, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                           (0.1307,), (0.3081,))
                                   ])),batch_size=batch_size_train, shuffle=True)
    # 加载数据
    test_loader_obj = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('./data/', train=False, download=True, transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))])
                                   ), batch_size=batch_size_test, shuffle=True
    )
    # 查看数据
    view_dataset_figure(test_loader_obj)
    # 形成网络对象
    continued_network_obj = Mnist_Net()
    # 形成优化器对象
    continued_optimizer_obj = optim.SGD(continued_network_obj.parameters(), lr=learning_rate,momentum=momentum)
    # 重载断点
    network_state_dict = torch.load('model.pth')
    continued_network_obj.load_state_dict(network_state_dict)
    optimizer_state_dict = torch.load('optimizer.pth')
    continued_optimizer_obj.load_state_dict(optimizer_state_dict)

    train_losses_obj = []
    train_counter_obj = []
    test_losses_obj = []
    test_counter_obj = [i * len(train_loader_obj.dataset) for i in range(n_epochs + 1)]
    # 测试一下测试集 Test set: Avg. loss: 0.0347, Accuracy: 9887/10000 (99%)
    testing_net(continued_network_obj, test_loader_obj, test_losses_obj)
    for epoch in range(1, n_epochs + 1):
        # 每个epoch,test一下
        # 训练网络
        training_net(epoch, continued_network_obj, train_loader_obj, continued_optimizer_obj, train_losses_obj, train_counter_obj,log_interval)
        testing_net(continued_network_obj, test_loader_obj, test_losses_obj)
    #画一下训练曲线
    show_loss_line_figure(train_counter_obj,train_losses_obj,test_counter_obj,test_losses_obj)
    #做预测的可视化
    show_predict_result(continued_network_obj,test_loader_obj)

### 主入口
if __name__ == '__main__':
    # 情况一:训练全新的模型;
    # execute_through_new()
    # 情况二:在断点的基础上,接着训练
    execute_through_checkpoint()

算法流程

口号:2【加数据、定模型】+2【训练4、测试2】
在这里插入图片描述
这是主体流程,主要是训练和测试2大步骤,其中训练主要包括了4个环节:网络运行、LOSS计算、反向传播、优化;测试包括了2个环节:网络运行、计算LOSS;

讨论网络模型定义

构建5层,包括两个卷积层,一个Dropout层(降低过拟合),两个线性层,最后返回F.log_softmax(x)。其中,需要去了解Net是集成自nn.Module。
关于nn.Module的详细介绍会在后面的章节展开。

主要参考资料

https://blog.csdn.net/sxf1061700625/article/details/105870851?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522162486393316780265489114%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=162486393316780265489114&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduend~default-1-105870851.first_rank_v2_pc_rank_v29_1&utm_term=pytorch++mnist&spm=1018.2226.3001.4187
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

基于Pytorch1.8.0+Win10+RTX3070的MNIST网络构建与训练 的相关文章

随机推荐

  • Unity 改变鼠标指针的方法

    在网上查的帖子 先看一下 Texture2D ClickedCursorImg 把鼠标指针改为ClickedCursorImg Cursor SetCursor ClickedCursorImg Vector2 zero CursorMod
  • Api Savior 文档生成 idea 插件进阶教程

    原文地址见 Github Wiki Spring MVC 注解支持表 注解 注解字段 是否支持 作用描述 备注 RequestMapping value path 支持 绑定一个或多个 url RequestMapping method 支
  • JetBrains系列--工具使用方法

    JetBrains系列 工具使用方法 介绍 常用IDE 2 1 IDEA 2 2 pycharm 2 3 goland 2 4 clion 3 快捷方式 4 说明 JetBrains系列 工具使用方法 介绍 JetBrains 系列IDE是
  • 如何在Vue项目中给路由跳转加上进度条

    在平常浏览网页时 我们会注意到在有的网站中 当点击页面中的链接进行路由跳转时 页面顶部会有一个进度条 用来标示页面跳转的进度 如下图所示 虽然实际用处不大 但是对用户来说 有个进度条会大大减轻用户的等待压力 提升用户体验 本篇文章就来教你如
  • 程序员口中常说的API是什么意思?什么是API?

    什么是API 我的回答 API 应用程序编程接口 一般来说 这是一套明确定义的各种软件组件之间的通信方法 什么是API 我们不妨用一个小故事展示出来 研发人员A开发了软件A 研发人员B正在研发软件B 有一天 研发人员B想要调用软件A的部分功
  • Xilinx Vivado .coe文件生成

    一 COE格式文件生成 由于Quartus ii软件ROM用的是mif格式的文件 且可以用软件Guagle wave生成正弦波 三角波 锯齿波 我们可以利用这个软件先生成数据 然后再将其转化为符合COE格式的文件 具体请参考以下步骤 1 先
  • JavaWeb中如何将JSP文件的编码格式修改为UTF-8

    目录 一 修改原因 二 修改步骤 在使用eclipse学习jsp时 很多默认的编码都是ISO 8859 15 而我们需要使用的是utf 8编码 我们第一个接触改变jsp编码的方式可能都是在需要修改的jsp中修改 如下 将charset与pa
  • python线程与进程概述_1.24

    多进程与多线程 进程 Process 是计算机中的程序关于某数据集合上的一次运行活动 是系统进行资源分配和调度的基本单位 是操作系统结构的基础 线程 Thread 有时被称为轻量级进程 Lightweight Process LWP 是程序
  • Java 20新特性:Scoped Values 作用域值(孵化器)

    以下内容由New Bing自动生成 仅介绍了Scoped Values的部分内容 如果需要详细的Scoped Values信息 可以查阅官方JEP 429文档 Java JEP 429是 JDK 20 中引入的唯一一个新特性 目前还处于孵化
  • android点击按钮弹出输入框,android 弹出框(输入框和选择框)

    1 输入框 final EditText inputServer new EditText this inputServer setFilters new InputFilter new InputFilter LengthFilter 5
  • tcp短连接TIME_WAIT问题解决方法大全(3)——tcp_tw_recycle

    tcp tw recycle和tcp timestamps 参考官方文档 http www kernel org doc Documentation networking ip sysctl txt tcp tw recycle解释如下 t
  • ELK日志收集分析服务

    任务要求 搭建ELK集群 收集日志信息并展示 任务拆解 认识ELK 部署elasticsearch集群并了解其基本概念 安装elasticsearch head实现图形化操作 安装logstash收集日志 安装kibana日志展示 安装fi
  • Linux生产者消费者模型(POSIX信号量)

    目录 一 生产者消费者模型 1 基本概念 2 模型特点 3 模型优点 二 基于BlockingQueue的生产者消费者模型 1 基本概念 2 单生产者 单消费者为例进行模拟实现 3 基于计算任务的生产者消费者模型 三 POSIX信号量 1
  • Java经典面试:vuejs调用java后端

    第一个暴击 Spring 上一份Spring的手绘思维脑图 就像是个知识大纲总结 预览一下Spring的知识点 心里有个谱 不过这边我是采用的截图方式 为了把全部的内容都截取出来 所以整个就比较小 可能不是很清晰 Spring面试真题 七大
  • C语言进阶之路:对任意两个数字求和

    提示 可以参考笔者之前的文章 来对此篇博客进行思考 文章目录 介绍 一 如何正确去书写代码 二 使用步骤 1笔者所写代码 2 重要代码部分 总结 介绍 对本文要记录的大概内容 对任意两个数字进行求加减乘除运算 小数 以下是本篇文章正文内容
  • VMware-克隆虚拟机

    VMware 克隆虚拟机 在使用VMware过程中需要经常克隆虚拟机 但是在克隆完整虚拟机后通常都会出现一个问题就是 网络无法连接因为网卡冲突了 告诉大家如何解决 虚拟机克隆 在管理中选择克隆 克隆当前虚拟机状态 选择完整克隆 重新生成网卡
  • 数据结构和算法--链栈(C++实现)

    定义 栈是限定仅在表尾进行插入和删除操作的线性表 把允许插入和删除的一端称为栈顶 top 另一端称为栈底 bottom 不含任何数据元素的栈称为空栈 栈又称为后进先出 Last In First Out 的线性表 简称LIFO结构 incl
  • 【网络编程】网络编程知识点总结

    文章目录 UDP也需要端口号 基于TCP的socket通信中 简易服务端的六步依次为 基于TCP的socket通信中 简易客户端的四步依次为 介绍一下在linux环境下 服务器这六步的使用到的一些函数 参数 返回值类型等 介绍一下在linu
  • pycharm连接远程ssh服务器,Ctrl+S不能自动上传代码

    各位码友在用pycharm连接远程服务器编写代码时 有些情况下 需要保持本地文件和远程文件的同步 可以设置成自动上传 或者按Ctrl S才会上传 设置步骤如以下截图所示 本来这样操作就行了 但是笔者有时设置成按Ctrl S进行保存 按Ctr
  • 基于Pytorch1.8.0+Win10+RTX3070的MNIST网络构建与训练

    直接上代码 先上整个的代码 import torch import torchvision from torch utils data import DataLoader import matplotlib pyplot as plt im