k近邻法matlab_【机器学习算法笔记】01-k近邻法

2023-10-27

k近邻法是一种基本的分类和回归方法,是最简单的监督学习算法之一,其本质极其直观:在给定的训练数据集里找距离新的输入实例最近的k个实例,将该k个实例中的多数标签赋给输入实例即可。问题的关键在于三个方面:

  • 距离度量

样本空间中两个实例点的距离显然是其相似程度的反映,常见模型一般是n维实数向量空间

,多数使用欧式距离,及
distance 中的 p = 2。所以算法程序中此处可设置默认参数;
  • k值的选择

k的选择对结果影响重大:

一般选用一个比较小的数值,后面根据我们的实际数据作交叉验证。

  • 分类决策规则

往往是多数表决,即直接取k个中最多的标签即可。其实应当还有半数表决等更强要求的方法,数据和精力有限暂未谈论。


这三点一说其实理论准备就已经说完了,只是这样的算法每次要计算训练集所有实例到输入的距离,效率过低,这个问题将在下一篇【机器学习算法笔记】02——kd树中聊一聊,这篇基本还是以简单k近邻法实现快速分类为主。以下是主题代码。

1、数据集准备

根据书上内容准备了三个数据集

# 函数1 创建题设数据集
def create_movie_data_set():
    group_out = np.array([[1, 101], [5, 89], [108, 5], [115, 8]])
    labels_out = np.array(['爱情片', '爱情片', '动作片', '动作片'])
    return group_out, labels_out


# 函数2 读取鸾尾花数做实验(弃)
def create_iris_data_set(test_proportion=0.2):
    from sklearn import datasets
    from sklearn.model_selection import train_test_split
    iris_all = datasets.load_iris()
    iris_train, iris_test, iris_label_train, iris_label_test 
        = train_test_split(iris_all.data, iris_all.target,test_size=test_proportion)
    return iris_train, iris_test, iris_label_train, iris_label_test


# 函数3 读取书上约会数据实验
def create_dating_data_set(test_proportion=0.2):
    from sklearn.model_selection import train_test_split
    # 原始数据归一化处理
    dating_data = normalization1(np.loadtxt("datingTestSet.txt", delimiter='t', usecols=(0, 1, 2)))
    dating_target = np.loadtxt("datingTestSet.txt", delimiter='t', dtype = 'str',usecols=3)
    dating_train, dating_test, dating_label_train, dating_label_test 
        = train_test_split(dating_data, dating_target,test_size=test_proportion)
    return dating_train, dating_test, dating_label_train, dating_label_test

sklearn中有

2、kNN算法主体

《机器学习实战》书中内容及目前CSDN中多个高浏览的kNN代码都过于陈旧,明明使用了numpy模块,有好用的boardcast规则不用,偏偏要把一维数组扩到n维占内存毫无意义。

# 函数4 简单k邻近算法
def kNN1(x_in, data_train_in, label_train_in, k, n=2):
    """
    :param x_in: 待测试的数据点
    :param data_train_in: 训练集
    :param label_train_in: 训练集标签
    :param k: k近邻法参数
    :param n: 距离度量方法选择,默认n=2即计算欧氏距离
    :return label_result:测试点的标签
    """
    import numpy as np

    # 计算各点到目标点的距离,默认n=2即欧氏距离
    distance = (((data_train_in - x_in) ** n).sum(axis=1)) ** (1 / n)

    distance_sort = distance.argsort()  # 求取distance中从小到大的索引值

    label_dic = {}
    for i in range(k):
        # 读取距离最近的k个点的标签
        label_vote = label_train_in[distance_sort[i]]
        # 将该k个点的标签写入字典
        label_dic[label_vote] = label_dic.get(label_vote, 0) + 1
    # 读取最近k个点的标签值入字典后,直接求取字典中值最大的健即可
    label_result = max(label_dic.items(), key=lambda x: x[1])[0]  # 快速求取字典最值

    return label_result

对于上述算法中的三个关键点:(1)距离度量默认欧式距离,可根据实际情况更改参数n;(2)k值需调用时输入;(3)分类决策规则为简单的多数表决。

1.0版本的算法应当说还是非常粗糙的,没有容错机制,但通过自己尝试修改能深刻理解算法。

3、归一化函数

算法函数一写完就急不可耐地测试了,结果发现效果很差,才想起来对于各维度数值差异较大的数据一定是要归一化的。

# 函数5 简单直接归一化算法
def normalization1(array_in):
    array_in_max = array_in.max(0)
    array_in_min = array_in.min(0)
    array_out = (array_in-array_in_min)/(array_in_max - array_in_min)
    return array_out

