Python手撸机器学习系列(一):感知机 (附原始形式和对偶形式Python实现代码)

2023-11-15

感知机

1.感知机的定义

感知机是二分类的线性模型,是神经网络和SVM的基础。输入特征 x ∈ X x∈X xX,输出 y = { + 1 , − 1 } y = \{+1 , -1\} y={+1,1}

那么感知机算法可以表示为 f ( x ) = s i g n ( w ⋅ x + b ) f(x) = sign(w·x+b) f(x)=sign(wx+b),相当于一个简单的线性函数

其中 s i g n ( a ) = { + 1 , if a ≥ 0 − 1 , if a<0 sign(a)= \begin{cases} +1, & \text {if a$\geq$0} \\ -1, & \text{if a<0} \end{cases} sign(a)={+1,1,if a0if a<0

数据的线性可分性:存在 w ⋅ x + b = 0 w·x+b = 0 wx+b=0能将数据集中的正负样本分开。说人话就是能找到一条直线将两组不同的点分开。

在这里,感知机的数据集假设为线性可分的,即表示在一堆坐标点中,总能找到一条线将正副样本给分开,并且一般能找到多条线满足要求,如下图所示。

请添加图片描述

2.感知机原始形式

2.1 损失函数

损失函数若选取误分类点的个数,则对于 w w w b b b而言不连续可导,不易优化

所以,选取误分类点到超平面S的总距离作为损失函数,即 − 1 ∣ ∣ w ∣ ∣ ∑ x i ∈ M y i ( w ⋅ x i + b ) -\frac{1}{||w||}\displaystyle\sum_{x_i∈M}y_i(w·x_i+b) w1xiMyi(wxi+b),最终不考虑 1 ∣ ∣ w ∣ ∣ \frac{1}{||w||} w1,即得到最终的损失函数:
L ( w , b ) = − ∑ x i ∈ M y i ( w ⋅ x i + b ) L(w,b) = -\displaystyle\sum_{x_i∈M}y_i(w·x_i+b) L(w,b)=xiMyi(wxi+b)
推导:某一点到S的距离为 − 1 ∣ ∣ w ∣ ∣ ∣ w ⋅ x 0 + b ∣ -\frac{1}{||w||}|w·x_0+b| w1wx0+b,而误分类数据会有 − y i ( w ⋅ x i + b ) > 0 -y_i(w·x_i+b)>0 yi(wxi+b)>0,所以上上式可以转化为 − 1 ∣ ∣ w ∣ ∣ y i ( w ⋅ x i + b ) -\frac{1}{||w||}y_i(w·x_i+b) w1yi(wxi+b),总距离: − 1 ∣ ∣ w ∣ ∣ ∑ x i ∈ M y i ( w ⋅ x i + b ) -\frac{1}{||w||}\displaystyle\sum_{x_i∈M}y_i(w·x_i+b) w1xiMyi(wxi+b)

L ( w , b ) L(w,b) L(w,b)非负,没有误分类点则为0

2.2 计算过程

使用随机梯度下降(SGD)来优化参数,算法如下:

  1. 选取初值 w 0 w_0 w0 b 0 b_0 b0

  2. 在训练集中选取数据 ( x i , y i ) (x_i,y_i) (xi,yi)

  3. 如果 y i ( w ⋅ x i + b ) ≤ 0 y_i(w·x_i+b)\leq 0 yi(wxi+b)0,则有:

    w = w + η y i x i w = w+\eta y_ix_i w=w+ηyixi

    b = b + η y i b = b + \eta y_i b=b+ηyi

  4. 转至2,直到没有误分类点

其中 η \eta η为学习率, w w w b b b的梯度通过对损失函数 L ( w , b ) L(w,b) L(w,b)求导而来

2.3代码实现

import numpy as np
import matplotlib.pyplot as plt

x_true = np.array([[3,3],[4,3]])
x_false = np.array([[1,1]])
y = [1]* len(x_true) + [-1] * len(x_false)
x_all = np.vstack([x_true,x_false])

w = np.array([0,0])
lr = 1
b = 0
i = 0
#循环判断每一个样本有没有误分类,有则更新参数重新开始判断
while i<len(x_all):
    if y[i]*(w.dot(x_all[i].T)+b) <= 0:
        w = w + lr * y[i] * x_all[i]
        b = b + lr * y[i]
        i = 0
        print('w = {},b = {}'.format(w,b))
    else:
        i += 1
print('平面S为:{:.2f}x1 + {:.2f}x2 {} = 0'.format(w[0],w[1], str(b) if b < 0 else '+'+str(b)))
plot_x = [0,1,2,3,4,5]
plot_y = [-(x*w[0]+b)/w[1] for x in plot_x]
plt.figure(figsize =(10,10))
plt.scatter([x[0] for x in x_true], [x[1] for x in x_true] , c = 'blue')
plt.scatter([x[0] for x in x_false], [x[1] for x in x_false] , c = 'red')
plt.plot(plot_x , plot_y , c = 'black')
# plt.text(0.5,4.5,'Func:{:.2f}x1 + {:.2f}x2 {} = 0'.format(w[0],w[1], str(b) if b < 0 else '+'+str(b)),fontsize=15,color = "green",style = "italic")
plt.xlim(0, 5.0) #坐标轴
plt.ylim(0, 5.0)
plt.xlabel('x1',fontsize = 16)
plt.ylabel('x2',fontsize = 16)
plt.pause(0.001)
plt.show()

实现结果:

请添加图片描述

w w w b b b变化过程以及最终的平面S:

请添加图片描述

换一组更复杂的数据测试:
请添加图片描述

3.感知机对偶形式

对偶形式是将原始形式中的 w w w b b b表示为 x i x_i xi y i y_i yi的线性组合,即

{ w = ∑ i = 1 N n i y i x i b = ∑ i = 1 N n i y i \begin{cases} w =\displaystyle\sum_{i = 1}^Nn_i y_ix_i \\b = \displaystyle\sum_{i=1}^Nn_i y_i \end{cases} w=i=1Nniyixib=i=1Nniyi

n i n_i ni值越大,表示这个样本被误分类的次数越多,就意味着这个点离我们所需要的超平面越近,左移一点或者右移一点就会误分类,对于SVM而言,这个点极有可能就是支持向量

根据原始形式, f ( x ) = s i g n ( w x + b ) = s i g n ( ∑ j = 1 N n j y j x j ⋅ x + ∑ i = 1 N n i y j ) f(x) = sign(wx+b) = sign(\displaystyle\sum_{j=1}^Nn_j y_jx_j·x+\displaystyle\sum_{i=1}^Nn_i y_j) f(x)=sign(wx+b)=sign(j=1Nnjyjxjx+i=1Nniyj)

从之前的的优化 w w w b b b,变成了优化 n n n

误分类的判断条件也变成了 y i ( ∑ j = 1 N n j y j x j ⋅ x + ∑ i = 1 N n i y j ) < 0 y_i(\displaystyle\sum_{j=1}^Nn_j y_jx_j·x+\displaystyle\sum_{i=1}^Nn_i y_j)<0 yi(j=1Nnjyjxjx+i=1Nniyj)<0

3.1 计算过程

《统计学习方法》中将 n i n_i ni α i \alpha_i αi表示

  1. 选取初值 α \alpha α b b b

  2. 在训练集中选取数据 ( x i , y i ) (x_i,y_i) (xi,yi)

  3. 如果 y i ( ∑ j = 1 N α j y j x j ⋅ x i + b ) ≤ 0 y_i(\displaystyle\sum_{j=1}^N\alpha_jy_jx_j·x_i+b)\leq 0 yi(j=1Nαjyjxjxi+b)0,则有:

    α = α + η \alpha = \alpha+\eta α=α+η

    b = b + η y i b = b + \eta y_i b=b+ηyi

  4. 转至2,直到没有误分类点

在对偶形式中,样本以内积的形式计算,如果以内积矩阵形式存储,则会大大缩短计算时间,即Gram矩阵:

G = [ x i ⋅ x j ] G = [x_i·x_j] G=[xixj],代码可以表示成Gram = x.dot(x.T)

3.2 代码实现

import numpy as np
import matplotlib.pyplot as plt

x_true = np.array([[3, 3], [4, 3]])
x_false = np.array([[1, 1]])
x_all = np.vstack([x_true,x_false])
y = [1]*len(x_true) + [-1] * len(x_false)
n = len(x_all)


a = np.zeros(n)
b = 0
lr = 1

Gram = x_all.dot(x_all.T) #计算G

i = 0
#循环判断每一个样本有没有误分类,有则更新参数重新开始判断
while i < n:
    error = 0
    for j in range(n):
        error += a[j] * y[j] * Gram[j,i]
    if y[i] * (error + b) <= 0: #有负样本
        a[i] += lr
        b += lr * y[i]
        print('a = {},b = {}'.format(a,b))
        i = 0
    else:
        i += 1

w = np.zeros(2)
for j in range(n):
    w += a[j] * y[j] * x_all[j]

print('平面S为:{:.2f}x1 + {:.2f}x2 {} = 0'.format(w[0],w[1], str(b) if b < 0 else '+'+str(b)))

plot_x = [0,1,2,3,4,5]
plot_y = [-(x*w[0]+b)/w[1] for x in plot_x]
plt.figure(figsize =(10,10))
plt.scatter([x[0] for x in x_true], [x[1] for x in x_true] , c = 'blue')
plt.scatter([x[0] for x in x_false], [x[1] for x in x_false] , c = 'red')
plt.plot(plot_x , plot_y , c = 'black')
plt.xlim(0, 5.0) #坐标轴
plt.ylim(0, 5.0)
plt.xlabel('x1',fontsize = 16)
plt.ylabel('x2',fontsize = 16)
plt.pause(0.001)
plt.show()

实验结果:

请添加图片描述

其中 a a a b b b的变化过程,以及最终的平面S:

在这里插入图片描述

4.结语

本为初学者,难免有错误,有问题欢迎评论区指出或私信。

联系方式:1759412770@qq.com ; zn1759412770@163.com

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Python手撸机器学习系列(一):感知机 (附原始形式和对偶形式Python实现代码) 的相关文章

随机推荐

  • 最全面的Socket使用解析

    前言 Socket的使用在Android的网络编程中非常重要 今天我将带大家全面了解Socket及其使用方法 目录 1 网络基础 1 1 计算机网络分层 计算机网络分为五层 物理层 数据链路层 网络层 运输层 应用层 其中 网络层 负责根据
  • 一次内网 Harbor 镜像仓库导出迁移过程记录

    1 整体思路 Harbor 提供有丰富的 API 接口 可以获取所有项目信息 镜像和标签等信息 通过编写 shell 脚本循环处理即可实现批量导出镜像包的需求 登陆 Harbor 后 左下角有 API 控制中心按钮 进入可以查看和调试 2
  • centos 安装配置l2tp实现***

    centos 安装配置l2tp实现 1 前言 L2TP是一种工业标准的Internet隧道协议 功能大致和PPTP协议类似 比如同样可以对网络数据流进行加密 不过也有不同之处 比如PPTP要求网络为IP网络 L2TP要求面向数据包的点对点连
  • OSI七层模型---数据链路层(以太网帧、MAC地址、MTU、MSS、ARP协议)

    我们首先来了解一下物理层的作用 物理层的主要目的是实现比特流的透明传输 为数据链路层提供服务 物理层接口解决了用几根线 多大电压 每根线什么功能 以及几根线之间是怎么协调的问题 物理层介质解决了数据载体材质以及价格优缺点的问题 通信技术解决
  • 01_I.MX6U芯片简介

    目录 I MX6芯片简介 Corterx A7架构简介 Cortex A处理器运行模型 Cortex A 寄存器组 IMX6U IO表示形式 I MX6芯片简介 ARM Cortex A7内核可达900 MHz 128 KB L2缓存 并行
  • 李宏毅 机器学习 2016 秋:6、Classification: Logistic Regression

    文章目录 六 Classification Logistic Regression 六 Classification Logistic Regression 我们来讲 Logistic Regression 我们在上一份投影片里面 我们都已
  • 点云Las格式分析及python实现

    目录 一 Las格式分析 1 公共头 2 变长记录 3 参考文献 二 安装laspy 2 0 2 三 代码实现 一 Las格式分析 1 公共头 公共头用来记录数据集的基本信息 如Li DAR点总数 数据范围 Li DAR点格式 变长记录总数
  • 在switch语句中使用字符串以及实现原理

    对于Java语言来说 在Java 7之前 switch语句中的条件表达式的类型只能是与整数类型兼容的类型 包括基本类型char byte short和int 与这些基本类型对应的封装类Character Byte Short和Integer
  • Go单体服务开发最佳实践

    单体最佳实践的由来 对于很多初创公司来说 业务的早期我们更应该关注于业务价值的交付 并且此时用户体量也很小 QPS 也非常低 我们应该使用更简单的技术架构来加速业务价值的交付 此时单体的优势就体现出来了 正如我直播分享时经常提到 我们在使用
  • 什么是等保合规

    近年来 随着国家对网络安全的重视 我国对网络安全的监管要求也越来越高 各互联网企业都在积极落实网络安全等级保护 关键信息基础设施安全保护制度 为了保护网络安全 企业也在按照 网络安全法 及 等保2 0 系列标准要求 积极寻求等级保护测评 整
  • C语言进阶:C陷阱与缺陷(读书笔记总)

    大家不要只收藏不关注呀 哪怕只是点个赞也可以呀 粉丝私信发邮箱 免费发你PDF 最近读了一本C语言书 C陷阱与缺陷 还不错 挺适合刚刚工作后的人 特此分享读书笔记 写代码时应注意这些问题 笔记已做精简 读完大概需要30min 如果读起来感觉
  • 广义线性模型(GLM)

    在线性回归中 y丨x N 2 在逻辑回归中 y丨x Bernoulli 这两个都是GLM中的特殊的cases 我们首先引入一个指数族 the exponential family 的概念 如果一个分布能写成下列形式 那么我们说这个分布属于指
  • Bert机器问答模型QA(阅读理解)

    Github参考代码 https github com edmondchensj ChineseQA with BERT https zhuanlan zhihu com p 333682032 数据集来源于DuReader Dataset
  • Unity基础3——Resources资源动态加载

    一 特殊文件夹 一 工程路径获取 注意 该方式 获取到的路径 一般情况下 只在 编辑模式下使用 我们不会在实际发布游戏后 还使用该路径 游戏发布过后 该路径就不存在了 print Application dataPath 二 Resourc
  • C++ vector find()使用? ( if!=vec.end())

    std vector find是C STL中的一个函数 它可以用来在std vector中查找给定的元素 如果找到了这个元素 它将返回一个迭代器指向该元素 否则将返回一个名为end 的迭代器 下面是一个使用find的示例代码 include
  • C++11 条件变量(condition_variable) 使用详解

    官网 一 总述 在C 11中 我们可以使用条件变量 condition variable 实现多个线程间的同步操作 当条件不满足时 相关线程被一直阻塞 直到某种条件出现 这些线程才会被唤醒 主要成员函数如下 二 具体函数 1 wait函数
  • 泰勒阵列天线综合与matlab,阵列天线综合之切比雪夫低副瓣阵列设计Matlab

    在 自适应天线与相控阵 这门课中 我了解到了关于理想低副瓣阵列设计的一些方法 其中切比雪夫等副瓣阵列设计方法是一种基础的方法 故将其设计流程写成maltab程序供以后学习使用 在此分享一下 此方法全称为道尔夫 切比雪夫综合法 简称为切比雪夫
  • 量化交易框架开发实践(二)

    我们通过分析代码可以看出 PyAlgoTrade分为六个组件 Strategies Feeds Brokers DataSeries Technicals Optimizer 从业务流上看也是比较容易理解的 Feed 数据源 gt Data
  • 【C++】常用math函数

    C语言提供了很多实用的数学函数 如果要使用先添加头文件
  • Python手撸机器学习系列(一):感知机 (附原始形式和对偶形式Python实现代码)

    感知机 1 感知机的定义 感知机是二分类的线性模型 是神经网络和SVM的基础 输入特征 x X x X x X 输出 y