PyTorch实现Logistic regression

2023-10-26

逻辑回归(Logistic regression)

回归方法是对数值型连续随机变量进行预测和建模的监督学习算法。其特点是标注的数据集具有数值型的目标变量。回归的目的是预测数值型的目标值。

逻辑回归对应线性回归,旨在解决分类问题,即将模型的输出转换为0/1值。逻辑回归直接对分类的可能性进行建模,无需事先假设数据的分布。

最理想的转换函数是单位阶跃函数(也称Heaviside函数),但单位阶跃函数是不连续的,没法在实际计算中使用。故而,在分类过程中更常使用对数几率函数(即sigmoid函数):

因此,回归模型为:

将该模型的损失函数定义为:

(其中, y ^ \hat{y} y^表示任意输入一个数据,经过Sigmoid之后这个数据点属于第二类的概率,那么其属于第一类的概率就是 1 − y ^ 1-\hat{y} 1y^;y 表示真实的 label,只能取 {0, 1} 这两个值。)

PyTorch实现逻辑回归

import torch
from torch import nn
from torch.autograd import Variable

#构造数据
n_data = torch.ones(100,2)
x0 = torch.normal(2*n_data)
y0 = torch.zeros(100)
x1 = torch.normal(-2*n_data)
y1 = torch.ones(100)

x = torch.cat((x0,x1)).type(torch.FloatTensor)
y = torch.cat((y0,y1)).type(torch.FloatTensor)

#定义LogisticRegression
class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression,self).__init__()
        self.lr = nn.Linear(2,1)
        self.sm = nn.Sigmoid()
        
    def forward(self,x):
        x = self.lr(x)
        x = self.sm(x)
        return x
    
logistic_model = LogisticRegression()
criterion = nn.BCELoss()
optimizer = torch.optim.SGD(logistic_model.parameters(),lr = 1e-3,momentum=0.9)

#训练过程
for epoch in range(10000):
    x_data = Variable(x)
    y_data = Variable(y)
    
    out = logistic_model(x_data)
    loss = criterion(out,y_data)
    print_loss = loss.data.item()
    mask = out.ge(0.5).float()
    correct = (mask == y_data).sum()
    acc = correct.item()/x_data.size(0)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (epoch+1)%20 == 0:
        print('*'*10)
        print('epoch {}'.format(epoch+1))
        print('loss is {:.4f}'.format(print_loss))
        print('acc is {:.4f}'.format(acc))

#参数输出
w0,w1 = logistic_model.lr.weight[0]
w0 = float(w0.item())
w1 = float(w1.item())
b = float(logistic_model.lr.bias.item())

print('w0:{}\n'.format(w0),'w1:{}\n'.format(w1),'b:{0}'.format(b))

参考:
https://github.com/KeKe-Li/tutorial/blob/master/assets/src/RAM/RAM.0.3.md
https://blog.csdn.net/qjk19940101/article/details/79573623

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch实现Logistic regression 的相关文章