这样单独写有个问题就在于对于新数据,每次调用kNN1()前都得先归一化,非常麻烦,暂时搁置,后期更新kNN2时再将normalization直接内置进kNN中。

4、算法效果验证

选取不同的k值,每个k值抽取200次iris数据做验证,

使用经典的iris数据,按照8:2的比例将15个原始数据分为训练集和测试集作初步验证。

k值取1~30,每个k值做200次验证取平均错误率,实现代码如下:

if __name__ == '__main__':
    """直接验证尝试"""
       
    test_result = np.array([])
    for k_test in range(1,31):
        error_rate = np.array([])
        for n in range(200):

            data_train, data_test, label_train, label_test = create_iris_data_set(0.2)
            len_test = float(len(data_test))
            result_error = 0
            
            for i in range(len(data_test)):
                label_try = kNN1(data_test[i], data_train, label_train, k_test)
                # 检查统计每个输入实例分类结果是否与实际标签一致
                if label_try != label_test[i]:
                    result_error += 1
            # 统计每批的错误率
            error_rate = np.append(error_rate, float(result_error) / len_test)

        test_result = np.append(test_result,np.mean(error_rate))

    print(test_result)
    # 画图看一看
    plt.title("iris_data_kNN_test")
    plt.xlabel("k")
    plt.ylabel("error_rate")
    plt.plot(np.arange(1,31), test_result)
    plt.show()

结果如下:

[ 4.65% 4.40% 3.77% 3.80% 3.72% 4.18% 3.72% 3.67% 3.73% 3.30%
3.55% 3.28% 2.58% 3.23% 3.40% 3.58% 3.32% 3.25% 3.62% 3.67%
4.95% 4.30% 4.52% 4.55% 5.22% 5.05% 5.48% 5.27% 5.55% 4.77%]

即当k取13时,分类误差达到最小,约2.58%。

5、收获

  • 严格来讲k近邻法并没有弄完,kd树能有效提升计算效率我还没弄好,还得加紧;
  • 目前主程序了调用同一数据源的create函数需要多次读取硬盘很没有效率,这是我在前期构建数据时没有提前考虑到的,以后在构建数据来源接口时务必考虑后面要怎么用,根据功能需求调整接口;
  • 自己码代码和看别人的成品过程收获是截然不同的,网上可参考资料虽然多,但没有自己真的做过看再多也没用;
  • 要善于利用python各类模块的便捷,比如iris数据,我本来是下载了txt放电脑上准备硬盘读取的,但实际著名的sklearn模块都自带了这些经典数据集,如breast_cancer等等,完全不需要自己来回折腾;另一个分割数据集的时候我是先自己用sample功能写了个小函数来拆分原始数据,成功实现后觉得太过麻烦还是换成了sklearn的相应功能,这些小细节有精力了当然自己做更好,但已经有成熟的解决方案时就没有必要非得用自己的。
  • numpy操作还是不熟练啊,格式化打印个结果弄了半天。文档看的再细帮助也有限,还是得大量练习,长期浸润其中,Practice make perfect,还是不能着急,毕竟用的时间太短,数组在我手上还开不了花。得想想办法怎么把平常用到的numpy小技巧积累起来。
  • 效率得提高,写这么一篇没多少干货的文字都花了不少功夫,得思考思考如何在学习和编程的同时就把思路和收获记下来,否则每周时间这么紧,如果做算法笔记成了累赘那就得不偿失了。
  • 这周末要去跑西马,基本上一天半就交代出去了,还有些计划任务都没完成,下周估计得加把劲儿了,希望不要鸽。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

