【机器学习 - 4】:线性回归算法

2023-10-27

线性回归


线性回归的理解

线性回归:判断数据的特征和目标值之间具有一定的线性关系。
简单线性回归:样本的特征只有一个,用线性回归法进行预测,叫做简单线性回归。
多元线性回归:样本的特征有两个或两个以上,叫做多元线性回归。

如下图所示,为线性回归模型
在这里插入图片描述

损失函数

损失函数:np.sum((y`-y)**2),即预测值和真实值的差值之和。因为有复数的存在所以求平方,不用绝对值的原因:用平方方便后续的求导和求极值。

最小二乘法
在这里插入图片描述
一些推导过程:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
重要结论
在这里插入图片描述

简单线性回归


  1. 先画出数据的散点图
import numpy as np
import matplotlib.pyplot as plt

x = np.array([1,2,3,4,5])
y = np.array([2,1,3,2,5])

plt.scatter(x,y)
plt.axis([0,6,0,6])
plt.show()

在这里插入图片描述

  1. 对数据进行处理,求出a和b
# y = a * x + b
# 先求出平均值
x_mean = np.mean(x)
y_mean = np.mean(y)

num = 0.0 # 分子
d = 0.0 # 分母

for x_i, y_i in zip(x, y):
    num += (x_i-x_mean)*(y_i-y_mean)
    d += (x_i-x_mean)**2
a = num/d
b = y_mean-a*x_mean

在这里插入图片描述

  1. 求出y`,并画出预测直线,求出这条线,使得真实值与预测值的差值达到最小。
y_hat = a * x + b

plt.plot(x, y_hat, color='r')
plt.scatter(x,y)
plt.axis([0, 6, 0, 6])
plt.show()

在这里插入图片描述

封装线性回归算法

import numpy as np


class SimpleLinearRegression:
    def __init__(self):
        self.a_ = None
        self.b_ = None
        self.x_mean = None
        self.y_mean = None

    def fit(self, x_train, y_train):
        self.x_mean = np.mean(x_train)
        self.y_mean = np.mean(y_train)

        num = 0.0   # 分子
        d = 0.0     # 分母
        for x_i, y_i in zip(x_train, y_train):
            num += (x_i-self.x_mean) * (y_i-self.y_mean)
            d += (x_i-self.x_mean)**2
        self.a = num/d
        self.b = self.y_mean - self.a * self.x_mean

        return self

    def predict(self, x_test):
        return self.a * x_test + self.b

    def __repr__(self):
        return 'SimpleLinearRegression()'

在jupter notebook中导入运行
在这里插入图片描述

线性回归算法


使用线性回归算法的前提:数据具有一定的线性关系。

我们希望找到一条最佳拟合的直线方程,y=ax+b,对于每一个样本点,在这个直线方程上都有一个预测值,预测值和真实值有一定的差距,我们希望这些样本到直线方程的差距之和最小。

计算差距:sqrt(|y-y`|**2),使用平方并开根号的方式更适合我们进行求导或求值。

在sklearn中调用线性回归算法

  1. 导入模块
from sklearn.linear_model import LinearRegression
import numpy as np
import matplotlib.pyplot as plt
  1. 准备数据,训练模型
# 准备数据
x = np.array([1,2,3,4,5])
y = np.array([3,1,4,3,6])

lin_reg = LinearRegression()
lin_reg.fit(x.reshape(-1,1), y)	# 拟合,训练模型

在这里插入图片描述

  1. 画出散点图和预测直线
plt.scatter(x, y)
plt.plot(x, lin_reg.predict(x.reshape(-1,1)), color='r')
plt.axis([0,6,0,7])
plt.show()

在这里插入图片描述

向量化运算

如下图所示,向量化运算更加方便,向量点乘是先乘后加,原理一样。
在这里插入图片描述

x = np.array([1,2,3,4,5])
y = np.array([3,1,4,3,6])
def lin_fit(x, y):
    x_mean = np.mean(x)
    y_mean = np.mean(y)
    num = 0.0
    d = 0.0
    num = (x-x_mean).dot(y-y_mean)
    d = (x-x_mean).dot(x-x_mean)
    a = num/d
    b = y_mean-a*x_mean
    return a, b

在这里插入图片描述

线性回归模型中的误差


在分类问题可以将score看成准确率,在回归问题将score看成模型的好坏程度。
在这里插入图片描述

均方误差 MSE

均方误差的公式如下图所示:
在这里插入图片描述
为什么要除以样本数量m?
举个例子,比如第一个团队有2个人,统计其工资的均方误差为800,第二个团队有100个人,工资的均方误差为1000,能说明第一个团队比较好吗?这是不行的,因为统计的个数不同,样本不同,导致量纲不一样,所以需要除以样本数量m,减少量纲的影响。

封装的函数

# 均方误差 MSE
def MSE(y_true, y_predict):
    return np.sum((y_true-y_predict)**2)/len(y_true)

在这里插入图片描述

均方根误差

在这里插入图片描述
在均方误差中进行开根号处理,可以消除量纲的影响。

封装的均方根误差

# 均方根误差
from math import sqrt
def RMSE(y_true, y_predict):
    return sqrt(np.sum((y_true-y_predict)**2)/len(y_true))

在这里插入图片描述

平均绝对误差

在这里插入图片描述
封装的平均绝对误差

# 平均绝对误差
def MAE(y_true, y_predict):
    return np.sum(np.absolute(y_true-y_predict))/len(y_true)

在这里插入图片描述

调用sklearn中的均方根误差和平均绝对误差函数

from sklearn.metrics import mean_squared_error, mean_absolute_error
mean_squared_error(x, y_hat)
mean_absolute_error(x, y_hat)

在这里插入图片描述

R squared error (常用)

R^2(以下用R2表示)分类的准确度在0和1之间,R2为1时,模型最优,即没有出现任何错误。

计算公式如下:
在这里插入图片描述
封装R squared error

import numpy as np
x = np.array([1,2,3,4,5])
y = np.array([3,1,4,3,6])
def r2_score(x_true, y_predict):
    return 1-((np.sum((x_true-y_predict)**2)/len(x_true))/np.var(x_true))

在这里插入图片描述

或调用均值方差 MSE
在这里插入图片描述
调用sklearn中的线性回归算法,计算预测值,最终的误差结果还是一样

from sklearn.linear_model import LinearRegression
lin_reg = LinearRegression()
lin_reg.fit(x.reshape(-1,1), y)
y_predict = lin_reg.predict(x.reshape(-1,1))

在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【机器学习 - 4】:线性回归算法 的相关文章

  • 从框架中获取可调用对象

    给定框架对象 由sys getframe http docs python org library sys html sys getframe 例如 我可以获得底层的可调用对象吗 代码解释 def foo frame sys getfram
  • Pandas ParserError:标记数据时出错。 C 错误:字符串内有 EOF

    我的数据超过 400 000 行 运行此代码时 f pd read csv filename error bad lines False 我收到以下错误 pandas errors ParserError Error tokenizing
  • 用于读取类似 CSV 行的 Python 正则表达式

    我想解析传入的类似 CSV 的数据行 值用逗号分隔 逗号周围可能有前导和尾随空格 并且可以用 或 引用 例如 这是有效的行 data1 data2 data3 data4 data5 但这是格式错误的 data1 data2 da ta3
  • 使用python编辑html,但是lxml将漂亮的html实体转换为奇怪的编码

    我正在尝试使用 python 带有 pyquery 和 lxml 来更改和清理一些 html Eg html div p It 146 s a spicy meatball p div lxml html clean 函数 clean ht
  • 让 python 脚本打印到终端而不作为标准输出的一部分返回

    我正在尝试编写一个返回值的 python 脚本 然后我可以将其传递给 bash 脚本 问题是我想要在 bash 中返回一个单一值 但我想要一些东西一路打印到终端 这是一个示例脚本 我们称之为 return5 py usr bin env p
  • 在 ubuntu 中卸载 python 模块

    我必须删除一个名为 django 的 python 模块 一种流行的模块 因为我安装了错误的版本 1 3 py 2 6 中的 beta 如何卸载这个模块 请解释一下 因为我只在 Windows 中使用过 python 而从未在 Ubuntu
  • 清理 MongoDB 的输入

    我正在为 MongoDB 数据库程序编写 REST 接口 并尝试实现搜索功能 我想公开整个 MongoDB 接口 我确实有两个问题 但它们是相关的 所以我将它们放在一篇文章中 使用 Python json 模块解码不受信任的 JSON 是否
  • 在Python中,如何通过去掉括号和大括号来打印Json

    我想以一种很好的方式打印 Json 我想去掉方括号 引号和大括号 只使用缩进和行尾来显示 json 的结构 例如 如果我有一个像这样的 Json A A1 1 A2 2 B B1 B11 B111 1 B112 2 B12 B121 1
  • Python SQLite3 SQL注入漏洞代码

    我知道下面的代码片段由于 format 的原因很容易受到 SQL 注入的攻击 但我不知道为什么 有谁明白为什么这段代码容易受到攻击以及我从哪里开始修复它 我知道这些代码片段使输入字段保持打开状态 以便通过 SQL 注入执行其他恶意命令 但不
  • 熊猫系列到二维数组

    所以 我使用了来自的答案将二维数组放入 Pandas 系列中 https stackoverflow com questions 38840319 put a 2d array into a pandas series将 2D numpy
  • python 硒 按名称查找元素

    查找电子邮件输入的正确代码是什么https accounts google com ServiceLogin html 是
  • 在 Python 中引发异常的正确方法是什么? [复制]

    这个问题在这里已经有答案了 这是简单的代码 import sys class EmptyArgs StandardError pass if name main The first way to raise an exception if
  • 如何在使用 Flask for Python 3 的同时使用 Bootstrap 4?

    我检查过 发现默认安装时 Flask Bootstrap 原生使用 Bootstrap 3 3 7 但实际上我想通过使用 Flask Bootstrap 包在我的项目中使用 Bootstrap 4 任何有关如何更新它或类似内容的帮助将不胜感
  • 哈希 freezeset 与排序元组

    在 Python 中 给定一组可比较的 可散列的元素s 散列是否更好frozenset s or tuple sorted s 这取决于你在做什么 创建一个更快frozenset 比排序tuple but frozenset占用的内存比tu
  • PySpark DataFrame 上分组数据的 Pandas 式转换

    如果我们有一个由一列类别和一列值组成的 Pandas 数据框 我们可以通过执行以下操作来删除每个类别中的平均值 df DemeanedValues df groupby Category Values transform lambda g
  • model.predict() 返回类而不是概率

    Hello 我是第一次使用 Keras 我训练并保存了一个模型 作为 json 文件及其权重 该模型旨在将图像分为 3 个类别 我的编译方法 model compile loss categorical crossentropy optim
  • 阻止 BeautifulSoup 将我的 XML 标签转换为小写

    我正在使用 BeautifulStoneSoup 来解析 XML 文档并更改一些属性 我注意到它会自动将所有 XML 标签转换为小写 例如我的源文件有
  • Docker Python 脚本找不到文件

    我已经成功构建了一个 Docker 容器 并将应用程序的文件复制到 Dockerfile 中的容器中 但是 我正在尝试执行引用输入文件 在 Docker 构建期间复制到容器中 的 Python 脚本 我似乎无法弄清楚为什么我的脚本告诉我它无
  • Python 中的可逆 STFT 和 ISTFT

    有没有通用的形式短时傅立叶变换 https en wikipedia org wiki Short time Fourier transform与内置于 SciPy 或 NumPy 或其他什么中的相应逆变换 这是pyplotspecgram
  • 从 HDF5 文件中删除信息

    我意识到 SO 用户以前曾问过这个问题question https stackoverflow com questions 1124994 removing data from a hdf5 file rq 1但它是在 2009 年被问到的

随机推荐

  • hibernate注解

    现在EJB3实体Bean是纯粹的POJO 实际上表达了和Hibernate持久化实体对象同样的概念 他们的映射都通过JDK5 0注释来定义 EJB3规范中的XML描述语法至今还没有定下来 注释分为两个部分 分别是逻辑映射注释和物理映射注释
  • 目标检测一阶段和二阶段对比图

    图片来源
  • 『学Vue2+Vue3』认识Vue3

    认识Vue3 1 Vue2 选项式 API vs Vue3 组合式API 特点 代码量变少 分散式维护变成集中
  • Linux下安装jre

    Linux下安装Java运行环境 现需要项目部署到Linux中 需要配置java运行环境 注 以下测试环境系统为centOS 用户为超级管理员 jre8 1 下载最新版的jre 服务器环境下不需要配置jdk 下载地址如下 http www
  • microsoft visual c++ 6.0中文版两种使用方法

    microsoft visual c 6 0 是一款语言编程软件 那么很多人都不知道microsoft visual c 6 0中文版怎么使用 其实使用方法很简单哦 只要打开microsoft visual c 6 0中文版就可以进行语言编
  • 《数据结构》 图的创建与遍历 代码表示

    测试数据 10 15 共10个顶点 15条边 0 1 0 8 0 0 第一 二个数表示连接两个顶点的起始顶点 第三个数1表示单通行 0表示双向通行 4 8 1 5 4 0 5 9 1 0 6 0 7 3 1 8 3 1 2 5 0 2 1
  • c#排列组合算法

    Combinatorics cs代码清单 using System using System Collections using System Data
  • 二叉树所有节点转换成大于该节点的平均值,没有最大值就转换成0

    import java util ArrayList import java util List import java util function ToIntFunction import java util stream Collect
  • CUnit(单元测试框架)

    CUnit是一个用C语言编写 管理和运行单元测试的轻量级系统 它为C程序员提供了具有灵活多样用户界面的基本测试功能 CUnit是作为一个静态库构建的 它与用户的测试代码链接在一起 它使用一个简单的框架来构建测试结构 并为测试公共数据类型提供
  • Buildroot制作根文件系统过程(基于MYD-AM335X开发板)

    buildroot的功能很强大 可以利用它制作交叉编译工具链 根文件系统 甚至可以构建多种嵌入式平台的bootloader linux 下面以米尔科技的MYD AM335X平台为例展示如何利用buildroot制作自己所需的根文件系统 一
  • 柔性OLED拼接屏有哪些场景化应用?

    柔性OLED拼接屏是一种新型的显示技术 它采用了柔性OLED屏幕 可以实现多个屏幕的拼接 形成一个大屏幕显示 这种技术可以应用于各种场合 如商业展示 广告宣传 会议演示等 柔性OLED屏幕是一种新型的显示技术 它采用了柔性材料作为基底 可以
  • java基于ssm+vue的共享充电宝管理系统 elementui

    随着时代的发展 人们的生活越来越离不开手机 但是因为技术水平等原因的限制 手机的电池并没有人们想象中的那么耐用 很多时候人们在外出的时候 很可能会遇到手机没电的情况发生 作为日常通讯的必备工具 如果没电了 很可能会影响一些重要的事情 尤其是
  • 3D相机调研

    最近因为自己实验需要配置一个3D相机 安装在机械臂上实现eye in arm的自动化引导过程 调研结果记录如下 3D相机又称为深度相机 即通过该相机能检测出拍摄空间的景深距离 与普通相机 2D的最大区别 普通彩色相机 2D相机 拍摄到的图片
  • JS -- input输入框只能输入正整数

    摘自文章 input输入框只能输入正整数 半城烟沙的技术博客 51CTO博客 one
  • 最大比例

    X星球的某个大奖赛设了M级奖励 每个级别的奖金是一个正整数 并且 相邻的两个级别间的比例是个固定值 也就是说 所有级别的奖金数构成了一个等比数列 比如 16 24 36 54 其等比值为 3 2 现在 我们随机调查了一些获奖者的奖金数 请你
  • 面试题深入思考01-----Arrays.sort()与Collections.sort()

    面试题深入思考01 Arrays sort 与Collections sort 1 Collections sort Collections本质是关于集合的一种工具类 其中包含对集合的各种api 例如排序 反转 交换和复制等 其中sort方
  • word怎么恢复保存前的文件,word文件恢复

    我们在使用word编辑文档时偶尔会有误删除文档的经历 word要怎么恢复保存前的文件呢 本文为你提供了五种解决思路 你可以通过搜索word文档的备份文档 自动恢复文件 临时文件 回收站 第三方数据恢复软件找到文档 方法一 搜索 Word 备
  • katex

    Katex Accents Accent functions inside text Delimiters Delimiter Sizing Environments Letters and Unicode Other Letters Un
  • Android ----蓝牙架构

    蓝牙 1 fromwork 2 service 3 driver Bluetooth apk bluedroid 芯片厂家 fromwork到service直接调用 service到driver利用service调用 fromwork到dr
  • 【机器学习 - 4】:线性回归算法

    文章目录 线性回归 线性回归的理解 损失函数 简单线性回归 封装线性回归算法 线性回归算法 在sklearn中调用线性回归算法 向量化运算 线性回归模型中的误差 均方误差 MSE 均方根误差 平均绝对误差 调用sklearn中的均方根误差和