吴恩达机器学习笔记1:手写linear regression

2023-11-20

最近手写了linear regression,有以下几点收获:

  • 做batch gradient descent时,注意每一轮迭代要使用同一个error同时更新所有参数。
  • 归一化的时候,要注意记录相应的均值和方差,后续对新样本做预测时也需要使用这两个参数,对特征做归一化。
  • 为何要做特征归一化?不做归一化,可能会降低模型的训练效率。
  • 不做归一化,模型也可以收敛,但需要选择合适的学习率lambda。

相关代码如下,数据取自吴恩达机器学习的课后作业(GitHub上可以找到)。

## 数据探索
import numpy as np
import matplotlib.pyplot as plt

# # 载入数据
# data = np.loadtxt('./ex1data1.txt', delimiter=',')
# # 区分特征X和标签y
# X,y = data[:,:1], data[:,1]

# 载入数据
data = np.loadtxt('./ex1data2.txt', delimiter=',')
# 区分特征X和标签y
X,y = data[:,:2], data[:,2]

# 数据维度
print data.shape, X.shape, y.shape
# 增加全1列,对应表达式的常数项
X=np.c_[np.ones(X.shape[0]),X]
print X.shape

# 绘图
plt.scatter(X[:,1],y)
plt.show()


## 计算模型函数和代价函数
import numpy as np

# 定义模型函数
def h(Theta, X):
    # Theta: 参数向量
    # x: 特征向量
    return np.dot(X,Theta.T)

# 定义代价函数,计算m个样本的损失
def J(Theta,X,y):
    # m表示样本数量
    m=X.shape[0]
    return 1.0/(2*m)*sum((np.dot(X,Theta.T)-y)**2)

# 初始化参数向量,参数向量的长度等于特征列的数量
Theta=np.zeros(X.shape[1])
print np.dot(X,Theta.T).shape, J(Theta,X,y)


## 模型训练:参数迭代更新(归一化)
# 梯度下降:同时对每一个参数进行迭代
def model_fit(Theta,X,y,alpha,iterations):
    # 代价函数的迭代过程记录
    cost=np.zeros(iterations)
    # 样本数量
    m=X.shape[0]
    # 迭代次数
    for i in range(iterations):
        # 必须要先计算error:对所有参数进行更新时,使用相同的error值
        error=np.dot(X,Theta.T)-y
        # batch gradient descent:对参数进行更新
        for j in range(len(Theta)):
            Theta[j]=Theta[j]-1.0/m*alpha*np.dot(error,X[:,j])
        cost[i]=J(Theta,X,y)
    return Theta,cost

# 归一化时:剔除全1列
Xless=X[:,1:]
# 存储归一化参数
X_mean=np.mean(Xless, axis=0)
X_std=np.std(Xless, axis=0)
# 归一化
Xless_std=(Xless-X_mean)/X_std
X_std=np.c_[np.ones(X.shape[0]),Xless_std]

# 定义迭代参数,如果alpha的数量级选择不正确,有可能会无法收敛
alpha,iterations=0.1,50
# 初始化参数
init_theta=np.zeros(X_std.shape[1])
# 模型训练
theta,cost=model_fit(init_theta,X_std,y,alpha,iterations)
print cost,theta

# 绘图:原始数据
plt.scatter(X_std[:,1],y)
plt.plot(X_std[:,1],np.dot(X_std,theta),color='red')
plt.show()
# 绘图:收敛速度
plt.plot([i for i in range(len(cost))],cost,marker='x')
plt.show()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

吴恩达机器学习笔记1:手写linear regression 的相关文章

