吴恩达机器学习笔记:手搓线性回归(梯度下降寻优)

2023-11-14

概念就不介绍了,记录下公式推导和代码实现,以及与最小二乘的比较。

吴恩达老师课程中使用一个参数theta保存两个变量,不过我好像没把中间变量的形状对齐,所以最后实现了两个参数的版本。

 

 

代码:

import time
import numpy as np
import pandas
import matplotlib.pyplot as plt

#随机种子
rd = np.random.RandomState(round(time.time())) 

#输出一个5*5的单位矩阵、随机矩阵
arr1 = np.ones((5,5))
arr2 = rd.randint(-10,10,(2,3))
print('5*5单位矩阵:')
print(arr1)
print('2*3随机int矩阵:')
print(arr2)


#随机生成散点数据
Population = rd.uniform(1,50,100) #均匀分布
Profit = rd.rand(100)*1000  #高斯分布


#排序(模拟线性数据) 如果是numpy array,则要使用numpy.sort()
Population.sort()
Profit.sort()


#绘制散点图
fig = plt.figure(figsize=(16, 9), dpi=40)
plt.scatter(Population, Profit) #散点图
plt.grid(color='r', linestyle='--', linewidth=1,alpha=0.3)
plt.show()


#构造代价函数Cost,使用MSE估计误差,为方便求导将目标函数设置为MSE/2m (m为样本数量)
#对于线性回归y=wx+b,回归参数有两个:w与b,我们将打包其放入“theta”中
def costFunction(x,y,theta):
    tmp = np.matmul(x,theta)
    cost = np.power(tmp-y, 2)
    return np.sum(cost)/(2*len(x))


#梯度下降
def gradientDescent(x, y, theta, rate, iters):
    X = np.array(x.reshape(len(x),1))
    Y = np.array(y.reshape(len(y),1))
    theta = theta.reshape(2,1)
    #X添加一列ones,以便与theta相乘
    X = np.insert(X, 1, np.array(np.ones(len(x))), axis=1)
    
    temp = np.matrix(np.zeros(theta.shape))
    cost = np.zeros(iters)
    
    i = 0
    cost[i] = 1e-4
    
    while i < iters and cost[i]>1e-5 :
        error = np.dot(X , theta) - Y
        print(error.shape)
        print(X[:,0].shape)
        term = error * X[:,0]
        print(term.shape)
        theta[0,0] = theta[0,0] - ((rate /(len(x))) * (( np.sum(term) )))
        theta[1,0] = theta[1,0] - ((rate /(len(x))) * ((np.sum(error))) )
    
        cost[i] = costFunction(X, Y, theta)
        
    return theta, cost


#两个参数版本的costFuction与gradientDescent
def costFun(x,y,theta_w,theta_b):
    x = x.reshape(len(x),1)
    y = y.reshape(len(y),1)
    cost = (np.dot(x,theta_w) + theta_b) - y
    return np.sum(np.power(cost,2))/(len(x)*2)


def gradientDescent2(x,y,rate,iters,theta_w,theta_b):
    x = np.array(x.reshape(len(x),1))
    y = np.array(y.reshape(len(y),1))
    
    y_hat = np.dot(x,theta_w) + theta_b
    error = np.array(y_hat - y)
    
    dw = 2*(np.vdot(x,error))
    db = sum(2*error)
    
    i = 0
    cost = np.zeros(iters)
    cost[i] = 1e-4
    while i <iters and cost[i] > 1e-5:
        y_hat = np.dot(x,theta_w) + theta_b
        error = y_hat - y
    
        dw = 2*np.vdot(x,error)
        db = sum(2*error)
        
        #更新参数
        theta_w = theta_w - rate / (len(x)*2)*dw
        theta_b = theta_b - rate / (len(x)*2)*db
        
        cost[i] = costFun(x,y,theta_w,theta_b) #打印出每一轮迭代的cost,查看是否一直在变小
        i+=1
        
    Theta = list([theta_w,theta_b])
    
    return Theta,cost
    
    
#初始化,需保证输入的数据为numpy matrix
#x = np.matrix(Population.values)
#y = np.matrix(Profit.values)


#超参数
rate = 0.001  #学习率
iters = 1500  #迭代次数
theta = np.matrix(np.array([10,5])) # a与b的初始值为0


result2 = gradientDescent2(Population, Profit, rate,iters,1,1)
print(result2)

#训练结果可视化
plt.title('Linear Regression with 1D Feature')

X = np.linspace(np.min(Population),np.max(Population*1.1),len(Population)*100)
Y = X*result2[0][0] + result2[0][1]
ax1 = plt.scatter(Population, Profit, marker= 'o', s=50)
plt.plot(X,Y,'r', label='Gradient Descending')
plt.grid(color='r', linestyle='--', linewidth=1,alpha=0.3)
plt.xlabel('Input_value(x)')
plt.ylabel('Output_value(y)')
plt.show()



