pytorch:分类(Classification)

2023-11-03

使用pytorch+python实现分类。

程序:

import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt

# 创造数据
n_data = torch.ones(100, 2)  # 数据的的基本形态
x0 = torch.normal(2*n_data, 1)   # class0 x shape=(100,2) 创建一个服从均值为2*n_data的张量,标准差为均为1的tensor
y0 = torch.zeros(100)            # class0 y shape=(100,1)
x1 = torch.normal(-2*n_data, 1)  # class1 x shape=(100,2) 创建一个服从均值为-2*n_data的张量,标准差为均为1的tensor,shape=(100,2)
y1 = torch.ones(100)             # class1 y shape=(100,1)

# 合并数据(cat是cancatenate的意思)
x = torch.cat((x0, x1), 0).type(torch.FloatTensor)  # FloatTensor = 32-bit float 其中0是按行拼接,1是按列拼接
y = torch.cat((y0, y1)).type(torch.LongTensor)  # LongTensor = 64-bit integer

# 画图
plt.scatter(x.data.numpy()[:, 0], x.data[:].numpy()[:, 1])
plt.show()

# 创建网络
class Net(torch.nn.Module):  # 继承torch的Module
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()  # 继承__init__功能
        self.hidden = torch.nn.Linear(n_feature, n_hidden)
        self.out = torch.nn.Linear(n_hidden,n_output)

    def forward(self, x):
        X = F.relu(self.hidden(x))  # 激活函数
        X = self.out(X)  # 输出值,但这不是预测值,预测值还需另外计算
        return X

net = Net(n_feature=2, n_hidden=10, n_output=2)
# print(net)

# 训练
optimizer = torch.optim.SGD(net.parameters(), lr=0.02)  # 传入net的所有参数和学习率
loss_func = torch.nn.CrossEntropyLoss()

for i in range(100):
    out = net(x)  # 把数据x作为输入,输出分析值
    loss = loss_func(out, y)  # 计算两者误差
    optimizer.zero_grad()  # 清空上一步的残余更新参数值
    loss.backward()  # 误差反向传播,计算参数更新值
    optimizer.step()  # 将参数更新值施加到net的parameters上

    if i % 2 == 0:
        plt.cla()
        prediction = torch.max(F.softmax(out), dim=1)[1]
        pred_y = prediction.data.numpy().squeeze()
        target_y = y.data.numpy()
        plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')
        accuracy = sum(pred_y == target_y)/200.  # 预测中有多少和真实值一样
        plt.text(1.5, -4,'Accuracy=%.2f'%accuracy, fontdict={'size': 20, 'color': 'red'})
        plt.pause(0.1)

结果:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch:分类(Classification) 的相关文章

