基于libtorch的LeNet-5卷积神经网络实现(2)--Cifar-10数据分类

2023-11-18

上篇文章中我们使用libtorch实现了LeNet-5卷积神经网络,并对Minst数据集进行训练与分类。本文我们尝试使用该实现的网络对更加复杂的Cifar-10数据集进行训练、分类。

基于libtorch的LeNet-5卷积神经网络实现

LeNet-5网络地总体结构如下,详细请参考上方地链接。

1. Cifar-10数据集介绍

Cifar-10是一个专门用于测试图像分类的公开数据集,其包含的彩色图像分为10种类型:飞机、轿车、鸟、猫、鹿、狗、蛙、马、船、货车。且这10种类型图像的标签依次为0、1、2、3、4、5、6、7、8、9。

该数据集分为Python、Matlab、C/C++三个不同的版本,顾名思义,三个版本分别适用于对应的三种编程语言。因为我们使用的是C/C++语言,所以使用对应的C/C++版本就好,该版本的数据集包含6个bin文件,如下图所示,其中data_batch_1.bin~data_batch_5.bin通常用于训练,而test_batch.bin则用于训练之后的识别测试

如下图所示,每个bin文件包含10000*3073个字节数据,在每个3073数据块中,第一个字节是0~9的标签,后面3072字节则是彩色图像的三通道数据:红通道 --> 绿通道 --> 蓝通道 (1024 --> 1024 --> 1024)。其中每1024字节的数据就是一帧单通道的32*32图像,3帧32*32字节的单通道图像则组成了一帧彩色图像。所以总体来说,每一个bin文件包含了10000帧32*32的彩色图像。

2. 训练策略

(1) epoch

首先我们讲一下epoch的概念:一个epoch就是将所有的训练数据都输入神经网络,并完成前向传播、反向传播、参数更新的过程。比如我们使用Cifar-10数据集进行训练的时候,训练数据包含于5个bin文件中,每个bin文件有10000张32*32图像,因此总共有5*10000张训练图像,当我们把这5*10000张训练图像都输入网络并完成训练的过程,就是一个epoch过程。

然而,往往一个epoch过程过程达不到参数收敛的目的,因此需要执行多个epoch过程,也就是说:使用这5*10000张训练图像完成一次训练之后,在此次训练得到参数模型的基础上,再重复使用这5*10000张训练图像进行下一轮训练。

(2) 学习率的改变

我们把学习率α的初始值设置为0.001,每完成一个epoch,参数都进一步接近收敛状态,因此这个时候我们需要适当比例地减小α,以缩短步长:

α = α*0.8

3. 数据格式转换

(1) 图像格式转换

我们使用Opencv来读取Cifar-10图像为Mat格式,但是libtorch框架处理的数据格式为Tensor格式,因此首先需要把Mat格式的图像转换为Tensor格式。调用from_blob函数可方便进行转换,不过要注意转换时需指定数据维度为1*1*row*col,其中row、col分别为图像的高、宽。

//test_img[i]为Mat格式,test_img[i].data为Mat的数据首地址
//{ 1, 1, test_img[i].rows, test_img[i].cols }指定Tensor张量的维度为1*1*row*col
//torch::kFloat表示以float格式转换数据,该类型要与Mat本来的数据类型相一致,否则会出错
torch::Tensor inputs = torch::from_blob(test_img[i].data, { 1, 1, test_img[i].rows, test_img[i].cols }, torch::kFloat);

(1) 标签格式转换

我们读取的标签为单个uchar型数据,但是libtorch框架处理的数据格式为Tensor格式,因此首先需要把单个uchar数据转换为Tensor格式。调用from_blob函数也可方便转换,同样要注意转换时需指定数据维度为1,也就是只有一个数据的张量。

//test_label[i]为一个uchar数据,&test_label[i]表示该数据的地址
//{ 1 }表示该张量的维度为1
//torch::kByte表示以uchar类型读取该数据,该参数需要与数据本身的类型一致
//toType(torch::kLong)表示把Tensor张量转换为long int类型,因为要求标签的类型为long int
torch::Tensor labels = torch::from_blob(&test_label[i], { 1 }, torch::kByte).toType(torch::kLong);

4. 主要代码实现

(1) 读取Cifar-10数据与标签代码