随机推荐

  • Unity脚本中枚举类型在inspector面板中文显示

    效果 工具脚本 ChineseEnumTool cs using System using UnityEngine if UNITY EDITOR using UnityEditor using System Reflection usin
  • STM使用SPI协议通信-基础(标准库)

    SPI协议是摩托罗拉公司开发的协议 它以主从方式工作 这种模式通常有一个主设备和一个或多个从设备 至少需要下列4根线 1 MISO Master Input Slave Output 主设备数据输入 从设备数据输出 2 MOSI Maste
  • 华为OD机试 - 查字典(Java)

    题目描述 输入一个单词前缀和一个字典 输出包含该前缀的单词 输入描述 单词前缀 字典长度 字典 字典是一个有序单词数组 输入输出都是小写 输出描述 所有包含该前缀的单词 多个单词换行输出 若没有则返回 1 用例 输入 b 3 a b c 输
  • RocketMQ的死信队列

    死信队列用于处理无法被正常消费的消息 当一条消息初次消费失败 消息队列会自动进行消息重试 达到最大重试次数后 若消费依然失败 则表明消费者在正常情况下无法正确地消费该消息 此时 消息队列 不会立刻将消息丢弃 而是将其发送到该消费者对应的特殊
  • Python数据可视化之条形图和热力图

    Python数据可视化之条形图和热力图 提示 介绍 简单介绍Pthon可视化的图表使用 提示 热力图和条形图 文章目录 Python数据可视化之条形图和热力图 前言 一 导入数据包 二 选择数据集 2 加载数据 2 读入数据 总结 前言 提
  • vue 二级级联菜单

    ul class sidebar menu li li ul
  • carplay是否可以用安卓系统_carplay能连接安卓手机吗

    carplay能连接安卓手机吗 carplay并不可以连接安卓手机 这一系统只能连接苹果的设备 有非常多车基本都有carplay功能 假如有这一功能 那么就可以将手机与自己的苹果手机连接起来 这样子可以导航 接打电话 用语音助手调节车机 听
  • 微信小程序之数据缓存

    在H5之前 缓存一般都是用cookie 但是cookie的存储空间太小 于是 H5增加了新的缓存机制 即localstorage 和 sessionstorage 具体的介绍就不在多说 在微信小程序中 数据缓存其实就和localstorag
  • 前端URL编码与解码:理解、应用与实践

    目录 什么是URL编码和解码 为什么需要URL编码和解码 1 特殊字符处理 2 支持非ASCII字符 3 SEO优化与用户体验 JavaScript中的URL编码和解码 URL编码示例 URL解码示例 实际应用场景 1 处理查询参数 2 构
  • Vue-CLI and Leaflet(2):地图基本操作(放大,缩小,平移,定位等)

    一 Vue CLI and Leaflet 起步 在 Vue CLI 中使用 Leaflet 二 Vue CLI and Leaflet 地图基本操作 放大 缩小 平移 定位等 三 Vue CLI and Leaflet 添加 marker
  • 《caffe学习之路》第一章:Ubuntu16.04 cuda及cudnn环境搭建

    这里我们选择一种简单的方式搭建cuda环境 那就是JetPack他会自动安装最新的驱动 CUDA Toolkit cuDNN TensorRT Opencv Python等 环境 系统 Ubuntu16 04 显卡 NVIDIA GTX20
  • Java 8系列之重新认识HashMap

    摘要 HashMap是Java程序员使用频率最高的用于映射 键值对 处理的数据类型 随着JDK Java Developmet Kit 版本的更新 JDK1 8对HashMap底层的实现进行了优化 例如引入红黑树的数据结构和扩容的优化等 本
  • vs2015的OpenCV3.2.0编译

    我们希望添加第三方功能模块和库或者针对特定cpu和gpu的编译调整优化选项 这样的需求就需要自己去编译opencv了 准备东西 opencv opencv contrib cmake 还有两个文件 因为可能是国内的原因 在configure
  • eviews建立时间序列模型_如何用eviews分析时间序列(全面).pdf

    您所在位置 网站首页 gt 海量文档 nbsp gt nbsp中学教育 nbsp gt nbsp高中教育 如何用eviews分析时间序列 全面 pdf70页 本文档一共被下载 次 您可全文免费在线阅读后下载本文档
  • 二层组播和三层组播

    平时常常说组播 其实只是多播的另外一种叫法 多播中 因为把参与多播的所有接收者称为组 所以才有组播的说法 多播技术要比广播技术复杂的多 多播技术对一些应用很重要 比如电视会议 聊天室等 物理层多播 系统需要对网络接口进行配置 让接口识别该地
  • MATLAB行向量顺序颠倒函数 - fliplr

    fliplr A 只可用于行向量 列向量不行 实例 1 行向量 2 列向量
  • 如何使用正则表达式实现Java日志信息的抓取与收集

    首先 什么是Java日志信息 简单来说 Java应用程序在运行过程中会输出一些信息 这些信息可以用来追踪程序运行状态 调试错误等 而Java日志信息就是这些输出信息的集合 那么为什么要抓取和收集Java日志信息呢 一方面 这些信息可以帮助我
  • 失业的程序员(八):创业的要素

    一 管饭哥登场 按理说我规定我和卞工的上班时间是上午8点到10点 弹性足够大 虽曰规定 但是遵不遵守随意 原因只有一个 引用卞工的话 就两个人 考毛勤 我 很是认可 严密的考勤制度的建立是老板对员工不怎么太信任的开始 是一种等级制度的体现
  • 1.神奇的字符串之快速求和

    文章目录 前言 正题 先看第一个代码 直接循环取出每一位数 总结 前言 这个专栏是分享一些好用的数据 和一些解题比较快的小方法 会持续更新 因为博主还是计算机方向的小白 知道的东西还是很少 希望大家可以多多指教 正题 众所周知 字符串一直是
  • PyTorch实现Logistic regression

    逻辑回归 Logistic regression 回归方法是对数值型连续随机变量进行预测和建模的监督学习算法 其特点是标注的数据集具有数值型的目标变量 回归的目的是预测数值型的目标值 逻辑回归对应线性回归 旨在解决分类问题 即将模型的输出转