k近邻法matlab_【机器学习算法笔记】01-k近邻法 的相关文章

  • 微信小程序分享功能总结

    小程序实现分享功能有如下三种方式 1 在js文件中实现onShareAppMessage函数 即可点击右上角菜单分享给微信好友 页面中默认已实现 在js文件中实现onShareTimeline函数 即可点击右上角菜单分享到微信朋友圈 需要自
  • 联盟链走向何方

    联盟链技术哪家强 开源架构Fabric FISCO BCOS 以下简称 BCOS CITA 技术对比 出品 碳链价值研究院 01 摘要 第 46 届世界经济论坛达沃斯年会将区块链与人工智能 自动驾驶等一并列入 第四次工业革命 经济学人 曾在
  • qt5.5.1 移植4412的问题过程

    编译错误 WTF wtf unicode wchar UnicodeWchar h In function bool WTF Unicode isAlphanumeric UChar WTF wtf unicode wchar Unicod
  • 开源项目部署之悟CRM部署 PHP服务端版

    文章目录 前言 一 部署环境 二 部署流程 1 安装宝塔等基础环境 2 部署CRM 点击安装即可 在这里插入图片描述 https img blog csdnimg cn 4f83ede5d3f74343a927f8a641c25e19 pn
  • 助推打造全球研发中心城市

    阿里 社招 一面 面不动了 真的面不动了一 项目挑一个你觉得最有挑战性的细说 有些细节被质疑了 嘴在前面飞脑子在后面追 以后说每一句话都要小心 笑cry 二 八股1 聚簇索引和非 题解 检索产品名称和描述 一 select prod nam
  • 3D关键点检测(2020-2017)

    3D关键点检测 1 3D关键点检测之PoseDRL Deep Reinforcement Learning for Active Human Pose Estimation AAAI2020 这篇文章可能与我们通常所处理的姿态估计任务略有不
  • 【BEV】BEVDet

    BEVDet 解析 BEVDet 模型 bevdet r50 训练配置 Scale NMS 优化配置 推理记录 注册 随机种子 总结 BEVDet BEVDet继承于CenterPoint gt MVTwoStageDetector 模型实
  • 射频工程师笔记---射频通信基础

    文章更新或问题可关注本人公众号 回顾一下移动通信技术的发展 其实是互联网和通信技术的融合过程 在这个过程中 很多应用都在不断加入其中 比如计算机跟通信的融合产生了互联网 互联网跟手机的融合带来了移动互联网 手机可以看杂志 看视频 听音乐 于
  • SpringCLoud——服务的拆分和远程调用

    服务拆分 服务拆分注意事项 一般是根据功能的不同 将不同的服务按照功能的不同而分开 微服务拆分注意事项 不同微服务 不要重复开发相同业务 微服务数据独立 不要访问其他微服务的数据库 微服务可以将自己的业务暴露为接口 供其他微服务调用 远程调
  • C++ 数据结构与算法(五)(哈希表)

    哈希表 1 定义 哈希表 Hash table 也称散列表 是根据关键码的值而直接进行访问的数据结构 一般哈希表都是用来快速判断一个元素是否出现集合里 只需要在初始化时用哈希函数 hash function 将这些元素映射在哈希表的索引上
  • WJ的Direct3D简明教程2:Render-To-Texture

    转载请注明 来自http blog csdn net skyman 2001 Rendering to a texture is one of the advanced techniques in Direct3D On the one h
  • Unity绘制户型(一)

    户型绘制主要对象数据 点 线 面 部件 门窗 主要难点是通过绘制的点寻找闭合多边形 多边形的生成 3D墙体的生成 门窗要在墙体上留下孔洞这四个功能 这篇文章我只写前两个问题 后面来两个问题单独再写一篇文章 1 如何寻找闭合多边形 我的方法是
  • 内容管理系统测试实战

    使用django和restframework开发接口 使用postman测试接口 使用unittest和requests模块测试接口 目录 Django安装 Django Rest Framework 创建API应用 数据库迁移 创建超级管
  • C++11中pair的用法

    概述 pair可以将两个数据组合成一种数据类型 C 标准库中凡是必须返回两个值的函数都使用pair pair有两个成员变量 分别是first和second 由于使用的struct而不是class 因此可以直接访问pair的成员变量 基本用法
  • Python_某宝某东秒杀抢购

    纯学习分享 只用于学习用途 请勿用于任何商业用途 本人不承担任何责任 视频编写过程 某宝秒杀程序 某宝源码 from selenium import webdriver from selenium webdriver common by i
  • springboot配置shiro多项目实现session共享的详细步骤

    springboot配置shiro多项目实现session共享的详细步骤 项目的配置步骤我已写到另一篇文章中 shiro框架 多项目登录访问共享session的实现 springboot redis shiro 的实现项目已共享到GitHu
  • 关于Tomcat端口被占用的情况

    今天打开eclipse突然发现运行不了 报错的提示为 Several ports 8005 8080 8009 required by Tomcat v7 0 Server at localhost are already in use 有
  • Android studio遇到问题:Emulator: PANIC: Cannot find AVD system path. Please define ANDROID_SDK_ROOT

    前言 在使用android studio时 配置模拟器的时候一直在报错这个 然后网上找到问题 并实际解决了问题 在这里记录下 目录 问题原因 没有配置环境的情况下 是因为他默认找的是这个路径的AVD 问题很明显了 中文路径导致的 C Use
  • Vue路由 传参几种方式

    动态路由传参 path detail username name a component gt import components Detail vue

随机推荐