读取到的图像为uchar型的三通道彩色图像,因此需要将其转换为单通道灰度图,并转换为-1~1之间的浮点型数据,方便后续的训练、分类。

void read_cifar_bin(char *bin_path, vector<Mat> &img_liat, vector<uchar> &label_list)
{
  const int img_num = 10000;
  const int img_size = 3073;   //第一字节是标签
  const int img_size_1 = 1024;
  const int data_size = img_num * img_size;
  const int row = 32;
  const int col = 32;


  uchar *cifar_data = (uchar *)malloc(data_size);
  if (cifar_data == NULL)
  {
    cout << "malloc failed" << endl;
    return;
  }


  FILE *fp = fopen(bin_path, "rb");
  if (fp == NULL)
  {
    cout << "fopen file failed" << endl;
    free(cifar_data);
    return;
  }


  fread(cifar_data, 1, data_size, fp);


  img_liat.clear();
  label_list.clear();
  for (int i = 0; i < img_num; i++)
  {
    long int offset = i * img_size;
    long int offset0 = offset + 1;    //红
    long int offset1 = offset0 + img_size_1;    //绿
    long int offset2 = offset1 + img_size_1;   //蓝


    uchar label = cifar_data[offset];   //标签
    Mat img(row, col, CV_8UC3);


    for (int y = 0; y < row; y++)
    {
      for (int x = 0; x < col; x++)
      {
        int idx = y * col + x;
        img.at<Vec3b>(y, x) = Vec3b(cifar_data[offset2 + idx], cifar_data[offset1 + idx], cifar_data[offset0 + idx]);    //BGR
      }
    }


    Mat gray;
    cvtColor(img, gray, COLOR_BGR2GRAY);  //三通道彩色图转换为单通道灰度图
    gray.convertTo(gray, CV_32F);  //uchar转换为float
    gray = gray / 255.0;   //范围0~1
    gray = (gray - 0.5) / 0.5;  //范围-1~1


    img_liat.push_back(gray.clone());   //float
    label_list.push_back(label);       //uchar
  }


  fclose(fp);
  free(cifar_data);
}

(2) LeNet-5网络定义

struct LeNet5 : torch::nn::Module
{
  //arg_padding为C1层的padding参数,当输入图像为28*28时,需要将其填充为32x32的图像
  //这里可能有人会有疑惑,为什么没有定义S2、S4层,这是因为池化层放在前向传播函数中执行即可,不需要再定义了,详细见forward函数
  LeNet5(int arg_padding = 0)
    //C1层
    : C1(register_module("C1", torch::nn::Conv2d(torch::nn::Conv2dOptions(1, 6, 5).padding(arg_padding))))
    //C3层
    , C3(register_module("C3", torch::nn::Conv2d(6, 16, 5)))
    //C5层
    , C5(register_module("C5", torch::nn::Conv2d(16, 120, 5)))
    //F6层
    , F6(register_module("F6", torch::nn::Linear(120, 84)))
    //OUTPUT层
    , OUTPUT(register_module("OUTPUT", torch::nn::Linear(84, 10)))
  {


  }


  ~LeNet5()
  {


  }
  
  //该函数用于将多维数据一维展开成一维向量
  int64_t num_flat_features(torch::Tensor input)
  {
    int64_t num_features = 1;
    auto sizes = input.sizes();
    for (auto s : sizes) 
    {
      num_features *= s;
    }
    return num_features;
  }
  
  //前向传播函数
  torch::Tensor forward(torch::Tensor input)
  {
    namespace F = torch::nn::functional;
    //这一步其实包含了3个操作,首先是C1层的卷积,其次是将卷积结果输入Relu函数,接着是将Relu函数的结果做最大值的池化操作
    auto x = F::max_pool2d(F::relu(C1(input)), F::MaxPool2dFuncOptions({ 2,2 }));
    //这一步也包含了3个操作,首先是C3层的卷积,其次是将卷积结果输入Relu函数,接着是将Relu函数的结果做最大值的池化操作
    x = F::max_pool2d(F::relu(C3(x)), F::MaxPool2dFuncOptions({ 2,2 }));
    //将C5层的卷积结果输入Relu函数,Relu函数的结果作为本层输出
    x = F::relu(C5(x));
    //120张1*1的卷积结果图像按顺序展开成长度为120的一维向量
    x = x.view({ -1, num_flat_features(x) });
    //F6层的Affine计算
    x = F::relu(F6(x));
    //OUTPUT层的Affine计算,注意这里不包括Softmax层计算,Softmax层计算放到后面的交叉熵误差函数中去做
    x = OUTPUT(x);
    
    return x;
  }