随机推荐

  • linux使用date命令获取系统时间

    转载自Linux系统date命令的参数及获取时间戳的方法 date指令相关用法示例 date 用法 date OPTION FORMAT date u utc universal MMDDhhmm CC YY ss 直接输入date dat
  • 微信小程序开发之——用户登录-登录流程(1)

    一 概述 新建微信小程序自带用户登录简化 小程序登录流程时序 二 新建微信小程序自带用户登录简化 新建的微信小程序默认有用户登录功能 将多余功能去除后 简化如下 2 1 index wxml
  • 文心一言续写太监小说《名侦探世界的巫师》

    名侦探世界的巫师 是我的童年回忆 总是想着续写一下 但是又没有时间和文笔 文心一言出了 由于目前大模型貌似可以联网 可以尝试搞一波 目录 文章1 前六个故事还能看 后面就是在重复 故事2 辣眼睛 毁童年 非请勿看 故事3 流水账 故事4 其
  • JDK介绍

    JDK JRE和JVM之间的关系 JVM是运行环境 JRE是含运行环境和相关的类库 跟node环境是一个意思 JDK目录介绍 目录名称 说明 bin 该路径下存放了JDK的各种工具命令 javac和java就放在这个目录 conf 该路径下
  • C++学习笔记(十六):对vector进行更多的操作——泛型算法

    先强调一下 这里的泛型算法实际不光光是对vector的操作 对于 顺序容器 均可以 但是什么是顺序容器 我们都知道 容器就是一些特定类型对象的集合 而顺序容器为程序员提供了控制元素存储和访问的能力 这种容器的一个显著的特征 就是容器中元素的
  • ES6.x版本单机三节点配置discovery.zen.ping.unicast.hosts 错误

    问题 在同一个机子利用不同端口搭建3个ES节点 单节点正常运行 集群间无法联通 找不到主节点 表现 cluster uuid 一直没有注册成功 curl 0 0 0 0 29200 name es 01 cluster name es te
  • 浏览器地址栏输入url以后发生了什么

    在浏览器输入url后会发生的过程 1 DNS对域名进行解析 2 建立TCP连接 三次握手 3 发送HTTP请求 4 服务器处理请求 5 返回响应结果 6 关闭TCP连接 四次挥手 7 浏览器解析HTML 8 浏览器布局渲染 1 浏览器对输入
  • 华为OD机试 - 需要打开多少监控器(Java & JS & Python)

    题目描述 某长方形停车场 每个车位上方都有对应监控器 当且仅当在当前车位或者前后左右四个方向任意一个车位范围停车时 监控器才需要打开 给出某一时刻停车场的停车分布 请统计最少需要打开多少个监控器 输入描述 第一行输入m n表示长宽 满足1
  • 按照 C++ 11 标准,数组,指针,传递问题!

    一 一维数组 静态 int array 100 定义了数组array 并未对数组进行初始化 静态 int array 100 1 2 定义并初始化了数组array 动态 int array new int 100 delete array
  • Java 日历的制作 心得 写给自己

    之前已经跟着老师做过一次这个日历 但是时间一久便又拿出来自己再复习一遍 果然不出所料 已经做不出来了 而且因为在学习的时候使用的是Myeclipse 其中话中操作是由软件自己操作的 每写出一句代码软件也会自动提示哪里有问题 半傻瓜式的操作果
  • HTML5的多个video标签:截取视频源的封面图poster,监听视频播放状态的功能;

    在日常项目中 html5的video标签还是比较常用到的 开发过程中 我们都会使用到 通过监听video标签的播放 暂停 停止等等来使用 我们是否也会遇到过 有些浏览器在显示这标签 兼容不太友好 video标签的封面是一层黑色的 ok 那么
  • git-基本操作-1

    1安装 window上安装git 官网直接下载 下载完成后需要在git bash命令行中输入 git config global user name yourname git config global user email yourema
  • 非常详细的小程序搜索历史功能

    前言 我们在进行一些项目开发时 很有可能会涉及到在搜索框中搜索某一个词条 从而进行相应的检索 在这里就会出现一个优化功能 我们在搜索后的某一个词条 我希望能够显示在历史记录中 这样一个小的tip 可以给用户带来更高的使用体验 历史记录并不会
  • goland环境配置

    goland modules环境配置 下载和安装goland 环境配置 配置环境变量GOPATH 配置go modules GOPROXY代理的系统变量 工程目录中新建三个工作目录 goland中启用go modules 新建一个go程序
  • 浅谈图数据库

    本文主要讨论图数据库背后的设计思路 原理还有一些适用的场景 以及在生产环境中使用图数据库的具体案例 从社交网络谈起 下面这张图是一个社交网络场景 每个用户可以发微博 分享微博或评论他人的微博 这些都是最基本的增删改查 也是大多数研发人员对数
  • 【电子技术】什么是LFSR?

    目录 0 前言 1 数学基础 1 1 逻辑异或 1 2 模2乘法 和 模2除法 2 线性反馈移位寄存器LFSR 3 抽头和特征多项式 4 阶线性反馈移位寄存器实例 0 前言 线性反馈移位寄存器 Linear Feedback Shift R
  • mysql jdbc 实现读写分离

    这种方式直接在代码级别实现了mysql 读写分离 很简单 只需要改一下配置文件 就搞定了 是不是很嗨 jdbc driverClassName com mysql jdbc ReplicationDriver jdbc url jdbc m
  • Windows10安装Markdown安装教程(超级详细)

    markdown其实就是我们平常写博客的地方 下面我来详细介绍它的安装教程 首先到官网去安装 markdown 点击download 我反正点击download后它自动就下载了 然后下载好后是安装包 双击 然后一直next 最后它会跳出来
  • 被火车撞了都不能忘记的几道题(你会了吗?)

    目录 一 删除有序链表中的重复元素I 二 删除有序链表重复元素II 三 环形单链表中插入一个元素 四 单链表翻转II 五 奇偶链表 一 删除有序链表中的重复元素I 1 对应牛客网链接 删除有序链表中重复的元素 I 牛客题霸 牛客网 nowc
  • 吴恩达机器学习笔记1:手写linear regression

    最近手写了linear regression 有以下几点收获 做batch gradient descent时 注意每一轮迭代要使用同一个error同时更新所有参数 归一化的时候 要注意记录相应的均值和方差 后续对新样本做预测时也需要使用这