pytorch入门级教程——分类代码分析与实现(iris数据集)

2023-11-02

用iris数据进行分类训练,并可视化

首先导入API:

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from torch.autograd import Variable
from sklearn.datasets import load_iris
import pandas as pd
import numpy as np

获得数据集:

iris = load_iris()
iris_d = pd.DataFrame(iris['data'], columns = ['Sepal_Length', 'Sepal_Width', 'Petal_Length', 'Petal_Width'])
iris_d['Species'] = iris.target

特征降维——主成分分析(PCA),将四维降至两维

transfer_1 = PCA(n_components=2)
iris_d = transfer_1.fit_transform(iris_d)
x = torch.from_numpy(iris_d)
y =torch.from_numpy(iris.target)
x, y = Variable(x), Variable(y)

构建前向神经网络,2个输入神经元,10个中间层神经元,3个输出神经元,激活函数使用RELU:

net =torch.nn.Sequential(
    torch.nn.Linear(2, 10),
    torch.nn.ReLU(),
    torch.nn.Linear(10, 3),
    )
print(net) 

设定autograd为梯度下降法以及损失函数为交叉熵误差:

optimizer = torch.optim.SGD(net.parameters(), lr=0.2)  #随机梯度下降
loss_func = torch.nn.CrossEntropyLoss() 

训练模型,循环150次:

for t in range(150):
    out = net(x.float())                 # input x and predict based on x
    loss = loss_func(out, y.long())     # must be (1. nn output, 2. target), the target label is NOT one-hotted

    optimizer.zero_grad()   # clear gradients for next train
    loss.backward()         # backpropagation, compute gradients
    optimizer.step()        # apply gradients

可视化,每十步更新一次:

if t % 10 == 0:
        # plot and show learning process
        plt.cla()
        prediction = torch.max(out, 1)[1]
        pred_y = prediction.data.numpy()
        target_y = y.data.numpy()
        plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')
        accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size)
        plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 20, 'color':  'red'})
        plt.pause(0.1)
plt.show()

保存及加载神经网络模型:

torch.save(net, 'E:\\net.pkl')  # save entire net
torch.save(net.state_dict(), 'E:\\net_params.pkl')   # save only the parameters
net1 = torch.load('net.pkl')  #加载神经网络

附完整代码:

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from torch.autograd import Variable

# make fake data
from sklearn.datasets import load_iris
import pandas as pd
import numpy as np
iris = load_iris()
iris_d = pd.DataFrame(iris['data'], columns = ['Sepal_Length', 'Sepal_Width', 'Petal_Length', 'Petal_Width'])
iris_d['Species'] = iris.target
 
#特征降维——主成分分析
transfer_1 = PCA(n_components=2)
iris_d = transfer_1.fit_transform(iris_d)
x = torch.from_numpy(iris_d)
y =torch.from_numpy(iris.target)
x, y = Variable(x), Variable(y)

# mehod 
net =torch.nn.Sequential(
    torch.nn.Linear(2, 10),
    torch.nn.ReLU(),
    torch.nn.Linear(10, 3),
    )
#net1 = Net(n_feature=2, n_hidden=10, n_output=3)     # define the network
print(net)  # net architecture

optimizer = torch.optim.SGD(net.parameters(), lr=0.2)  #随机梯度下降
loss_func = torch.nn.CrossEntropyLoss()  # the target label is NOT an one-hotted



for t in range(150):
    out = net(x.float())                 # input x and predict based on x
    loss = loss_func(out, y.long())     # must be (1. nn output, 2. target), the target label is NOT one-hotted

    optimizer.zero_grad()   # clear gradients for next train
    loss.backward()         # backpropagation, compute gradients
    optimizer.step()        # apply gradients

    if t % 10 == 0:
        # plot and show learning process
        plt.cla()
        prediction = torch.max(out, 1)[1]
        pred_y = prediction.data.numpy()
        target_y = y.data.numpy()
        plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')
        accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size)
        plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 20, 'color':  'red'})
        plt.pause(0.1)


