【机器学习】线性回归

2023-05-16

目录

    • 1. 问题
    • 2. 解决
    • 3. 代码
    • 4. 结果
    • 5. 数据

1. 问题

假设你是一个餐饮连锁店的CEO,你打算在不同的城市开设不同的分店。你已经在一些城市开了分店而且你有这些城市人口与利润的数据(见 5. 数据 data.txt),你希望通过这些数据来决定在哪些城市新开分店(也就是通过新城市的人口预测新城市的利润)。

2. 解决

线性回归

假设 利润人口数 的函数关系为: h θ ( x ) = θ 0 + θ 1 x h_{\theta}(x) = \theta_0 + \theta_1x hθ(x)=θ0+θ1x

实现了单变量线性回归模型,且仅针对单变量线性回归有效.

实现时为使最小化代价函数 J ( θ ) = 1 2 m ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) 2 J(\theta) = \frac{1}{2m}\sum_{i=1}^m (h_{\theta}(x^{(i)})-y^{(i)})^2 J(θ)=2m1i=1m(hθ(x(i))y(i))2(均方误差),使用梯度下降法获得线性回归参数。

在这里插入图片描述
代价函数的导数:
在这里插入图片描述
需要设置的初始参数有 θ 0 \theta_0 θ0 θ 1 \theta_1 θ1、学习率 α \alpha α

θ 0 = 0 \theta_0 = 0 θ0=0
θ 1 = 0 \theta_1 = 0 θ1=0
α = 0.01 \alpha = 0.01 α=0.01