  //要求这里的各层定义与本结构体开头处的定义保持一致
  int m_padding = 0;
  torch::nn::Conv2d  C1;
  torch::nn::Conv2d  C3;
  torch::nn::Conv2d  C5;
  torch::nn::Linear  F6;
  torch::nn::Linear  OUTPUT;
};

(3) 训练代码

Cifar-10图像本来就是32*32大小,因此不需要像Minst数据集那样填充数据。



void tran_lenet_5_cifar_10(void)
{


  vector<Mat> train_img_total;
  vector<uchar> train_label_total;


  //定义一个LeNet-5网络结构体,输入的图像是32x32图像,不需要填充
  LeNet5 net1(0);
  //使用交叉熵误差函数
  auto criterion = torch::nn::CrossEntropyLoss();
  
  int kNumberOfEpochs = 8;   //训练8个epoch 
  int data_file_num = 5;  //总共5个训练文件
  double alpha = 0.001;   //学习率初始值0.001
  for (int epoch = 0; epoch < kNumberOfEpochs; epoch++)
  {
    printf("epoch:%d\n", epoch+1);
    //定义梯度下降法优化器
    auto optimizer = torch::optim::SGD(net1.parameters(), torch::optim::SGDOptions(alpha).momentum(0.9));
    
    for (int k = 1; k <= data_file_num; k++)
    {


      printf("data_file_num:%d\n", k);


      auto running_loss = 0.;


      char str[200] = { 0 };
      sprintf(str, "D:/Program Files (x86)/Microsoft Visual Studio 14.0/prj/KNN_test/KNN_test/cifar-10-batches-bin/data_batch_%d.bin", k);
      //读取10000张cifar-10训练图像
      read_cifar_bin(str, train_img_total, train_label_total);


      for (int i = 0; i < train_img_total.size(); i++)
      {
        //Mat转换为Tensor
        torch::Tensor inputs = torch::from_blob(train_img_total[i].data, { 1, 1, train_img_total[i].rows, train_img_total[i].cols }, torch::kFloat);  //1*1*32*32
        //uchar转换为Tensor
        torch::Tensor labels = torch::from_blob(&train_label_total[i], { 1 }, torch::kByte).toType(torch::kLong);   //1
        //清零梯度
        optimizer.zero_grad();
        //前向传播
        auto outputs = net1.forward(inputs);
        //计算交叉熵误差
        auto loss = criterion(outputs, labels);
        //误差反向传播
        loss.backward();
        //更新参数
        optimizer.step();
        //累加每1000个误差值,方便查看训练时交叉熵误差函数的下降情况
        running_loss += loss.item().toFloat();
        if ((i + 1) % 1000 == 0)
        {
          printf("loss: %.3f\n", running_loss / 1000);
          running_loss = 0.;
        }
      }
    }


    alpha *= 0.8;  //完成一个epoch,减小学习率
  }
  
  printf("Finish training!\n");
  torch::serialize::OutputArchive archive;
  net1.save(archive);
  archive.save_to("mnist_cifar_10.pt");
  printf("Save the training result to mnist.pt.\n");


}
 

(4) 分类测试代码

