感知机及算法实现

2023-10-26

在这里插入图片描述

1、感知机二类分类的线性分类模型,输入为实例的特征向量,输出为实例的类别,取+1和-1二值,感知机对应于输入空间中将实例划分为正负两类的分离超平面,属于判别模型。
感知机学习旨在求出将训练数据进行线性划分的分离超平面,为此导入基于误分类的损失函数,利用梯度下降法对损失函数进行极小化,求得感知机模型。感知机学习算法具有简单、易于实现的优点,分为原始形式和对偶形式,感知机预测是用学习到的感知机模型对新的输入实例进行分类。

2、感知机:假设输入空间(特征空间)是X⊆R,输出空间Y={+1,-1},输入x∈X表示实例的特征向量,对应于输入空间的点;输出y∈Y表示实例的类别。有输入空间到输出空间的如下函数f(x)=sign(w•x+b)称为感知机,其中w和b为感知机模型参数,w∈R叫做权值或权值向量,b∈R叫做偏置,w•x为内积,sign是符号函数
在这里插入图片描述
感知机是一种线性分类模型,属于判别模型,假设空间是定义在特征空间中所有线性分类模型或线性分类器,即函数集合{f|f(x)=w•x+b}。感知机学习是由训练数据集求得上述感知机模型,即求出模型参数w,b;感知机预测是通过学习得到的感知机模型,对新的输入实例给出对应的输出类别。

3、感知机的几何解释:线性方程w•x+b=0对应于特征空间R中的一个超平面S(分离超平面),其中w是超平面的法向量,b是超平面的截距,这个超平面将特征空间划分为两个部分,位于两部分的点被分为正负两类。

4、数据集的线性可分性:给定一个数据集T={(x1,y1),(x2,y2),…,(xn,yn)},其中xi∈X,yi∈Y={+1,-1},如果存在某个超平面S,w•x+b=0能够将数据集的正实例点和负实例点完全正确地划分到超平面两侧,即对所有yi=+1的实例i,有w•xi+b>0;对所有yi=-1的实例i,有w•xi+b<0,则称数据集T为线性可分数据集,否则称数据集T线性不可分。

5、为了找到能将训练集正实例点和负实例点完全正确分开的分离超平面,需要确定感知机模型形参数w、b,这时需要确定学习策略即定义损失(经验)函数并将损失函数极小化。
损失函数的一个自然选择是误分类点的总数,但是这样的损失函数不是参数w,b的连续可导函数,不易优化,另一个选择是误分类点到超平面S的总距离,这是感知机所采用的。损失函数定义为
在这里插入图片描述
其中,M为误分类点的集合。损失函数是非负的,如果没有误分类点,损失函数值为0,而且误分类点数越少,误分类点离超平面越近,损失函数值越小。

6、感知机算法是误分类驱动的,具体所采用的是随机梯度下降法,首先任意选取一个超平面w0,b0,然后用梯度下降法不断地极小化目标函数,极小化的过程中不是一次使M中所有误分类点的梯度下降,而是一次随机选取一个误分类点使其梯度下降。损失函数L(w,b)的梯度由下式给出。
在这里插入图片描述
随机选择一个误分类点(xi,yi),对w,b进行更新:
在这里插入图片描述
式中η(0<η≤1)是步长,在统计学习中又称为学习率,通过迭代可以期待损失函数不断减小直至0。

7、感知机学习算法的原始形式
输入——训练数据集T={(x1,y1),(x2,y2),…,(xn,yn)},其中xi∈X=R,yi∈Y={+1,-1},i=1,2,…,N,学习率η(0<η≤1)
输出——w,b,感知机模型f(x)=sign(w•x+b)。
(1)选取初值w0,b0;
(2)在训练集中选择数据(xi,yi)
(3)如果yi(w•xi+b)≤0
在这里插入图片描述
(4)转至(2),直至训练集中没有误分类点
算法直观解释——当一个实例点被误分类,位于分离超平面的错误一侧时,则调整w,b的值,使分离超平面向该误分类点的一侧移动,以减少该误分类点与超平面间的距离,直至超平面越过该误分类点使其被正确分类。

# 利用Python实现感知机算法的原始形式
import numpy as np
import matplotlib.pyplot as plt

def createData():
    samples = np.array([[3, 3], [4, 3], [1, 1]])
    labels = [-1, -1, 1]
    return samples, labels

