Pytorch实现多特征输入的分类模型 代码实操

2023-11-17

初学者学习Pytorch系列

第一篇 Pytorch初学简单的线性模型 代码实操
第二篇 Pytorch实现逻辑斯蒂回归模型 代码实操
第三篇 Pytorch实现多特征输入的分类模型 代码实操



前言

  1. 本文的输入数据集是基于一个糖尿病预测的案例。输入的大概意思是,给病人测试8次身体情况,预测明年是否会病情加重。这个8次,就是作为多特征的输入,8个特征,是否加重(0表示不会加重,1表示加重)作为一个二分类问题。

一、先上代码

代码如下(解释已经写在代码中):

import numpy as np
import torch

xy = np.loadtxt('../data/diabetes.csv.gz', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:, :-1])
y_data = torch.from_numpy(xy[:, [-1]])      # 这里第二维度用[-1],可以使分割出来的是矩阵
                                            # 不加[][1.0.2.0,3.0],加了后是[[1.0],[2.0],[3.0]],


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x


model = Model()
# 定义损失函数
criterion = torch.nn.BCELoss(reduction='mean')
# 优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

for epoch in range(10000):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    if epoch > 9900:
        print(epoch, loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


二、测试结果

1. 数据结果

9979 轮开始展示
轮数 损失
9979 0.4617641568183899
9980 0.46176356077194214
9981 0.4617629647254944
9982 0.461762398481369
9983 0.46176180243492126
9984 0.46176114678382874
9985 0.46176064014434814
9986 0.461760014295578
9987 0.46175938844680786
9988 0.4617588222026825
9989 0.46175816655158997
9990 0.4617575705051422
9991 0.4617570638656616
9992 0.4617564380168915
9993 0.4617558419704437
9994 0.46175524592399597
9995 0.4617546796798706
9996 0.46175411343574524
9997 0.4617534875869751
9998 0.46175292134284973
9999 0.4617522656917572

三、代码说明

1. 数据集引入

xy = np.loadtxt('../data/diabetes.csv.gz', delimiter=',', dtype=np.float32)

这里是导入数据集,文件的位置是相对位置。在项目中的位置如下:
数据位置
数据来源:B站刘二大人课程资料
链接:https://pan.baidu.com/s/1_J1f5VSyYl-Jj2qIuc1pXw
提取码:wyhu

2.对Model类的理解

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x

这里对初学者可能会很不理解为什么要从8维到6维再到4维再到1维?而且为什么每次都要通过sigmoid?

  1. 这里的维度的变化,没变化一次就是深度学习里面的一层 (对于层的理解,可以把它看成一个过滤网,一层层过滤掉信息),层数多可以进行精细化一点的处理,在这个例子中也可以从8维到5维到3维再到1维,但是选用哪种好,是另一个问题,但是并不是过滤得越多越好,过滤越多学习能力太强,可能会把一些噪声的特征也学习进来。

这里可能有人会问为什么每一层都需要应该sigmoid函数?

  1. 首先得了解sigmoid是什么,它是一个激活函数,是给我们的线性函数(linear直线模型)做非线性变化的,这里又涉及到一个问题:为什么要做非线性变化? 这里我找到一张图片,很好的表现的为什么线性模型不行。

  2. 下图又两种颜色的点,我们需要使用深度学习将下面两种点进行分割,在第一张图中,只是线性模型,进行区分的能力有限,如果图中的点再多点再乱点,是无法满足需求的,但是在第二张图中引入了非线性模型,就可以进行更加灵活的变化,达到更加好的区分度。图三是过度学习的一个案例,区分得太过于精细了。
    非线性变化

  3. 那为什么用sigmoid能拟合呢??
    sigmoid的函数式是1/(1+e-x),在例子中,x是代入我们的线性模型wx+b的值,相当于1/(1+ewx+b),而我们调整w和b,这个函数是可以变成不同形状的函数,以此才能拟合奇奇怪怪的函数,不会只能拟合直线。

  4. 所以回到问题本身,我们知道了为什么用sigmoid,但是为什么用多次呢 ,下面我画了一个简单的示意图(图与本文例子无关)
    拟合我们仅仅通过一个sigmoid,就算我们不断调整w和b,函数变化的形式依旧比较单一,所以在上图中只用一个sigmoid,我们可能可以拟合0-1-2这一段,但是拟合2-3这一段可能会有比较大误差,如果拟合了2-3这一段,又可能影响了0-1-2这一段,但是我们在使用多个sigmoid的时候,就可能可以拟合各段的函数形式。

  5. 8维到6维对应的矩阵变化形式,这里的函数是向量函数,即对向量中每一个数代入后结果再放到向量中,结果还是一个函数。在这里插入图片描述

3.对优化器的理解

criterion = torch.nn.BCELoss(reduction='mean')

这里的reduction是一个属性,决定loss是要累加还是取平均,在本例中,如果不采用平均,会导致偏差很大。但是在前两篇文章中,使用累加没有影响。在实验后发现是lr是关系,lr是学习率,如果采用累加,可能需要把学习率调小一点。


总结

以上就是今天要讲的内容,本文仅仅是向初学者以通俗一点的方式介绍了pytorch实现多特征输入分类问题的基本使用。讲解过程属于个人理解,如有误,请谅解。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch实现多特征输入的分类模型 代码实操 的相关文章

随机推荐

  • tensorrt和onnxruntime-gpu同时调用gpu时tensorrt推理出现错乱解决方式

    问题 当我在同一个进程同时调用tensorrt和onnxruntime gpu时 出现了tensorrt推理结果全为0的情况 解决方式 将onnxruntime gpu放到cpu上 但是cpu的推理速度明显会不如gpu 如果在python中
  • 深度剖析数据在内存中的存储(修炼内力)

    目录 一 数据类型的介绍 1 1数据大小 1 2类型的基本归类 二 整型在内存中的存储 2 1原码 反码 补码 2 2大小端介绍 2 2 1大小端的起源 2 2 2大小端的概念 2 2 3为什么会有大端和小端 2 2 4设计一个小程序来判断
  • Fedora 启动顺序

    http hi baidu com wwwkljoel item 29620217882a585b2b3e2244 The start of the Fedora fedora 系统加电或复位后 中央处理器将内存中的所有数据清零 并对内存进
  • html往下滑变成水平,HTML - 水平滑块CSS最佳方法_html_开发99编程知识库

    由於每個部分的位置已經設置為relative 意味著將relative定位到上一節 因此可以將其他部分設置為left 0 margin 0 all sections display inline flex main about profes
  • 【学】saas系统前端技术选型,需要考虑哪些方面?

    对于saas前端技术选型 可以考虑以下几个方面 框架选择 目前比较流行的前端框架有React Vue Angular等 可以根据项目需求和团队技术水平选择合适的框架 例如 如果需要高度可定制性和灵活性 可以选择React 如果需要快速开发和
  • 数学建模之灰色关联实例含代码

    参考书籍 数学建模算法与应用 一 预备 1 无量纲化处理技术 二 灰色关联的步骤 通过对某健将级女子铅球运动员的跟踪调查 获得其 1982 年至 1986 年每年好成绩及16 项专项素质和身体素质的时间序列资料 见表 2 试对此铅球运动员的
  • linux-UNIX socket

    UNIX域套接字 域套接字作为进程间通信的一种手段 值得我们研究一下 域套接字实现本地进程间通信 同样有服务端和客户端之分 一个进程作为客户端 另一个进程作为服务端 这个和TCP socket类似 但是不一样 域套接字不经过底层网络 数据结
  • LaTeX 数学公式大全!

    LaTeX 数学公式大全 这里是来自一篇教程的截图 很全面
  • java.util之ArrayList使用

    java util之ArrayList使用 一 概述 ArrayList底层实际是通过一个数组来保存数据 其默认大小为10 扩容机制为新的容量 原始容量x3 2 1 允许空值 有序 为线程不安全 可以使用迭代器遍历 里面的的元素全部都是对象
  • NoteExpress安装时问题解决

    每次安装软件我都不能一次性成功 这次遇见的是NoteExpress和Word权限不一致的问题 版本 win10 office2019 网上有很多方法 其中CSDN博主 令令狐大侠 总结郭一篇 原文链接 https blog csdn net
  • 【华为OD机试】工号不够用了怎么办 (C++ Python Java)2023 B卷

    题目描述 3020年 空间通信集团的员工人数突破20亿人 即将遇到现有工号不够用的窘境 现在 请你负责调研新工号系统 继承历史传统 新的工号系统由小写英文字母 a z 和数字 0 9 两部分构成 新工号由一段英文字母开头 之后跟随一段数字
  • 关于split截取字符时,问号的特殊情况

    有一段字符 tring str gjjxxcx gjjxx cx jsp zgzh 1010024000019 如果使用如下代码 String strArray str split gjjxx cx jsp System out print
  • 基础算法题——带分数(全排列,工具库)

    前言 这道题理解起来不难 但是要找到一个合适的方法对题目进行优化 就会相对麻烦些 蓝桥杯的题 真的到处都是坑的感觉 带分数题目 资源限制 时间限制 1 0s 内存限制 256 0MB 问题描述 100 可以表示为带分数的形式 100 3 6
  • 表单注入——sqli-labs第11~16关

    目录 第11关 0 万能账号 密码的前提 1 判断是否POST注入 2 猜测后台SQL语句 3 判断闭合符 4 查询列数 5 找显示位 6 查库名 7 查表名 8 查列名 9 找账号密码 第12关 第13关 第14关 1 2 3 4 5 6
  • Leetcode148.排序链表——排序问题详解

    文章目录 引入 归并排序解法 其他 引入 148 排序链表题目如下 148 排序链表 在 O n log n 时间复杂度和常数级空间复杂度下 对链表进行排序 示例 1 输入 4 gt 2 gt 1 gt 3 输出 1 gt 2 gt 3 g
  • 工作中常用且容易遗忘的css样式整理,建议收藏

    1 文字超出部分显示省略号 单行文本的溢出显示省略号 一定要有宽度 p width 200rpx overflow hidden text overflow ellipsis white space nowrap 多行文本溢出显示省略号 p
  • Linux(驱动编程)(调试技术)(imx6ull)

    调试技术 1 在写驱动程序时函数未包含头文件 在linux内核源码driver char目录下输入命令 grep XXXX nrw 查看次函数在那个 c里用过 然后在vscode界面下按alt p搜索这个 c就可以参考这个 c的头文件 2
  • docker笔记(二)之镜像加速器

    国内从 Docker Hub 拉取镜像有时会遇到困难 此时可以配置镜像加速器 国内很多云服务商都提供了国内加速器服务 例如 阿里云加速器 点击管理控制台 gt 登录账号 淘宝账号 gt 右侧镜像中心 gt 镜像加速器 gt 复制地址 网易云
  • 从原理到应用,人人都懂的 ChatGPT 指南

    如何充分发挥ChatGPT潜能 成为了众多企业关注的焦点 但是 这种变化对员工来说未必是好事情 IBM计划用AI替代7800个工作岗位 游戏公司使用MidJourney削减原画师人数 此类新闻屡见不鲜 理解并应用这项新技术 对于职场人来说重
  • Pytorch实现多特征输入的分类模型 代码实操

    初学者学习Pytorch系列 第一篇 Pytorch初学简单的线性模型 代码实操 第二篇 Pytorch实现逻辑斯蒂回归模型 代码实操 第三篇 Pytorch实现多特征输入的分类模型 代码实操 文章目录 初学者学习Pytorch系列 前言