void test_lenet_5_cifar_10(void)
{


  LeNet5 net1(0);


  torch::serialize::InputArchive archive;
  archive.load_from("mnist_cifar_10.pt");  //加载上一步骤训练好的模型


  net1.load(archive);
  //读取测试数据与标签
  vector<Mat> test_img;
  vector<uchar> test_label;
  read_cifar_bin("D:/Program Files (x86)/Microsoft Visual Studio 14.0/prj/KNN_test/KNN_test/cifar-10-batches-bin/test_batch.bin", test_img, test_label);


  int total_test_items = 0, passed_test_items = 0;
  double total_time = 0.0;
  
  for (int i = 0; i < test_img.size(); i++)
  {
    //将Mat格式转换为Tensor格式
    torch::Tensor inputs = torch::from_blob(test_img[i].data, { 1, 1, test_img[i].rows, test_img[i].cols }, torch::kFloat);  //1*1*32*32
    //uchar转换为Tensor
    torch::Tensor labels = torch::from_blob(&test_label[i], { 1 }, torch::kByte).toType(torch::kLong);   //1
    //使用训练好的模型对测试数据进行分类,也即前向传播过程
    auto outputs = net1.forward(inputs);
    //得到分类值,0 ~ 9
    auto predicted = (torch::max)(outputs, 1);
    //比较分类结果和对应的标签是否一致,如果一致则认为分类正确
    if (labels[0].item<int>() == std::get<1>(predicted).item<int>())
      passed_test_items++;


    total_test_items++;


    printf("label: %d.\n", labels[0].item<int>());
    printf("predicted label: %d.\n", std::get<1>(predicted).item<int>());
    
  }
  
  printf("total_test_items=%d, passed_test_items=%d, pass rate=%f\n", total_test_items, passed_test_items, passed_test_items*1.0 / total_test_items);
}

(5) main函数

int main()
{
  tran_lenet_5_cifar_10();  //训练
  test_lenet_5_cifar_10();  //测试


  return EXIT_SUCCESS;
}

5. 运行结果

运行上述代码,先训练模型、保存模型,然后再加载模型(实际使用时如果已经训练好模型,可直接加载模型而不需要再训练)。

训练时目标函数(损失函数)的变化如下图所示,可以看到其值逐步减小:

待训练完成以及保存模型之后,就可以加载模型对数据进行分类预测啦,我们的分类结果如下图所示,准确率达到了60.3%,这个分类结果并不理想,这是因为LeNet-5网络还是相对简单了,对相对Minst更复杂的数据分类效果并不好。因此在接下来的文章中我们将使用libtorch来实现更加复杂的网络,敬请期待~

欢迎扫码关注以下微信公众号,接下来会不定时更新更加精彩的内容噢~

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