随机推荐

  • 毕业设计-基于机器颗粒状的农作物检测算法研究-YOLO

    目录 前言 课题背景和意义 实现技术思路 一 整体方案设计 二 基于 YOLO 的迁移学习算法 实现效果图样例 最后 前言 大四是整个大学期间最忙碌的时光 一边要忙着备考或实习为毕业后面临的就业升学做准备 一边要为毕业设计耗费大量精力 近几
  • curl post参数,接口接收不到数据问题

    今天遇到一个问题 注册下发短信失败 总提示无法发送注册短信 请从新发送 经检查 curl里面将post数据以json encode的方法转码之后传递 而且各选项设置感觉没有问题 怎么接口就接收不到post过去的数据的呢 在网上也搜索了不少网
  • 【git】git常用命令及所遇问题解决方法-小结

    git常用命令 小结 所遇到的问题会随时更新 git常用命令梳理 萌新git常用命令总结 要开始了哦 预备 走你 围攻git常用命令 1 git查看并设置用户名及邮箱 查看 git config user name git config u
  • 接口与自动化测试

    一 什么是接口 接口就是API 意思是应用程序编程接口 接口本质上程序开发的函数和方法 提供参数和返回值 二 接口组成的要素什么 接口访问的地址 请求的方法 参数 返回值 接口访问的地址 协议 IP地址或域名 端口号 应用名 功能名 请求的
  • IDEA使用Git更新项目提示:Push rejected: Push master to origin/master was rejected by remote

    失败的原因有很多 idea的提示不明确 网上答案大同小异的 网上说没有权限之类的 我的是自己创建的仓库 自己是管理员 直接排除这个选项 我们可以去你需要上传项目的文件夹 然后右键 选中Git Bash Here 打开 通过命令 git pu
  • 安装ubuntu出现空闲的空间不可用

    问题 我的系统已经安装了一win7 我先把其中一个70G的分区用于安装ubuntu 但在安装的过程中发现给ubuntu分了两个挂载点之后剩下的剩余空间显示为 不可用 并且也不能分配其他的挂载点了 请问谁知道可以怎样解决整个问题吗 解答 安装
  • IDEA打包jar包详尽流程

    打包流程 1 打开菜单栏File Project Structure 2 点击Artifacts 3 点击 JAR From module with depenencies 4 后弹出如下界面 自此开始 各种问题就来了 首先Module中
  • Django(2)-编写你的第一个 Django 应用

    本教程的目的是创建一个网络投票应用程序 它将由两部分组成 一个让人们查看和投票的公共站点 一个让你能添加 修改和删除投票的管理站点 创建应用 python manage py startapp polls 每一个应用是一个python包 一
  • ORA-17502 与 ORA-15173 错误解决

    用rman恢复spfile时 报错误如下 RMAN gt restore spfile from FLASHBACKDATA1 ORAC AUTOBACKUP 2010 05 16 s 719137976 308 719137979 Sta
  • python关键字保留字

    and 逻辑运算符 as 创建别名 assert 用于调试 break 跳出循环 class 定义类 continue 继续循环的下一个迭代 def 定义函数 del 删除对象 elif 在条件语句中使用 等同于 else if else
  • 小学生编程入门视频-机器人背后的机关密语

    学校要举办Mabot机关门设计比赛 在这个视频教程中就让我们一起看看 运用Mabot能做出怎样有趣的设计吧 查看完整视频教程 先来看看我们的设计理念吧 首先机关A的大门在触发下能够自动打开和关闭 机关B能检测外界干扰并闪灯报警 然后把信号发
  • Mac下载iReport安装之后无法打开解决办法

    iReport的运行是依赖于JDK的 我使用的是Mac系统 下载的iReport 5 0 0 dmg版本 iReport 5 0 0 dmg百度网盘下载地址 链接 https pan baidu com s 1ECTVHMvtM0fyK4
  • qt5打开摄像头采集图像并拍照

    qt5打开摄像头 主要用了QCamera类 要在pro文件里加入 multimedia multimediawidgets这两个模块 QT core gui multimedia multimediawidgets QCamera是摄像头类
  • pycharm预览HTML文件显示浏览器错误,系统找不到文件chrome

    出现该错误可以考虑pycharm可能没有识别到浏览器的路径 下面以Google浏览器为例 在pycharm中运行html文件 出现如下问题 解决问题的方法如下 文件 gt 设置 gt 工具 gt web浏览器和预览 gt 把chrome换成
  • 国培学习阶段性总结

    国培学习心得 一个月很快过去了 通过这次培训 我们学到很多先进的知识 感受到了浓郁的公司文化 开阔了眼界 见识了软件行业的精英 使我提高了认识 理清了思路 学到了新的教学理念 找到了自身的差距和不足 这些将会是我一生受益的宝贵经历和财富 也
  • css添加 渐变色 阴影

    1 css添加阴影 四面发光 box shadow 0px 0px 30px 888888 2 css添加渐变色 红色过渡到蓝色 background webkit linear gradient red blue Safari 5 1 6
  • 【springMVC】SpringMVC的拦截器(Interceptor)和异常处理【附源码】

    一 拦截器 1 SpringMVC的拦截器 Interceptor 类似于Servlet中的过滤器 Filter 主要用于拦截用户请求 控制器方法 并做出相应的处理 例如 权限验证 判断用户是否登录等 2 拦截器的定义 1 实现Handle
  • 有什么软件可以让手机使用卫星通信吗?

    作者 NASA链接 https www zhihu com question 49532903 answer 119675952来源 知乎著作权归作者所有 商业转载请联系作者获得授权 非商业转载请注明出处 手机正常情况是不可以的 需要卫星系
  • Android项目中音频文件的存储位置

    1 Android工程中声音文件的存储位置在资源文件的raw文件夹下 2 如果在res文件夹下没有raw文件夹 新建一个即可 3 在音频文件数量多的情况下 将音频文件存放在assets目录 可避免文件的重复编译
  • pytorch:分类(Classification)

    使用pytorch python实现分类 程序 import torch import torch nn functional as F from matplotlib import pyplot as plt 创造数据 n data to