梯度下降法求解线性回归--Numpy实现

2023-11-15

梯度下降法求解一元线性回归

在这里插入图片描述
在这里插入图片描述
依然是这个房价预测的任务,这是一个一元线性回归问题,这次我们采用梯度下降法来求解它可以分为5步

  • 第1步加载样本数据x,y
  • 第2步设置超参数,在这个例子中,超参数包括学习率和迭代次数
  • 第3步设置模型参数的初值 w 0 w_0 w0 b 0 b_0 b0,这个初值可以是任意的
  • 第4步训练模型使用迭代公式更新模型,参数迭代完成之后,以可视化的形式输出结果

这是程序流程图,因为有迭代运算,所以需要通过循环来实现。
在这里插入图片描述
这部分是梯度下降法的实现,首先设置 w 0 w_0 w0 b 0 b_0 b0,设置循环变量,然后利用迭代公式不断更新,w和b,并且计算每一次迭代的损失,直到循环结束为止。

在上节课中我们知道超参数的设置非常重要,对训练结果有很大的影响,在训练之前我们往往并不知道这个超参数应该设置成多少,一般需要根据经验反复尝试,同时观察算法是否收敛,并且达到了我们需要的精度,下面我们就编程实现以上步骤。

第1步导入需要的库加载数据

x和y分别是面积和房价,把它们放在Numpy数组中,它们都是长度为16的一维数组

import numpy as np
import matplotlib.pyplot as plt

x = np.array([137.97, 104.50, 100.00, 124.32, 79.20, 99.00, 124.00, 114.00,
              106.69, 138.05, 53.75, 46.91, 68.00, 63.02, 81.26, 86.21])

y = np.array([145.00, 110.00, 93.00, 116.00, 65.32, 104.00, 118.00, 91.00,
              62.00, 133.00, 51.00, 45.00, 78.50, 69.65, 75.69, 95.30])

第2步设置超参数

  • learn_rate是学习率,通常是一个很小的常数。
  • iter是迭代次数迭代100次
  • 如果每次迭代都输出结果输出就会很长,也没有必要,因此我们可以每10次迭代输出一次结果这个display_step就是用来设置输出结果的。间隔它不属于超参数,因为它的取值完全不会影响模型的训练,只是会改变显示的效果,
learn_rate = 0.00001
iter = 100
display_step = 10

第3步给模型参数w和b设置初值

这是numpy的随机数生成函数。返回一个正态分布的浮点数组,当参数为空时,随机生成一个数字。

np.random.seed(612)
w = np.random.rand()
b = np.random.rand()

第4步训练模型

这个mse是一个Python列表,用来保存每次迭代后的损失值。
下面使用for循环实现迭代循环变量,从0开始,循环101次。为了描述方便,当i等于10时,我们就说第10次迭代。

mse = []
for i in range(0, iter + 1):
    dL_dw = np.mean(x*(w * x + b - y))
    dL_db = np.mean(w * x + b - y)
    
    w = w - learn_rate * dL_dw
    b = b - learn_rate * dL_db
    
    pred = w * x + b
    Loss = np.mean(np.square(y-pred))/2
    
    mse.append(Loss)
    
    if i % display_step == 0:
        print("i: %i, Loss: %f, w: %f, b: %f" % (i, mse[i], w, b))

在循环体中,首先计算损失函数队w和b的偏导数
在这里插入图片描述
然后使用迭代公式更新w和b
在这里插入图片描述
到这里就已经实现了梯度下降法的一次迭代,可以进入下一次循环了。但是我们希望能够观察到每次迭代的结果,判断是否收敛或者什么时候开始收敛,因此需要使用每次迭代后的w和b计算损失,并且把它显示出来。