plt.show()
torch.save(net, 'E:\\net.pkl')  # save entire net
torch.save(net.state_dict(), 'E:\\net_params.pkl')   # save only the parameters
net1 = torch.load('net.pkl')  #加载神经网络

训练结果展示:
在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch入门级教程——分类代码分析与实现(iris数据集) 的相关文章

  • Docker日志日期时间精确查询

    docker logs since 2020 07 30T10 14 00 until 2020 07 30T10 15 00 tomcat80 这条代码可以通过2个时间来查询指定范围的时间日志 since起始时间 你要从什么时候开始查询
  • 【数据结构】UnionFind 并查集-2

    数据结构源码 UnionFind1 接口 public interface UnionFind int getSize boolean isConnected int p int q void unionElements int p int
  • 华大HC32L176与三相四线计量模块JSY_333通讯例程以及对三相三线认识误区

    在某宝购买这个产品后 需要编写程序读取数据 这款产品可以使用TTL和RS485进行通讯 我用的是用华大单片机HC32L176 首先对串口进行初始化 程序可以自行下载 链接 https pan baidu com s 1FD2VecV64ZH
  • 从端到端打通模型端侧部署流程(NCNN)

    文章目录 背景介绍 为什么要做端侧推理 端侧深度学习部署流程 一条主要技术路线 ONNX NCNN框架 NCNN的官方介绍 NCNN问题解决 NCNN使用样例 快速在NCNN框架下验证自己的模型 一般流程 YOLOv5的demo测试 全新部
  • CGSS中国综合社会调查

    数据详情 1 包含数据库和问卷 2 数据包含的年份为2003 2005 2006 2008 2010 2011 2012 2013 2015 2017 3 2017年数据为SPSS和STATA 14版 CSV EXCEL 编码表 4 15年
  • 8.14 ARM

    1 练习一 text 文本段 global start 声明一个 start函数入口 start start标签 相当于C语言中函数 mov r0 0x2 mov r1 0x3 cmp r0 r1 beq stop subhi r0 r0