# 最小二乘回归拟合对比梯度下降=============================================================================
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.datasets import make_regression
# 生产一元随机回归数据集
# X,Y = make_regression(n_samples=10,n_features=1,n_informative=1,\
#                       n_targets=1,noise=0.1,random_state=0)

Population = Population.reshape(np.size(Population),1)
Profit = Profit.reshape(np.size(Profit),1)

# 数据分割
X_train,X_test,Y_train,Y_test = train_test_split(Population, Profit,test_size=1/4,\
                                random_state=0)


# 最小二乘
linreg = LinearRegression().fit(X_train,Y_train)
print('模型系数w:{}\n模型截距b:{}\n训练集R^2得分:{}\n测试集R^2得分:{:.3f}'\
      .format(linreg.coef_, linreg.intercept_,linreg.score(X_train,Y_train)\
      ,linreg.score(X_test,Y_test)))

ax2 = plt.scatter(Population, Profit, marker= 'o', s=10,label='Samples')

X_ols = np.linspace(np.min(Population),np.max(Population*1.5),len(Population)*100)
Y_ols = X_ols * linreg.coef_+linreg.intercept_
Y_ols = Y_ols.reshape(np.size(Y_ols),1)

print(Y_ols.shape)
ax3 = plt.plot(X_ols,Y_ols,'g',label='Least Sum of Squares')

print("ax3:")
print(ax3)
#梯度下降
result = gradientDescent2(np.array(X_train), np.array(Y_train), rate,iters,1,1)
print(result)


X_d = np.linspace(np.min(X),np.max(X*1.5),len(X)*100)
Y_d = X_d*result[0][0] + result[0][1]


ax4 = plt.plot(X_d,Y_d,'b',label='gradientDescent')
plt.legend()


#形状对不上..
# result1 = gradientDescent(Population, Profit, theta, rate, iters)
# print(result1)

输出结果:

 

 

纸上习得终觉浅..

Reference:

吴恩达机器学习ex1 - Heywhale.com

一文看懂简单线性回归:梯度下降法和最小二乘法(代码实现及数学公式详解)_Maxxi Chen的博客-CSDN博客_最小二乘法和梯度下降法实现线性模型的过程

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

吴恩达机器学习笔记:手搓线性回归(梯度下降寻优) 的相关文章

  • SQL中Group By的使用

    SQL中Group By的使用 1 概述 2 原始表 3 简单Group By 4 Group By 和 Order By 5 Group By中Select指定的字段限制 6 Group By All 7 Group By与聚合函数 8
  • 操作系统(02)- 四个基本特征

    文章目录 一 操作系统的特征 1 并发 2 共享 3 虚拟 4 异步 一 操作系统的特征 操作系统的特征可以分为四类 并发 共享 虚拟 异步 其中并发和共享是最基本的特征 二者互为存在条件 后面会给出详细解释 下面详细的介绍这四种特征 1
  • 中国智能建筑行业运行状况与十四五应用前景调研报告2022版

    中国智能建筑行业运行状况与十四五应用前景调研报告2022版 修订日期 2021年12月 搜索鸿晟信合研究院查看官网更多内容 第一章 智能建筑发展概述 1 1 智能建筑的相关概念 1 1 1 智能建筑的定义 1 1 2 智能建筑的层次划分 1
  • Ubuntu安装redis5.0.0

    一 下载 sudo wget http download redis io releases redis 5 0 0 tar gz 如果慢 可以传上去 二 解压编译安装 解压后 切换目录 cd app redis 5 0 0 编译 make
  • 【C++】对数组指针的理解,例如 int (*p)[3]

    目录 简介 思考 理解 结语 简介 Hello 非常感谢您阅读海轰的文章 倘若文中有错误的地方 欢迎您指出 昵称 海轰 标签 程序猿 C 选手 学生 简介 因C语言结识编程 随后转入计算机专业 获得过国家奖学金 有幸在竞赛中拿过一些国奖 省