这是使用当前这次循环得到的w和b计算所有样本的房价的估计值,这里的x是一个长度为16的一维数组,保存着所有样本的面积,因此这个pred也是一个长度为16的,一维数组,是所有的房价的预测值。
在这里插入图片描述
然后使用房价的实际值和预测值计算均方误差。y和pred都是长度为16的一维数组,这部分运算的结果仍然是一个一维数组,然后对这个数组中的所有元素,求评均值得到一个数字,最后再乘以1/2把得到的均方误差加入列表mse。mse一开始是一个空的列表,以后每执行一次,循环列表中就增加一个元素,整个循环执行完之后,其中就有101个元素,分别是每次迭代的损失值,如果当前的循环次数能够被10整除,就显示均方误差,w的值和b的值,这样做是为了能够动态的观察模型的训练过程。
在这里插入图片描述
这是i=0时的输出。
在这里插入图片描述
可以看到这个损失的值非常庞大。在前30次循环中,损失以很快的速度下降,这是因为在远离机制点的地方,损失函数曲线很陡峭,而且更新的步长比较大,因此损失的下降速度很快。
在这里插入图片描述
到第40次循环开始,损失下降的速度越来越慢,这是因为随着不断接近极值点,损失函数的曲线越来越平缓,步长也越来越小,因此损失的下降也越来越小。
在这里插入图片描述
在第100次迭代时,损失的值虽然还在下降,但是差值已经非常小了,这时候虽然还没有达到极值点,但是已经非常接近极值点了,大家可以把循环次数修改的更大一些,看看损失函数是不是还会下降。

梯度下降法得到的数值解是一个近似值,在收敛之后,只要达到了精度要求,就可以停止迭代,否则可以继续迭代,直到满足精度要求为止,下面把模型训练的结果以可视化的形式输出出来,

结果可视化–数据和模型

首先使用样本数据绘制销售记录散点图,然后绘制预测房价的散点图,这个pred是最后一次迭代之后计算出的房价的估计值,把它们连接在一起就是得到的线性模型。

plt.rcParams['font.sans-serif'] = ['SimHei']
plt.figure()

plt.scatter(x, y, color="red", label="销售记录")
plt.scatter(x, pred, color="blue", label="梯度下降法")
plt.plot(x, pred, color="blue")

plt.xlabel("Area", fontsize=14)
plt.ylabel("Price", fontsize=14)

plt.legend(loc="upper left")
plt.show()

这是运行的结果。
在这里插入图片描述
其中红色的点是实际的销售房价,蓝色的点是预测出的房价。蓝色的直线是训练得到的模型,那么这个模型是否准确呢?这是在上一讲中我们计算出的解。解析解是一个精确的结果。现在我们可以把解析解对应的线性模型也绘制出来,进行比较。

plt.rcParams['font.sans-serif'] = ['SimHei']
plt.figure()

plt.scatter(x, y, color="red", label="销售记录")
plt.scatter(x, pred, color="blue")
plt.plot(x, pred, color="blue", label="梯度下降法")
plt.plot(x, 0.89*x+5.41, color="green", label="解析法")

plt.xlabel("Area", fontsize=14)
plt.ylabel("Price", fontsize=14)

plt.legend(loc="upper left")
plt.show()

在这里插入图片描述
这里的w和b保留两位小数,这条绿色的直线就是解析法得到的模型。可以看到采用梯度下降法得到的模型和它有一定的偏差,但是在可以接受的范围之内,大家也可以尝试增加迭代次数,继续更新权值,让w和b更接近极值点。

在这里插入图片描述
这张图展示了整个迭代过程中模型直线的变化过程,这是使用w和b的初始值得到的直线,这是迭代100次之后得到的直线。后面的这些直线非常接近,很多都重叠在一起了,要做出这个图,只要在每次迭代之后,都增加一条plot函数绘制直线就可以了。

mse = []
for i in range(0, iter + 1):
    dL_dw = np.mean(x*(w * x + b - y))
    dL_db = np.mean(w * x + b - y)

    w = w - learn_rate * dL_dw
    b = b - learn_rate * dL_db

    pred = w * x + b
    Loss = np.mean(np.square(y-pred))/2

    mse.append(Loss)
    
    plt.plot(x, pred)

    if i % display_step == 0:
        print("i: %i, Loss: %f, w: %f, b: %f" % (i, mse[i], w, b))