随机推荐

  • python的类写法_python类写法

    广告关闭 腾讯云11 11云上盛惠 精选热门产品助力上云 云服务器首年88元起 买的越多返的越多 最高返5000元 在python中这一点仍然成立 in class fatboy object pass in fb fatboy in pr
  • 刷脸发甚至改变整个支付行业和零售行业

    在今年4月17日 蚂蚁金服在北京发布新一代刷脸支付产品 蜻蜓2 0 并宣称未来将会投入30亿让刷脸支付全国普及 助力商家数字化 让商家快速结付 提高商家运营效率 为顾客便利服务 为商家引流 支付宝蜻蜓二代接入刷脸即会员等数字化经营能力 试点
  • vue el-option只回显数字问题

    1 value前面没有加冒号说明是字符串 加个冒号即可回显label名称 2 后端返回的值可能已经将id类型返回为String 此时转换为number即可回显 3 也可用v for循环渲染选项 回显时肯定能回显label名称
  • 机器人避障路径规划--基于人工势场算法

    机器人避障路径规划 基于人工势场算法 机器人避障路径规划是机器人导航和控制中的一个基本问题 它的目标是在给定环境中找到一条安全可行的路径 使得机器人能够从起点到达目标点 并尽可能地避免与环境发生碰撞 人工势场算法是一种常用的机器人避障路径规
  • error: could not delete '/usr/local/lib/python3.6/site-packages/pip/_internal/configuration.py': Per

    brew install python 报错 error could not delete usr local lib python3 6 site packages pip internal configuration py Permis
  • 黑马程序员Javaweb学习笔记02【request和response】

    该博客主要记录在学习黑马程序员Javaweb过程的一些笔记 方便复习以及加强记忆 系列文章 JavaWeb学习笔记01 BS架构 Maven Tomcat Servlet JavaWeb学习笔记02 request和response Jav
  • 【三维语义分割】PointNet++ (二):模型结构详解

    本文为博主原创文章 未经博主允许不得转载 本文为专栏 python三维点云从基础到深度学习 系列文章 地址为 https blog csdn net suiyingy article details 124017716 本节主要介绍Poin
  • 电机的堵转检测及处理

    基于L298N控制的电机的堵转检测及其处理 一 L298N原理 二 电机堵转检测 三 电机堵转处理 一 L298N原理 1 L298N datasheet 2 使用须知 工作电压高 最高工作电压可达46V 输出电流大 瞬间峰值电流可达3A
  • jeesite框架分析理解

    前文 jeesite代码生成器的使用 实例 报销表 地址 http blog csdn net m0 38021128 article details 68490920 前文中使用了jeesite框架的代码生成功能实现了一个小实例 但是实际
  • STM32CubeMX—串口空闲中断+DMA接收

    一 串口中断通信 串口中断方式的特点 发送数据时 将一字节数据放入数据寄存器DR 接收数据时 将DR的内容存放到用户存储区 中断方式不必等待数据的传输过程 只需要在每字节数据收发完成后 由中断标志位触发中断 在中断服务程序中放入新的一字数据
  • iOS 微信发布 8.0.12 正式版,寂寞来袭

    今天微信突然更新8 0 12正式版 我马上更新 更新完后并没有发现什么新功能 我就赶紧发文告诉大家 大家快去更新 更新看看这次更新了什么 我在AppStore商店更新完毕后就大概看了一下 并没有什么实质性的功能 可能内测功能还是内测人使用吧
  • org.apache.catalina.core.ApplicationContext.log ssi: Can‘t find file: /index.htmlERROR ErrorPageFil

    自己配置的tomcat 部署应用时提示错误信息 org apache catalina core ApplicationContext log ssi Can t find file index html ERROR ErrorPageFi
  • 如何设计一个高并发系统?

    原创 苏三呀 苏三说技术 2023 09 08 08 21 发表于四川 收录于合集 系统设计3个 大家好 我是苏三 又跟大家见面了 前言 最近有位粉丝问了我一个问题 如何设计一个高并发系统 这是一个非常高频的面试题 面试官可以从多个角度 考
  • 【VUE2】VUE2基础知识和原理--超详细--超简介--零基础(一)

    vue基础知识和原理 1 初识Vue 想让Vue工作 就必须创建一个Vue实例 且要传入一个配置对象 demo容器里的代码依然符合html规范 只不过混入了一些特殊的Vue语法 demo容器里的代码被称为 Vue模板 Vue实例和容器是一一
  • 兼容火狐--常见问题修改

    此文为本人在实际工作中遇到的情况做的记录 所以比较乱 主要用于自己日后查看 如果对大家有帮助 当然也更好 最普遍的情况 当遇到功能不好使的情况 首先按f12看控制台有没有报错 A如果有定位错误 常见错误 window frames ifra
  • Qt与MSVC中文乱码问题的解决方案

    一 问题是什么 在学习Qt编程的过程中 大多数人都遇到过中文乱码的问题 总结起来有三类 1 Qt Creator中显示的汉字变为乱码 编辑器上方有 Could not decode with UTF 8 encoding Editing n
  • PHP与MySQL程序设计 学习笔记 第二章 环境配置

    主流Linux发行包中都加入了Apache 如果没有 也可以利用发行包的打包服务轻松安装 如Ubuntu的apt get命令 http httpd apache org download cgi可导航到离你最近的镜像站点 windows安装
  • 如何查看linux服务器的版本和配置信息

    linux下看配置 可没有windows那么直观 你只能一个一个查看 一 cpu root srv more proc cpuinfo grep modelname root srv grep model name proc cpuinfo
  • python quit()讲解_Python:pygame.QUIT()

    Just been messing around with pygame and ran into this error CODE import sys import pygame pygame init size width height
  • pytorch入门级教程——分类代码分析与实现(iris数据集)

    用iris数据进行分类训练 并可视化 首先导入API import torch import torch nn functional as F import matplotlib pyplot as plt from sklearn dec