随机推荐

  • 【MYSQL】mysql1130错误与安装重置密码

    1 连接服务器 mysql u root p 2 看当前所有数据库 show databases 3 进入mysql数据库 use mysql 4 查看mysql数据库中所有的表 show tables 5 查看user表中的数据 sele
  • freeswitch编译过程以及添加odbc连接mysql

    freeswitch 编译 参考官网wiki bootstrap sh j configure prefix home make make j install make j cd sounds install make j cd moh i
  • Git的Patch功能

    本文整理编辑自 http www cnblogs com y041039 articles 2411600 html http yuxu9710108 blog 163 com blog static 2375153420101114488
  • 可以免费使用的ChatGPT-4,微软开放Bing Chat功能,供用户体验

    微软Bing取消了 Bing Chat的等待名单 现在用户可以通过使用 Edge 浏览器并使用微软帐户登录就可以使用Bing Chat了 入口 打开Bing首页 用户点击 聊天 Chat 即可进入Bing Chat界面 目前Bing Cha
  • Open3D(C++) 模型锐化

    目录 一 模型锐化 1 概述 2 主要函数 二 代码实现 三 结果展示 1 原始模型 2 锐化处理 一 模型锐化 1 概述 Open3D中的实现一种模型锐化处理的算法 该算法的输出值 v o v o v
  • uni.switchTab()跳转不刷新页面问题

    uni switchTab 跳转不刷新页面问题 大家应该都遇到过 调转到 tabBar 里面的页面时 只能使用 uni switchTab 或者是 uni navigator 跳转 使用 uni reLauch 或者是 uni redire
  • 【BLE】-CC2541 OSAL操作系统抽象层应用程序接口API介绍

    参考源source 简介 目的 本文档的目的是定义OS抽象层 OSAL 的API 这个API适用于TI协议栈软件组的产品 例如Z 堆栈 RemoTI 和BLE 适用范围 该文件列举了由OSAL提供的所有函数调用 详细地说明了所有函数调用 方
  • C++11之继承构造函数(using 声明)

    系列文章 C 11之正则表达式 regex match regex search regex replace C 11之线程库 Thread Mutex atomic lock guard 同步 C 11之智能指针 unique ptr s
  • SQL,NowSQL及NewSQL浅析

    关系型数据库 NOSQL NEWSQL浅析 1 关系型数据库 关系数据库 是建立在关系模型基础上的数据库 借助于集合代数等数学概念和方法来处理数据库中的数据 简单来说 关系模型指的就是二维表格模型 而一个关系型数据库就是由二维表及其之间的联
  • 华为OD机试 - 在字符串中找出连续最长的数字串(含“+-”号)(Java)

    题目描述 请在一个字符串中找出连续最长的数字串 并返回这个数字串 如果存在长度相同的连续数字串 返回最后一个 如果没有符合条件的字符串 返回空字符串 注意 数字串可以由数字 0 9 小数点 正负号 组成 长度包括组成数字串的所有符号 仅能出
  • 臭名昭著的MOS管米勒效应

    概述 MOS管的米勒效应会在高频开关电路中 延长开关频率 增加功耗 降低系统稳定性 可谓是臭名昭著 各大厂商都在不遗余力的减少米勒电容 分析 如下是一个NMOS的开关电路 阶跃信号VG1设置DC电平2V 方波 振幅2V 频率50Hz T2的
  • 求大神们指教

    都已经定义了 为什么出现如下错误 求大神们指教 1 gt main obj error LNK2019 无法解析的外部符号 public char thiscall LinkStack
  • x86直接写屏显示字符串

    直接向显存地址 0xb800 xxxx 写入数据 屏幕显示 80列 25行 一个字符显存2byte file showstr s code16 globl start begtext begdata begbss endtext endda
  • 一般试卷的纸张大小是多少_平时打印卷子的纸是多大的?

    展开全部 一般使用的是A3大小的纸 一 打印的卷子纸 一般是8K大小 就像两张A4纸拼在一起的大小 但是 32313133353236313431303231363533e59b9ee7ad9431333365643661家用打印机一般只能
  • Java学习笔记32——字符缓冲流

    字符缓冲流 字符流 字符缓冲流 字符缓冲流的特有功能 IO流小结 字节流 字符流 字符流 字符缓冲流 BufferedWriter 将文本写入字符输出流 缓冲字符 以提供单个字符 数组和字符串的高效写入 可以指定缓冲区大小 或者可以接受默认
  • IDEA 解决Maven打包时控制台中文乱码

    File Settings VM Options中加入 DarchetypeCatalog internal Dfile encoding GBK
  • 遍历提取文件夹中特定的jpg图片并存入指定文件夹

    coding utf 8 usr bin python test copyfile py import os shutil rootdir home unbuntu Desktop yixian 要提取文件夹的根目录 dstdir0 hom
  • STM32 基础系列教程 48 – CJSON

    前言 JSON JavaScript Object Notation JS 对象简谱 是一种轻量级的数据交换格式 它基于 ECMAScript 欧洲计算机协会制定的js规范 的一个子集 采用完全独立于编程语言的文本格式来存储和表示数据 简洁
  • 统计字符串中重复的字符个数并输出

    输出字符串各个字符的个数 对重复的字符将其下标存放在vector中 使用unique函数只保存一份重复字符的数字 通过下标查找到相应的字符 从map中取出对应的统计数字 include iostream include windows h
  • 吴恩达机器学习笔记:手搓线性回归(梯度下降寻优)

    概念就不介绍了 记录下公式推导和代码实现 以及与最小二乘的比较 吴恩达老师课程中使用一个参数theta保存两个变量 不过我好像没把中间变量的形状对齐 所以最后实现了两个参数的版本 代码 import time import numpy as