终止条件采用了两种方法(两种方法中任意一个满足条件时迭代终止):

  1. 迭代步数限制 ( S T E P = 10000 STEP = 10000 STEP=10000
  2. 当两次迭代获得的 差异 ( Δ = 0.0000001 \Delta = 0.0000001 Δ=0.0000001)较小时终止迭代

3. 代码

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pylab import mpl

mpl.rcParams['font.sans-serif'] = ['FangSong'] # 指定默认字体
mpl.rcParams['axes.unicode_minus'] = False # 解决保存图像是负号'-'显示为方块的问题

# 假设:单变量线性回归
def hypothesis(x, theta0, theta1):
    return theta0 + theta1 * x
# 代价函数(Cost function):
def CostFunction(X, Y, theta0, theta1):
    m = np.shape(X)[0]
    res = 1
    for i in range(m):
        res += pow(hypothesis(X[i], theta0, theta1) - Y[i], 2)
    res /= (2*m)
    return res
# 代价函数的导数
def CostFunction_derivative(j, X, Y, theta0, theta1):
    m = np.shape(X)[0]
    res = 0
    for i in range(m):
        tmp = hypothesis(X[i], theta0, theta1) - Y[i]
        if j == 1:
            tmp *= X[i]
        res += tmp
    res /= m
    return res
# 批量梯度下降算法
def gradientDescent(X, Y, theta0, theta1, a):
    temp0 = theta0 - a * CostFunction_derivative(0, X, Y, theta0, theta1)
    temp1 = theta1 - a * CostFunction_derivative(1, X, Y, theta0, theta1)
    return temp0, temp1
# 绘制二维散点图
def plot_data(X, Y, title, xlabel, ylabel):
	plt.plot(X, Y, 'ro', markersize=6)
	plt.title(title, fontsize=20)
	plt.xlabel(xlabel, fontsize=10)
	plt.ylabel(ylabel, fontsize=10)
	plt.ioff() # 使matplotlib的显示模式转换为交互(interactive)模式


# 读取数据
dataset = pd.read_csv('data1a.txt', header=None) # 设置header参数,读取文件的时候没有标题
X = dataset.iloc[:,0].values # 人口数
Y = dataset.iloc[:,1].values # 利润

# 参数 θ 初始化
theta0 = 0
theta1 = 0

# 学习率
learningRate = 0.01

# 循环终止条件设定
STEP = 10000 # 设定一个比较大的迭代步数。
Delta = 0.0000001 # 当两次迭代获的J(θ) 差异较小时终止迭代。

if __name__ == '__main__':
    cnt = 0
    Jlast = 0
    Jnow = CostFunction(X, Y, theta0, theta1)
    Jlist = [Jnow]
    while cnt < STEP and abs(Jnow - Jlast) > Delta:
        theta0, theta1 = gradientDescent(X, Y, theta0, theta1, learningRate);
        # print(theta0, theta1)
        Jlast = Jnow
        Jnow = CostFunction(X, Y, theta0, theta1)
        Jlist.append(Jnow)
        cnt += 1
    print("梯度下降法获得线性回归参数")
    print("θ0 = ", theta0)
    print("θ1 = ", theta1)
    print()
    print("回归模型在所有训练数据(train_data.txt)上最终的J(θ)值")
    print("J(θ) = ", Jnow)
    print()
    # 画图
    plt.figure(figsize=(10, 6))
    # figure = plt.subplot(211)
    plt.plot(Jlist)
    plt.xlabel(u'迭代步数')
    plt.ylabel(u'代价函数值J(θ)')
    plt.title(u'J(θ)随迭代步数的变化')

    plt.figure(figsize=(10, 6))
    # plt.subplot(212)
    plt.scatter(X, Y, color='red')
    plt.plot(X, predict, color='black')
    plt.xlabel(u'人口数')
    plt.ylabel(u'利润')
    plt.title(u'线性回归')
    plt.show()

4. 结果

梯度下降法获得线性回归参数:
θ 0 = − 3.8783681899109235 \theta_0 = -3.8783681899109235 θ0=3.8783681899109235
θ 0 = 1.1912843507674498 \theta_0 = 1.1912843507674498 θ0=1.1912843507674498

回归模型在所有训练数据(data.txt)上最终 J ( θ ) = 4.482153618457505 J(θ) = 4.482153618457505 J(θ)=4.482153618457505

循环过程中 J ( θ ) J(\theta) J(θ) 随迭代步数变化的图
在这里插入图片描述
线性回归的拟合效果
在这里插入图片描述

5. 数据

data.txt

6.1101,17.592
5.5277,9.1302
8.5186,13.662
7.0032,11.854
5.8598,6.8233
8.3829,11.886
7.4764,4.3483
8.5781,12
6.4862,6.5987
5.0546,3.8166
5.7107,3.2522
14.164,15.505
5.734,3.1551
8.4084,7.2258
5.6407,0.71618
5.3794,3.5129
6.3654,5.3048
5.1301,0.56077
6.4296,3.6518
7.0708,5.3893
6.1891,3.1386
20.27,21.767
5.4901,4.263
6.3261,5.1875
5.5649,3.0825
18.945,22.638
12.828,13.501
10.957,7.0467
13.176,14.692
22.203,24.147
5.2524,-1.22
6.5894,5.9966
9.2482,12.134
5.8918,1.8495
8.2111,6.5426
7.9334,4.5623
8.0959,4.1164
5.6063,3.3928
12.836,10.117
6.3534,5.4974
5.4069,0.55657
6.8825,3.9115
11.708,5.3854
5.7737,2.4406
7.8247,6.7318
7.0931,1.0463
5.0702,5.1337
5.8014,1.844
11.7,8.0043
5.5416,1.0179
7.5402,6.7504
5.3077,1.8396
7.4239,4.2885
7.6031,4.9981
6.3328,1.4233
6.3589,-1.4211
6.2742,2.4756
5.6397,4.6042
9.3102,3.9624
9.4536,5.4141
8.8254,5.1694
5.1793,-0.74279
21.279,17.929
14.908,12.054
18.959,17.054
7.2182,4.8852
8.2951,5.7442
10.236,7.7754
5.4994,1.0173
20.341,20.992
10.136,6.6799
7.3345,4.0259
6.0062,1.2784
7.2259,3.3411
5.0269,-2.6807
6.5479,0.29678
7.5386,3.8845
5.0365,5.7014
10.274,6.7526
5.1077,2.0576
5.7292,0.47953
5.1884,0.20421
6.3557,0.67861
9.7687,7.5435
6.5159,5.3436
8.5172,4.2415
9.1802,6.7981
6.002,0.92695
5.5204,0.152
5.0594,2.8214
5.7077,1.8451
7.6366,4.2959
5.8707,7.2029
5.3054,1.9869
8.2934,0.14454
13.394,9.0551
5.4369,0.61705
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【机器学习】线性回归 的相关文章

  • hdfs shell 操作基本语法

    hdfs用户切换并查看文件 xshell登陆到linux服务器 root 用户切换 以hdfs用户登陆查看创建的hive数据库是否以文件夹的形式存在hive文件目录下 su hdfs hdfs dfs ls apps hive wareho
  • js 多级对象数组删除对象

    let firstIndex 61 null let secondIndex 61 null const findItemNested 61 arr itemId nestingKey 61 gt arr reduce a item myI
  • Aarch64安装Anaconda Pytorch Torchvision

    1 Anaconda wget https github com Archiconda build tools releases download 0 2 3 Archiconda3 0 2 3 Linux aarch64 sh sudo
  • 扩大VMWARE里面虚拟硬盘大小(*.vmdk)

    http blog csdn net bshawk archive 2008 01 28 2070587 aspx 最近编译2 6 22的内核时 xff0c 发现虚拟机器FC6硬盘空间不够了 xff0c 于是乎 xff0c 想扩展下硬盘的大
  • c#加载xml文件

    C 加载xml文件 XmlDocument xmlDoc 61 new XmlDocument xmlDoc Load Application StartupPath 43 34 34 43 34 xml xml 34 加载xml文件 Xm
  • zram

    wiki zram是Linux内核的一个模块 xff0c 之前被称为 compcache zram通过在RAM内的压缩块设备上分页 xff0c 直到必须使用硬盘上的交换空间 xff0c 以避免在磁盘上进行分页 xff0c 从而提高性能 由于
  • 英飞凌 AURIX 系列单片机的HSM详解(2)——与HSM相关的UCB和寄存器

    本系列的其它几篇文章 xff1a 英飞凌 AURIX 系列单片机的HSM详解 xff08 1 xff09 何为HSM 英飞凌 AURIX 系列单片机的HSM详解 xff08 2 xff09 与HSM相关的UCB和寄存器 英飞凌 AURIX
  • MySQL数据库知识点总结

    1 什么是 MySQL MySQL 是 种关系型数据库 xff0c 在 Java 企业级开发中 常常 xff0c 因为 MySQL 是开源免费的 xff0c 并 且 便扩展 阿 巴巴数据库系统也 量 到了 MySQL xff0c 因此它的稳
  • 论文笔记-Towards Scene Understanding-Unsupervised Monocular Depth Estimation

    论文信息 标题 xff1a Towards Scene Understanding Unsupervised Monocular Depth Estimation with Semantic aware Representation作者 x
  • 结合 Casbin 对 http 请求做 RBAC 鉴权以及添加请求路由参数支持

    目录 总结 背景 实操 安装 Casbin 创建一个 Casbin 模型 创建一个 Casbin 策略 加载 Casbin 模型和策略并创建一个路由 总结 在本文中 xff0c 我们将介绍如何结合 Casbin 对 HTTP 请求进行基于角
  • Git—— master|RELEASE1/1

    当提交代码时 xff0c 多人合作避免不了要冲突 公司就我一个前端 xff0c 所以我一般情况下几乎不习惯pull代码 记录一下今天的执行过程 1 在vscode工具中操作更新的代码 2 在Git Bash中push 代码 span cla
  • 【FPGA】Mint20.3系统安装VCS2018环境

    mint系统是目前新手入手linux系统最为容易的系统版本 xff0c 其界面与Windows系统高度重合 vcs是IC开发常用的系统仿真工具 xff0c 但vcs工具的安装是一个很头疼的事情 xff0c 本篇展现在mint20 3系统安装
  • kubernetes使用flannel网络插件服务状态显示CrashLoopBackOff

    使用Kubeadm安装K8s集群 xff0c 在安装flannel网络插件后 xff0c 发现pod kube flannel ds 一直是CrashLoopBackOff 报错内容如下 xff1a log is DEPRECATED an
  • 用Python写了个金融数据爬虫,半小时干了全组一周的工作量

    最近 xff0c 越来越多的研究员 基金经理甚至财务会计领域的朋友 xff0c 向小编咨询 xff1a 金融人需要学Python么 xff1f 事实上在2019年 xff0c 这已经不是一个问题了 Python已成为国内很多顶级投行 基金
  • SSIS_数据流转换(Union All&合并联接&合并)

    Union All xff1a 与sql语言 Union All 一样 xff0c 不用排序 xff0c 上下合并多个表 Union All转换替代合并转换 xff1a 输入输出无需排序 xff0c 合并超过两个表 合并联接 xff1a 有
  • LACP协议:链路聚合/华为交换机LACP

    链路聚合的3种模式 61 61 61 61 61 gt 手工聚合 静态聚合 动态聚合 手工聚合 xff1a 手工汇聚概述 xff1a 手工负载分担模式是一种最基本的链路聚合方式 xff0c Eth Trunk 接口的建立 xff0c 成员接
  • Pytorch中Tensor和numpy数组的互相转化

    Pytorch中Tensor和Numpy数组的相互转化分为两种 xff0c 第一种转化前后的对象共享相同的内存区域 xff08 即修改其中另外一个也会改变 xff09 xff1b 第二种是二者并不共享内存区域 共享内存区域的转化 这种涉及到
  • #51单片机# 用中断实现蜂鸣器

    蜂鸣器常作为提示音 xff0c 用于计算机 打印机 万用表等设备中 提示音一般很简单 xff0c 能响就行 某单片机的蜂鸣器原理图 xff1a 该单片机的CPU原理图 xff1a 下面这段程序用到了中断的算法 xff0c 实现了蜂鸣器在4k
  • VM跨主机通信ovs配置

    如果位于不同物理主机上的两个VM需要通信 xff0c 那么底层的虚拟交换机ovs需要配置tunnel端口 OVS中支持添加隧道 Tunnel 端口 xff0c 常见隧道技术有两种gre或vxlan 隧道技术是在现有的物理网络之上构建一层虚拟
  • 免受 DDoS 攻击的五种技术

    尽管 DDoS 攻击很可怕 xff0c 但好消息是它们很容易预防 本节将讨论保护您的业务免受 DDoS 攻击的五种技术 一 高质量的网络硬件 高质量的网络基础设施可以帮助您检测甚至阻止网站流量的恶意增加 网络硬件包括路由器 用于连接设备的电

随机推荐