结果可视化–损失变化

在这里插入图片描述
通过这张图,我们可以更加清楚的观察损失值的变化,图中的横坐标是迭代次数,纵坐标是损失值,然后逐渐减缓,会是这个图也非常简单,因为每次迭代的损失值都已经被存放在列表mse中了,现在只要把它们取出来,连在一起,就可以得到损失值变化的曲线图。

plt.figure()
plt.plot(mse)

plt.xlabel("Iteration", fontsize=14)
plt.ylabel("Loss", fontsize=14)

plt.show()

在这张图中,因为开始时的损失值非常大,所以纵坐标的刻度也很大,导致从第20次迭代以后损失的下降很难直接在这张图中看出来。如果想要观察到第20次迭代之后损失,可以把plot函数的参数修改一下第20次迭代之后的。

plt.figure()
plt.plot(range(20, 100), mse[20:100])

plt.xlabel("Iteration", fontsize=14)
plt.ylabel("Loss", fontsize=14)

plt.show()

在这里插入图片描述
可以看到现在从第20次迭代时的损失值148,开始显示损失变化曲线,在第20~40次迭代时损失下降的很快,40次之后就逐渐平坦了。

采用同样的方法也可以绘制出第40次迭代之后损失变化的曲线,大家可以继续尝试一下观察和总结损失下降的规律。

plt.figure()
plt.plot(range(40, 100), mse[40:100])

plt.xlabel("Iteration", fontsize=14)
plt.ylabel("Loss", fontsize=14)

plt.show()

在这里插入图片描述
在这里插入图片描述

另外要注意的是,这个图是在训练模型的过程中,损失函数的值变化的曲线,不是损失函数的图。损失函数本身应该和右边的图近似。

结果可视化–估计值&标签值

为了更加直观的展示预测值和实际值之间的差距,可以使用这样的图。这个图中的横坐标是样本序号,一共有16个点,每个点对应一套商品房,纵坐标是房价。

plt.plot(y, color="red", marker="o", label="销售记录")
plt.plot(pred, color="blue", marker="o", label="梯度下降法")

plt.legend()
plt.xlabel("Sample", fontsize=14)
plt.ylabel("PRICE", fontsize=14)
plt.show()

在这里插入图片描述
红色的数据点是样本数据,它们的纵坐标是每套商品房的实际销售价格。蓝色的点是我们通过模型预测出来的房价。
可以看到有些房价的估计很准确,有些存在着一定的偏差,这个图形也是使用plot函数绘制的,把参数mark的值设置为o,数据点就以这种大圆点的形式显示出来。

plt.plot(y, color="red", marker="o", label="销售记录")
plt.plot(pred, color="blue", marker="o", label="梯度下降法")
plt.plot(0.89*x+5.41, color="green", marker="o", label="解析法")

plt.legend()
plt.xlabel("Sample", fontsize=14)
plt.ylabel("PRICE", fontsize=14)
plt.show()

在这里插入图片描述
也可以把解析法得到的结果绘制出来进行对比,这条绿色的线就是解析法画出来的线。

为了使代码更加的简洁紧凑,结果更加便于观察,我们把这些图放在同一个画布中,显示首先设置画布尺寸,然后划分子图,在每个子图中分别绘制不同的图形。

plt.rcParams['font.sans-serif'] = ['SimHei']
plt.figure(figsize=(20, 4))

plt.subplot(1, 3, 1)
plt.scatter(x, y, color="red", label="销售记录")
plt.plot(x, pred, color="blue", label="预测记录")
plt.xlabel("Area", fontsize=14)
plt.ylabel("Price", fontsize=14)
plt.legend(loc="upper left")

plt.subplot(1, 3, 2)
plt.plot(mse)
plt.xlabel("Iteration", fontsize=14)
plt.ylabel("Loss", fontsize=14)

plt.subplot(1, 3, 3)
plt.plot(y, color="red", marker="o", label="销售记录")
plt.plot(pred, color="blue", marker="o", label="预测记录")
plt.legend()
plt.xlabel("Sample", fontsize=14)
plt.ylabel("PRICE", fontsize=14)
plt.legend(loc="upper left")