基于libtorch的LeNet-5卷积神经网络实现(2)--Cifar-10数据分类 的相关文章

  • Pygame 让精灵按照给定的旋转行走

    很久以前我做了一个Scratch脚本 我想用Pygame将其转换为Python 有很多示例显示图像的旋转 但我想知道如何更改精灵的旋转以使其沿给定方向移动 而不更改图像 这是我的暂存代码 这是我的 Pygame 精灵类 class Star
  • Pandas 在列级别连接数据帧时添加键

    根据 Pandas 0 19 2 文档 我可以提供keys参数来创建结果多索引 DataFrame 一个例子 来自 pandas 文档 是 result pd concat frames keys x y z 我将如何连接数据框以便我可以在
  • python blpapi安装错误

    我试图根据 README 中的说明为 python 安装 blpapi 3 5 5 但是在运行时 python setup py install 我收到以下错误 running install running build running b
  • 从 Django 基于类的视图的 form_valid 方法调用特殊(非 HTTP)URL

    如果你这样做的话 有一个 HTML 技巧 a href New SMS Message a 点击新短信打开手机的本机短信应用程序并预 先填写To包含所提供号码的字段 在本例中为 1 408 555 1212 以及body与提供的消息 Hel
  • 如何使用 Twython 将 oauth_callback 值传递给 oauth/request_token

    Twitter 最近刚刚强制执行以下规定 1 您必须通过oauth callbackoauth request token 的值 这不是可选的 即使您已经在 dev twitter com 上设置了一个 如果您正在执行带外 OAuth 请通
  • 更改 numpy 数组的结构强制给定值

    如何缩小栅格数据的比例4 X 6大小成2 X 3如果 2 2 像素内的任何元素包含 1 则大小强制选择 1 否则选择 0 import numpy as np data np array 0 0 1 1 0 0 1 0 0 1 0 0 1
  • 如何计算查询集中每个项目的两个字段的总和

    假设我有以下模型结构 class SomeModel Model base price DecimalField commision DecimalField 我不想存储total price在我的数据库中为了数据一致性并希望将其计算为ba
  • 监控培训课程如何运作?

    我试图理解使用之间的区别tf Session and tf train MonitoredTrainingSession 以及我可能更喜欢其中之一 似乎当我使用后者时 我可以避免许多 杂务 例如初始化变量 启动队列运行程序或设置文件编写器以
  • 如何在python mechanize中设置cookie

    向服务器发送请求后 br open http xxxx br select form nr 0 br form MESSAGE 1 2 3 4 5 br submit 我得到了响应标题 其中包含 set cookie Set Cookie
  • 多个列表和大小的所有可能排列

    在 python 中使用以下命令很容易计算简单的排列itertools permutations https docs python org 3 library itertools html itertools permutations 你
  • 如何使用 numpy 从一维数组创建对角矩阵?

    我正在使用 Python 和 numpy 来做线性代数 我表演了numpy对矩阵进行 SVD 以获得矩阵 U i 和 V 然而 i 矩阵表示为 1 行的 1x4 矩阵 IE 12 22151125 4 92815942 2 06380839
  • Python 中使用 globals() 的原因?

    Python 中有 globals 函数的原因是什么 它只返回全局变量的字典 这些变量已经是全局的 所以它们可以在任何地方使用 我只是出于好奇而问 试图学习Python def F global x x 1 def G print glob
  • 安塞布尔 + 10.11.6

    我在 非常 干净地安装 10 11 6 时遇到了 Ansible 的奇怪问题 我已经安装了brew zsh oh my zsh Lil snitch 和1password 实际上没有安装其他任何东西 我安装了ansible brew ins
  • Django INSTALLED_APPS 的命名约定是如何工作的?

    该网站上的教程创建了一个名为 polls 的应用程序 它使用 django 1 9 所以在 INSTALLED APPS 中它是 polls apps PollsConfig 我正在观看一个教程 他将应用程序命名为新闻通讯 并且在 INST
  • datetime strftime 不输出正确的时间戳

    下列 gt gt gt from dateutil parser import parse gt gt gt parse 2013 07 02 00 00 00 0000 datetime datetime 2013 7 2 0 0 tzi
  • 如何输入可变的默认参数

    Python 中处理可变默认参数的方法是将它们设置为无 https stackoverflow com a 366430 5049813 例如 def foo bar None bar if bar is None else bar ret
  • 重写 PyGObject 中的虚拟方法

    我正在尝试实施高宽几何管理 http developer gnome org gtk3 3 2 GtkWidget html geometry management在 GTK 和 Python 中用于我的自定义小部件 我的小部件是来自的子类
  • gnuplot:第 1 行:无效命令

    stackoverflow 上可爱的人们大家好 我正在尝试使用 gnuplot 绘制数据 我首先阅读表格并提取我想要的数据 我将此数据写入 dat 文件 截至目前 我只是尝试通过命令行绘制它 但会添加必要的代码以在 python 脚本工作后
  • AES 在 cryptojs 中加密并在 python Crypto.Cipher 中解密

    使用 js CryptoJS 加密并使用 python crypto Cipher 解密时出现问题 这是我在js中的实现 附加 iv 与加密消息并使用 base64 进行编码
  • 如何从Python枚举类中获取所有值?

    我正在使用 Enum4 库创建一个枚举类 如下所示 class Color Enum RED 1 BLUE 2 我要打印 1 2 作为某处的列表 我怎样才能实现这个目标 您可以执行以下操作 e value for e in Color