class Perception:
    def __init__(self, x, y, a=1):
        self.x = x
        self.y = y
        self.l_rate = a
        self.w = np.zeros((x.shape[1], 1))
        self.b = 0
        self.numSimples = x.shape[0]
        self.numFeatures = x.shape[1]

    def sign(self, w, b, x):
        y = np.dot(x, w) + b  # x .w + b
        return int(y)

    def update(self, label_i, data_i):
        tmp = label_i * self.l_rate * data_i  # w = w + n yx
        tmp = tmp.reshape(self.w.shape)
        self.w = tmp + self.w
        self.b = self.b + label_i * self.l_rate  # b = b + n y

    def train(self):
        isFind = False
        while not isFind:
            count = 0
            for row in range(self.numSimples):
                simY = self.sign(self.w, self.b, self.x[row, :])
                if simY * self.y[row] <= 0:  # 如果是一个误分类实例点
                    print('误分类点为:', self.x[row, :], '此时的w和b为:', self.w, self.b)
                    count += 1
                    self.update(self.y[row], self.x[row])
            if count == 0:
                print('最终训练得到的w和b为:', self.w, self.b)
                isFind = True
        return self.w, self.b

class Picture:
    def __init__(self, data, w, b):
        self.w = w
        self.b = b
        plt.figure(1)
        plt.title("Perception Learning Algorithm", size=14)
        plt.xlabel("x0", size=14)
        plt.ylabel("x1", size=14)
        xData = np.linspace(0, 5, 100)
        yData = self.expression(xData)
        plt.plot(xData, yData, color='r', label='data')

        plt.scatter(data[0][0], data[0][1], s=50)
        plt.scatter(data[1][0], data[1][1], s=50)
        plt.scatter(data[2][0], data[2][1], s=50, marker='x')
        plt.savefig('2d_base.png', dpi=75)

    def expression(self, x):
        y = (-self.b - self.w[0] * x) / self.w[1]
        # 注意在此,把x0,x1当做两个坐标轴,把x1当做自变量,x2为因变量
        return y
        
    def show_pic(self):
        plt.show()

if __name__ == '__main__':
    samples, labels = createData()
    myperceptron = Perception(x=samples, y=labels)
    weights, bias = myperceptron.train()
    Picture = Picture(samples, weights, bias)
    Picture.show_pic()

8、Novikoff定理
在这里插入图片描述
定理表明,误分类次数k由上界,经过有限次搜索可以找到将训练数据完全正确分开的分离超平面,当训练数据集线性可分时,感知机学习算法原始形式迭代是收敛的。为了得到唯一的超平面,需要对分离超平面增加约束条件。

9、感知机学习算法的对偶形式的基本思路
将w和b表示为实例xi和标记yi的线性组合形式,通过求解其系数求得w和b,可先假设w0、b0均为0,对误分类点先通过下式
在这里插入图片描述
逐步修改w,b,修改n次则w,b关于(xi,yi)的增量分别是αiyixi和αiyi,αi=niη,则最后学习到的w,b可以分别表示为,
在这里插入图片描述
其中αi≥0,当η=1时表示第i个实例点由于误分而进行更新的次数,实例点更新次数越多,意味着它距离分离超平面越近,越难正确分类。

10、感知机学习算法的对偶形式
输入——线性可分训练数据集T={(x1,y1),(x2,y2),…,(xn,yn)},其中xi∈X=R,yi∈Y={+1,-1},i=1,2,…,N,学习率η(0<η≤1)
输出——ɑ,b,感知机模型f(x)=sign(∑ɑjyjxj•x+b)。
其中
在这里插入图片描述
(1)ɑ←0,b←0
(2)在训练集中选取数据(xi,yi)
(3)如果yi(∑ɑjyjxj•xi+b)≤0
在这里插入图片描述
(4)转至(2)直到没有误分类数据。
对偶形式中训练数据仅以内积形式出现,可以预先将训练集中实例间的内积计算出来并以矩阵形式存储,该矩阵即Gram矩阵:G=[xi•xj](NxN)。

# 利用Python实现感知机算法的对偶形式
import numpy as np
import matplotlib.pyplot as plt

# 1、 创建数据集
def createdata():
    samples = np.array([[3, 3], [4, 3], [1, 1]])
    labels = np.array([-1, -1, 1])
    return samples, labels