plt.show()

这是运行的结果。
在这里插入图片描述
现在我们已经使用梯度下降法完成了对一元线性回归模型的训练,下面就可以使用这个模型对未知的数据进行预测了。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

梯度下降法求解线性回归--Numpy实现 的相关文章

  • 如何访问pandas数据框中的多级索引?

    我想用相同的索引来调用这些行 这是示例数据框 arrays np array bar bar baz baz foo foo qux qux np array one two one two one two one two df pd Da
  • 打印 scrapy 请求的“响应”

    我正在尝试学习 scrapy 在遵循教程的同时 我正在尝试进行细微的调整 我想简单地从请求中获取响应内容 然后我会将响应传递到教程代码中 但我无法发出请求并获取响应内容 建议就好 from scrapy http import Respon
  • 如何更改充当按钮的范围的文本

    我正在为自定义 Web 应用程序编写自动化测试 我遇到了无法更改跨度文本的问题 我尝试过使用 driver execute script 但没有运气 如果我更好地了解 javascript 这确实会有帮助 据我所知 您无法单击跨度 并且列表
  • 在 Python 中使用 sec 函数的反函数

    我正在创建一个程序 用于计算从一定高度范围和设定初始速度发射射弹的最佳角度 在我需要使用的最终方程中 存在一个反 sec 函数 它导致了一些麻烦 我已经导入了数学并尝试使用 asec 无论如何 但是数学似乎无法计算反秒函数 我也明白 sec
  • 在 python-docx 中搜索和替换

    我有一个包含以下字符串的文档 模板 你好 我的名字是鲍勃 鲍勃是一个很好的名字 我想使用 python docx 打开此文档并使用 查找和替换 方法 如果存在 来更改每个字符串 Bob gt Mark 最后 我想生成一个新文档 其中包含字符
  • Python:当前目录是否自动包含在路径中?

    Python 3 4 通过阅读其他一些 SO 问题 似乎如果moduleName py文件位于当前目录之外 如果要导入它 必须将其添加到路径中sys path insert 0 path to application app folder
  • 将一个时间序列插入到 pandas 中的另一个时间序列中

    我有一组定期测量的值 说 import pandas as pd import numpy as np rng pd date range 2013 01 01 periods 12 freq H data pd Series np ran
  • python ttk treeview:如何选择并设置焦点在一行上?

    我有一个 ttk Treeview 小部件 其中包含一些数据行 如何设置焦点并选择 突出显示 指定项目 tree focus set 什么也没做 tree selection set 0 抱怨 尽管小部件明显填充了超过零个项目 但未找到项目
  • Python:随时接受用户输入

    我正在创建一个可以做很多事情的单元 其中之一是计算机器的周期 虽然我将把它转移到梯形逻辑 CoDeSys 但我首先将我的想法放入 Python 中 我将进行计数 只需一个简单的操作 counter 1 print counter 跟踪我处于
  • 使用Python将图像转换为十六进制格式

    我的下面有一个jpg文件tmp folder upload path tmp resized test jpg 我一直在使用下面的代码 Method 1 with open upload path rb as image file enco
  • Python unicode 字符代码?

    有没有办法将 Unicode 字符 插入 Python 3 中的字符串 例如 gt gt gt import unicode gt gt gt string This is a full block s unicode charcode U
  • Python int 太大,无法放入 SQLite

    我收到错误 OverflowError Python int 太大 无法转换为 SQLite INTEGER 来自以下代码块 该文件约25GB 因此必须分部分读取 length 6128765 Works on partitions of
  • 在 pip.conf 中指定多个可信主机

    这是我尝试在我的中设置的 etc pip conf global trusted host pypi org files pythonhosted org 但是 它无法正常工作 参考 https pip pypa io en stable
  • 使用 Doc2vec 后如何解释 Clusters 结果?

    我正在使用 doc2vec 将关注者的前 100 条推文转换为矢量表示形式 例如 v1 v100 之后 我使用向量表示来进行 K 均值聚类 model Doc2Vec documents t size 100 alpha 035 windo
  • 是否可以强制浮点数的指数或有效数匹配另一个浮点数(Python)?

    这是我前几天试图解决的一个有趣的问题 是否可以强制一个的有效数或指数float与另一个人一样float在Python中 出现这个问题是因为我试图重新调整一些数据 以便最小值和最大值与另一个数据集匹配 然而 我重新调整后的数据略有偏差 大约小
  • 从 dask 数据框中的日期时间序列获取年份和星期?

    如果我有一个 Pandas 数据框和一个日期时间类型的列 我可以按如下方式获取年份 df year df date dt year 对于 dask 数据框 这是行不通的 如果我先计算 像这样 df year df date compute
  • 如何对字符串列表进行排序?

    在 Python 中创建按字母顺序排序的列表的最佳方法是什么 基本回答 mylist b C A mylist sort 这会修改您的原始列表 即就地排序 要获取列表的排序副本而不更改原始列表 请使用sorted http docs pyt
  • 使用 Keras 和 fit_generator 绘制 TensorBoard 分布和直方图

    我正在使用 Keras 使用 fit generator 函数训练 CNN 这似乎是一个已知问题 https github com fchollet keras issues 3358TensorBoard 在此设置中不显示直方图和分布 有
  • 从时间序列生成日期特征

    我有一个数据框 其中包含如下列 Date temp data holiday day 01 01 2000 10000 0 1 02 01 2000 0 1 2 03 01 2000 2000 0 3 30 01 2000 200 0 30
  • 将此 MATLAB 代码转换为 Python 时我做错了什么?

    我正在努力将生成波形的 MATLAB 代码转换为 Python 就上下文而言 这是原子力显微镜带激发响应的模拟 与代码错误无关 在 MATLAB 中从 r vec 生成的图形与我在 Python 中生成的图形不同 我是否正确地将 MATLA