随机推荐

  • 傻瓜电梯项目实现

    目录 文档介绍 package lift entity Elevator java Entity java Floor java package lift Pretreatment Pretreatment java package lif
  • Elasticsearch——document相关原理

    1 document数据路由原理 1 1 document路由到shard上是什么意思 一个index的数据会被分为多片 每片都在一个shard中 所以说 一个document 只能存在于一个shard中 当客户端创建document的时候
  • [计算机毕业设计]大数据疫情分析与可视化系统

    前言 大四是整个大学期间最忙碌的时光 一边要忙着准备考研 考公 考教资或者实习为毕业后面临的就业升学做准备 一边要为毕业设计耗费大量精力 近几年各个学校要求的毕设项目越来越难 有不少课题是研究生级别难度的 对本科同学来说是充满挑战 为帮助大
  • mysql报错 -- (errno: 13 - Permission denied)

    重启服务器后 mysql没有自启动 手动启动的时候报错 后面经一番折腾后强行用root身份启动后又发现原有的数据库表都不见了 mysql 报错 ERROR 1018 HY000 Can t read dir of db translator
  • 模型选择+过拟合+欠拟合

    模型选择 当我们训练模型时 我们只能访问数据中的小部分样本 最大的公开图像数据集包含大约一百万张图像 而在大部分时候 我们只能从数千或数万个数据样本中学习 将模型在训练数据上拟合的比在潜在分布中更接近的现象称为过拟合 overfitting
  • 从代码角度理解DETR

    一个cnn的backbone 提图像的feature 比如 HWC 同时对这个feature做position embedding 然后二者相加 在Transformer里面就是二者相加 输入encoder 输入decoder 这里有obj
  • Matlab中实现图像处理的工作流程

    一 识别流程 Receipt Identification Workflow Working with Images in MATLAB Import display and manipulate color and grayscale i
  • Angular4.0_完善在线竞拍应用路由

    路由实战思路 一 创建商品详情组件 显示商品的图片和标题 使用Angular命令行工具生成一个新的组件 ng g component productDetail product detail component ts import Comp
  • latex Elsevier 模板给作者加脚注

    Elsevier 模板给作者加脚注 thanks 无效 网上有说使用 corref cor1 cortext cor1 Corresponding author 但是实测发现不行 只能加一个标注 再加一个就是两个 还有说使用 authorn
  • SVM算法(Support Vector Machine)

    一 SVM 支持向量机 support vector machines SVM 是一种二分类模型 将实例的特征向量映射为空间中的一些点 SVM 的目的就是想要画出一条线 以 最好地 区分这两类点 以至如果以后有了新的点 这条线也能做出很好的
  • GIT reset

    Git Reset 转载Git Reset reset 用于回退commit 主要有三个参数 hard mixed soft working工作区 cache暂存区 repository本地库 hard 清空 清空 清空 mixed 保留
  • window系统启动redis和清除缓存

    一 启动redis dos命令行方式 c user john gt d 进入所在盘 D gt cd D Redis x64 3 2 100 进入安装目录 D gt cd D Redis x64 3 2 100 gt redis server
  • git提交新项目操作笔记

    git提交新项目操作笔记 1 本地安装git环境 下载安装包安装即可 2 初始化git项目 生成 git 配置目录 进入项目根目录 右键 git bash here打开控制台 输入git init即可完成 3 将项目加入本地git仓库 gi
  • fork()函数详解

    一个进程 包括代码 数据和分配给进程的资源 fork 函数通过系统调用创建一个与原来进程几乎完全相同的进程 也就是两个进程可以做完全相同的事 但如果初始参数或者传入的变量不同 两个进程也可以做不同的事 一个进程调用fork 函数后 系统先给
  • 'gbk' codec can't decode byte 0xae 解决方法

    gbk codec can t decode byte 0xae 解决方法 今天使用python 读取txt的时候出现了如下报错 Message gbk codec can t decode byte 0xae in position 32
  • python一球从100米高度自由落下,一球从100米高度自由落下,每次落地后反跳回原高度的一半;再落下,求它在 第10次落地时,......

    首先一开始想到的就是用循环来计算的 所以就写了以下代码 include include include define H 100 define N 10 int main void int i 1 float weiyi distance
  • SVN安装及使用教程图文详解

    一 SVN简介 1 什么是SVN SVN全名Subversion 即版本控制系统 SVN与CVS一样 是一个跨平台的软件 支持大多数常见的操作系统 作为一个开源的版本控制系统 Subversion管理着随时间改变的数据 这些数据放置在一个中
  • 树莓派安装TensorFlow并使用[一步到位]

    树莓派安装TensorFlow并使用 一步到位 安装TensorFlow并使用 树莓派3B 树莓派安装TensorFlow并使用 一步到位 换源并更新 安装TensorFlow依赖包 安装TensorFlow并使用 各种问题 换源并更新 安
  • 帮我使用pytorch和opencv实现根据双目视差图生成点云

    可以使用OpenCV库读取双目图像 并使用SGBM算法或BM算法计算视差图 然后 可以使用OpenCV的reprojectImageTo3D函数将视差图映射到三维空间中 生成点云 以下是代码示例 import cv2 import nump
  • 基于libtorch的LeNet-5卷积神经网络实现(2)--Cifar-10数据分类

    上篇文章中我们使用libtorch实现了LeNet 5卷积神经网络 并对Minst数据集进行训练与分类 本文我们尝试使用该实现的网络对更加复杂的Cifar 10数据集进行训练 分类 基于libtorch的LeNet 5卷积神经网络实现 Le