class Perception:
    def __init__(self, x, y, a=1):
        self.x = x
        self.y = y
        self.w = np.zeros((1, x.shape[0]))
        self.b = 0
        self.a = 1  # 学习率
        self.numsamples = self.x.shape[0]
        self.numfeatures = self.x.shape[1]
        self.gMatrix = self.cal_gram(self.x)

    def cal_gram(self, x):
        gMatrix = np.zeros((self.numsamples, self.numsamples))
        for i in range(self.numsamples):
            for j in range(self.numsamples):
                gMatrix[i][j] = np.dot(self.x[i, :], self.x[j, :])
        return gMatrix

    def sign(self, w, b, key):
        y = np.dot(w * self.y, self.gMatrix[:, key]) + b  # αjYjXjXi + b
        return int(y)

    def update(self, i):
        self.w[:, i] = self.w[:, i] + self.a
        self.b = self.b + self.y[i] * self.a

    def cal_w(self):
        w = np.dot(self.w * self.y, self.x)
        return w

    def train(self):
        isFind = False
        while not isFind:
            count = 0
            for i in range(self.numsamples):
                tmpY = self.sign(self.w, self.b, i)
                if tmpY * self.y[i] <= 0:  # 如果是一个误分类实例点
                    print('误分类点为:', self.x[i, :], '此时的w和b为:', self.cal_w(), ',', self.b)
                    count += 1
                    self.update(i)
            if count == 0:
                print('最终训练得到的w和b为:', self.cal_w(), ',', self.b)
                isFind = True
        weights = self.cal_w()
        return weights, self.b

# 画图描绘
class Picture:
    def __init__(self, data, w, b):
        self.b = b
        self.w = w
        plt.figure(1)
        plt.title('Perception Learning Algorithm', size=14)
        plt.xlabel('x0', size=14)
        plt.ylabel('x1', size=14)

        xData = np.linspace(0, 5, 100)
        yData = self.expression(xData)
        plt.plot(xData, yData, color='r', label='data')

        plt.scatter(data[0][0], data[0][1], s=50)
        plt.scatter(data[1][0], data[1][1], s=50)
        plt.scatter(data[2][0], data[2][1], s=50, marker='x')
        plt.savefig('2d_duio.png', dpi=75)

    def expression(self, x):
        y = (-self.b - self.w[:, 0] * x) / self.w[:, 1]
        return y

    def show_pic(self):
        plt.show()

if __name__ == '__main__':
    samples, labels = createdata()
    myperceptron = Perception(x=samples, y=labels)
    weights, bias = myperceptron.train()
    Picture = Picture(samples, weights, bias)
    Picture.show_pic()

拓展知识

  • 损失函数之所以是非负的,因为每一个误分类点都满足[公式]
    因为当我们数据点正确值为+1的时候,你误分类了,那么判断为-1,则算出来[公式]
    当数据点是正确值为-1的时候,你误分类了,那么判断为+1,则算出来
    在这里插入图片描述
  • 异或问题不能用感知机表示的原因:异或问题可以分为根据输出可以分为两类,显示在二维坐标系中如上图所示:其中输出结果为1对应右图中红色的十字架,输出为0对应右图中蓝色的圆圈,我们可以发现对于这种情况无法找到一条直线将两类结果分开,即感知机无法找到一个线性模型对异或问题进行划分,并且所有的线性分类模型都无法处理异或分类问题。
    在这里插入图片描述

来源:统计学习方法

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

感知机及算法实现 的相关文章