随机推荐

  • 联想拯救者R7000p 2021风扇异响解决办法

    联想拯救者R7000p 2021风扇异响解决办法 23年了 电脑用了2年 F1键下面的风扇跟拖拉机一样 在没有开任何软件下 都一直再高速转 在网上搜了下 看到19款 21款的拯救者都有这个问题 解决办法呢 网上看到的有 更新BIOS 去官网
  • 设计模式(Design Patterns)

    原文地址 http blog csdn net zhangerqing article details 8194653 设计模式 Design Patterns 可复用面向对象软件的基础 设计模式 Design pattern 是一套被反复
  • 机器学习算法+代码

    机器学习 一 概述 1 机器学习研究方向 传统预测 图像识别 自然语言处理 2 数据集构成 数据集 特征值 目标值 监督学习 目标值为类别 属于分类问题 目标值为连续数据 属于回归问题 无监督学习 无目标值 3 机器学习流程 获取数据 数据
  • Python,OpenCV骨架化图像并显示(skeletonize)

    Python OpenCV骨架化图像并显示 skeletonize 1 效果图 2 源码 参考 1 效果图 自己画一张图 原图 VS 骨架效果图如下 opencv logo原图 VS 骨架化效果图如下 2 源码 图像骨架化
  • 统一登录门户系统

    随着等保2 0和密评工作的深入推进 各政企单位的应用系统建设会向着更安全 更标准方向发展 为了推进整合信息共享 破除各系统之间的壁垒 首先要建设的就是统一登录门户系统 常见的统一登录要求 还是基于一个统一的入口 由统一登录入口完成登录后 可
  • kafka java 性能测试_针对kafka_2.13版本测试过程中的一些坑

    声明 这是在windows10上进行kafka 2 13demo搭建时的过程记录 提供给同学们参考 1 jdk先要装一下 自己安装的kafka最好检查一下配置文件中的参数 server properties 1 zookeeper conn
  • Java Pattern.matcher()方法具有什么功能呢?

    转自 Java Pattern matcher 方法具有什么功能呢 下文笔者讲述Pattern matcher 方法的功能简介说明 如下所示 Pattern matcher 方法的功能 用于匹配字符串或返回Matcher实例 Pattern
  • VM安装mac问题

    安装VM以及mac虚拟机 http tieba baidu com p 2847457021 遇见问题 您的 mac os 客户机正在使用cd dvd 此操作无法继续 请忽略此消息 并从客户机内弹 首先 需要下载安装darwin6 iso才
  • 自动化测试岗位建议熟读!!!Python+Selenium代码编写方法大全

    整理过的自动化测试selenium工具代码常用方法大全 对于常使用selenium工具的朋友一定经常会使用 建议熟读熟练 当然收藏之后复制粘贴也可以 这些整理过的web自动化测试进阶资料 有需要的可以进入群聊免费领取点击并输入暗号 CSDN
  • ganymed-ssh2实现java ssh协议采集

    我的博客第一篇讲的就是用Maverick组件实现java ssh协议采集 可惜Maverick是个商业软件 不开放源码且只有45天的试用期 实际上在网上也能搜到不少实现java ssh的开源组件 例如orion ssh2 trilead s
  • Vue路由组件独有的两个生命周期钩子

    1 作用 用于捕获路由组件的激活状态 2 具体名字 2 1 activated路由组件被激活是触发 activated this timer setInterval gt console log this opacity 0 01 if t
  • Elastic Search:(一)快速入门

    目录 1 快速入门 1 1 核心概念介绍 1 2 RESTful风格介绍 1 2 1 概念 1 2 2 方法 1 3 索引 1 3 1 新增索引 PUT 1 3 2 获取索引 GET 1 3 3 删除索引 DELETE 1 3 4 判断索引
  • 以AI对抗AI,大模型安全的“进化论”

    点击关注 文丨刘雨琦 编 王一粟 互联网时代 我们是更危险 还是更安全 2016年 互联网正值高速发展之际 电梯广告经常出现这几个大字 两行标语 从病毒木马到网络诈骗 对于安全的思考 安全防范技术的建立一直在与科技发展赛跑 同样 大模型时代
  • paypal中授权返回_2020最新教程:如何在Unity Ads中填写W-8BEN(W8税表)

    税收资料作为payout profile的一部分 是必须要填写的 即使你不是美国居民 也需要填写个人资料的 纳税 部分 否则将无法获得来自于Unity的付款 由于我们不是美国居民 因此只需要填写W 8BEN即可 其实 它的填写方法与我之前写
  • 计算机中1kb等于多少字节,在计算机中1kb等于多少字节

    在计算机中1kb等于1024个字节 字节是计算机信息技术用于计量存储容量的一种计量单位 也表示一些计算机编程语言中的数据类型和语言字符 一个字节存储8位无符号数 本文操作环境 windows10系统 thinkpad t480电脑 学习视频
  • QEMU-在内核中增加驱动(6)

    上面是我的微信和QQ群 欢迎新朋友的加入 进入linux源码目录 增加驱动 hello c include
  • Java面试----2018最全Redis面试题整理

    1 什么是Redis 答 Redis全称为 Remote Dictionary Server 远程数据服务 是一个基于内存的高性能key value数据库 2 Redis的数据类型 答 Redis支持五种数据类型 string 字符串 ha
  • python 访问网络失败 huggingface ConnectionError

    使用Hugginface下载数据集 dataset load dataset path seamew ChnSentiCorp 结果遇到网络问题 huggingface ConnectionError Couldn t reach 原因是无
  • element中同一个一面使用两个table,使用v-if判断显示,数据混乱

    错误 在一个页面中使用两个table 绑定不同的数据 并且在table中row中使用
  • 梯度下降法求解线性回归--Numpy实现

    梯度下降法求解一元线性回归 依然是这个房价预测的任务 这是一个一元线性回归问题 这次我们采用梯度下降法来求解它可以分为5步 第1步加载样本数据x y 第2步设置超参数 在这个例子中 超参数包括学习率和迭代次数 第3步设置模型参数的初值 w