随机推荐

  • Git学习笔记

    配置user信息 配置user name git config global user name your name 配置user email git config global user email your email 查看所有配置信息
  • Jenkins中连接Git仓库时提示:无法连接仓库:Error performing git command: git ls-remote -h

    问题 Jenkins中连接Git仓库时提示 无法连接仓库 Error performing git command git ls remote h 原因 git的账号密码错误 解决方案 重新设置账号密码 操作 控制面板 凭证管理器 wind
  • 有趣的异常

    缘起 最近 在项目中遇到一个有趣的异常 在没附加调试器的情况下会直接崩溃 附加调试器后 会中断到调试器中 但是按 F5 继续运行后 程序还能继续执行 interesting 你能猜出这是个什么异常吗 初遇错误 在测试程序功能的时候 意外的崩
  • 【教程】一款Markdown 编辑器,免费版本 Typora 下载与使用.

    csdn资源老挂 再补一个网盘的 哪个能用用哪个吧 链接 https pan baidu com s 19c MJQRuas9v5lHxF1uB6A pwd f3n5 提取码 f3n5 gt gt 资源 lt
  • EXCEL-数据透视表、日数据整理成月数据

    1 当你面对一个很多年的日数据 想要把它整理成月数据 下图是2015年1月到2022年1月的日数据 2 首先我们把没用的信息挪开 在时间和日数据上加个表头 3 接着选中数据 包括表头 点击 插入 数据透视表 4 跳出来的框框 直接确认 5
  • AppsFlyer 研究(四)OneLink Deep Linking Guide

    一 简介 深度链接是指当用户打开移动应用时向其提供个性化的内容 或将用户带到应用内特定位置的操作 通过这种操作 您可以为用户提供优质的用户体验 从而极大加强用户与应用的互动 两种深度链接类型 由于用户不一定安装了移动应用 所以有两种深度链接
  • 曾经被视为「牛市制造机」们的机构巨鲸,如今都怎么了?

    这是白话区块链的第1790期原创 作者 Terry出品 白话区块链 ID hellobtc 11 月 17 日 萨尔瓦多总统 Nayib Bukele 表示 从明天开始 我们将每天购买一个比特币 直接开始了国家级别的比特币定投之旅 相信不少
  • 【图片二值化处理,以及byte[] 与bitmap互相转化问题】

    1 byte与bitmap相互转换 将byte流转换为bitmap byte signature item ToArray MemoryStream ms1 new MemoryStream signature Bitmap bm Bitm
  • 打开ABAQUS时,显示找不到 MFC140U.DLL 文件,打不开软件,亲测解决

    打开ABAQUS时 显示找不到 MFC140U DLL 文件 打不开软件 如何解决 下载了X64版本的 安装完毕后就可以打开了 Microsoft Visual C 2017 Redistributable 32位链接 link 64位链接
  • 时钟同步-注意客户端和服务端都需要开启123端口 udp协议

    确认时钟源 chronyc sources v chronyc tracking Linux Chronyd时间同步服务器详解 wangjie722703的博客 CSDN博客 local stratum 10 即使自己未能通过网络时间服务器
  • pytorch-lightning如何设置训练epoch

    Trainer初始化时添加max epochs参数 init model autoencoder LitAutoEncoder trainer pl Trainer gpus 8 max epochs 50 trainer fit auto
  • iOS uiscrollView 嵌套 问题 的解决

    苹果官方文档里面提过 最好不要嵌套scrollView 特别提过UITableView和UIWebView 因为在滑动时 无法知道到底是希望superScrollView滑动还是subScrollView滑动 一旦出现这种情况 情况就出乎我
  • 一文了解websocket全双工通信java实现&socket地址404问题解决

    websocket介绍 1 websocket介绍 1 1注解介绍 2 demo 2 1 后端代码 2 2 前端代码 2 3 效果 附录 socket地址404问题解决 1 websocket介绍 WebSocket是一种在单个TCP连接上
  • 背包问题

    一 01背包 题目 有一个容量为T的背包 现有n个物品 每个物品有都有一个体积w i 和自身价值v i 现在要求求出背包能够装的物品的价值最大 每个物品只可以装一次 基本思路 01背包是背包中的最基础的问题 后面很多背包问题都是01背包和完
  • [会议分享]2022年欧洲计算机科学与信息技术会议(ECCSIT 2022)

    2022年欧洲计算机科学与信息技术会议 ECCSIT 2022 重要信息 会议网址 www eccsit org 会议时间 2022年11月25 27日 召开地点 南京 截稿时间 2022年10月20日 录用通知 投稿后2周内 收录检索 E
  • 【DevOps核心理念基础】3. 敏捷开发最佳实践

    一 敏捷开发最佳实践 1 1 项目管理 1 2 需求管理 1 3 技术架构 1 4 技术开发 1 5 测试 二 敏捷开发最佳实践 2 1 敏捷开发的执行细节 三 全面的DevOps工具链 四 版本控制和协作开发工具 4 1 集中式版本控制工
  • SX1281驱动学习笔记一:Lora驱动移植

    目录 一 资料下载 1 中文手册下载地址 2 英文手册下载地址 3 固件下载地址 4 SX1281的速率计算器下载地址 5 SX128X区别 二 驱动讲解 1 radio h文件 2 sx1281 c文件 3 sx1281 hal c文件
  • unity在同屏幕显示多Camera并在脚本中修改Viewport Rece

    参考 https www it610 com article 1305219586412548096 htm 参考 https www zhihu com question 41879088 sort created 修改Camera的Vi
  • 开放平台认证方案

    背景 本次的直接起因是第三方那边接入系统后端引起的 第三方方觉得认证要过期比较麻烦 而且要用账号密码去调登录接口去刷token 设计不合理 客观来说 凭本人使用过其它开放平台来说确实有些不一样 常见的一些开放平台 有带web的 一般web能
  • 感知机及算法实现

    1 感知机二类分类的线性分类模型 输入为实例的特征向量 输出为实例的类别 取 1和 1二值 感知机对应于输入空间中将实例划分为正负两类的分离超平面 属于判别模型 感知机学习旨在求出将训练数据进行线性划分的分离超平面 为此导入基于误分类的损失