[546]python实现K-Nearest Neighbor算法

2023-11-13

K-Nearest Neighbor(KNN)可以翻译为K最近邻算法,是机器学习中最简单的分类算法。为了更好的理解这个算法,本帖使用Python实现这个K-Nearest Neighbor算法 ,最后和scikit-learn中的k-Nearest Neighbor算法进行简单对比。

KNN算法基本原理

假设我有如下两个数据集:

dataset  =  {'black':[  [1,2],  [2,3],  [3,1]  ],  'red':[  [6,5],  [7,7],  [8,6]  ]  }

使用Python实现k-Nearest Neighbor算法

上面画出了两组数据点:black,red。假设你在上图任意添加一个点,如(3.5, 5.3),KNN的任务就是判断这个点(下图中的绿点)该划分到哪个组。

使用Python实现k-Nearest Neighbor算法

KNN分类算法超级简单:只需使用初中所学的两点距离公式(欧拉距离公式),计算绿点到各组的距离,看绿点和哪组更接近。K代表取离绿点最近的k个点,这k个点如果其中属于红点个数占多数,我们就认为绿点应该划分为红组,反之,则划分为黑组。

如果有两组数据(如上图),k值最小应为3;如果有三组数据(如下图),k值最小应为5。scikit-learn默认k值为5。

使用Python实现k-Nearest Neighbor算法

上面使用的是二维数据,同样的逻辑可以推广到三维或任意纬度。

除了K-Nearest Neighbor之外还有其它分组的方法,如Radius-Based Neighbor。

使用Python实现KNN算法
# -*- coding:utf-8 -*-
import math
import numpy as np
from matplotlib import pyplot
from collections import Counter
import warnings


# k-Nearest Neighbor算法
def k_nearest_neighbors(data, predict, k=5):
    if len(data) >= k:
        warnings.warn("k is too small")
    # 计算predict点到各点的距离
    distances = []
    for group in data:
        for features in data[group]:
            # euclidean_distance = np.sqrt(np.sum((np.array(features)-np.array(predict))**2))   # 计算欧拉距离,这个方法没有下面一行代码快
            euclidean_distance = np.linalg.norm(np.array(features) - np.array(predict))
            distances.append([euclidean_distance, group])
    print(sorted(distances))
    sorted_distances = [i[1] for i in sorted(distances)]
    top_nearest = sorted_distances[:k]
    # print(top_nearest)  ['red','black','red']
    group_res = Counter(top_nearest).most_common(1)[0][0]
    confidence = Counter(top_nearest).most_common(1)[0][1] * 1.0 / k
    # confidences是对本次分类的确定程度,例如(red,red,red),(red,red,black)都分为red组,但是前者显的更自信
    return group_res, confidence


if __name__ == '__main__':
    dataset = {'black': [[1, 2], [2, 3], [3, 1]], 'red': [[6, 5], [7, 7], [8, 6]]}
    new_features = [3.5, 5.2]  # 判断这个样本属于哪个组
    for i in dataset:
        for ii in dataset[i]:
            pyplot.scatter(ii[0], ii[1], s=50, color=i)

    which_group, confidence = k_nearest_neighbors(dataset, new_features, k=3)
    print(which_group, confidence)
    pyplot.scatter(new_features[0], new_features[1], s=100, color=which_group)
    pyplot.show()

执行结果:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jbAzqpdG-1576323000530)(http://upload-images.jianshu.io/upload_images/12504508-f5829945cbe86017.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240)]

使用现实数据测试上面实现的knn算法

数据集(Breast Cancer):https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+%28Original%29

数据集字段:

   #  属性                          Domain
   -- -----------------------------------------
   1. Sample code number            id   # 这一列没啥用
   2. Clump Thickness               1 - 10
   3. Uniformity of Cell Size       1 - 10
   4. Uniformity of Cell Shape      1 - 10
   5. Marginal Adhesion             1 - 10
   6. Single Epithelial Cell Size   1 - 10
   7. Bare Nuclei                   1 - 10
   8. Bland Chromatin               1 - 10
   9. Normal Nucleoli               1 - 10
  10. Mitoses                       1 - 10
  11. Class(分类):                   (2 代表良性, 4 代表恶性)

使用Python实现k-Nearest Neighbor算法

我们的任务是使用knn分类数据,预测肿瘤是良性的还是恶性的。

代码:

# -*- coding:utf-8 -*-
import math
import numpy as np
from collections import Counter
import warnings
import pandas as pd
import random


# k-Nearest Neighbor算法
def k_nearest_neighbors(data, predict, k=5):
    if len(data) >= k:
        warnings.warn("k is too small")

    # 计算predict点到各点的距离
    distances = []
    for group in data:
        for features in data[group]:
            euclidean_distance = np.linalg.norm(np.array(features) - np.array(predict))
            distances.append([euclidean_distance, group])

    sorted_distances = [i[1] for i in sorted(distances)]
    top_nearest = sorted_distances[:k]

    group_res = Counter(top_nearest).most_common(1)[0][0]
    confidence = Counter(top_nearest).most_common(1)[0][1] * 1.0 / k

    return group_res, confidence


if __name__ == '__main__':
    df = pd.read_csv('breast-cancer-wisconsin.data')  # 加载数据
    # print(df.head())
    print(df.shape)
    '''inplace=True:不创建新的对象,直接对原始对象进行修改;
    inplace=False:对数据进行修改,创建并返回新的对象承载其修改结果。'''
    df.replace('?', np.nan, inplace=True)  # -99999
    df.dropna(inplace=True)  # 去掉无效数据
    print(df.shape)
    df.drop(['id'], 1, inplace=True)
    # 把数据分成两部分,训练数据和测试数据
    full_data = df.astype(float).values.tolist()# 先将数据类型转为float类型,在转为列表
    random.shuffle(full_data)
    test_size = 0.2  # 测试数据占20%
    train_data = full_data[:-int(test_size * len(full_data))]
    test_data = full_data[-int(test_size * len(full_data)):]
    # print(test_data)
    train_set = {2: [], 4: []}
    test_set = {2: [], 4: []}
    for i in train_data:
        train_set[i[-1]].append(i[:-1])
    for i in test_data:
        test_set[i[-1]].append(i[:-1])

    correct = 0
    total = 0
    for group in test_set:
        for data in test_set[group]:
            # 你可以调整这个k看看准确率的变化,你也可以使用matplotlib画出k对应的准确率,找到最好的k值
            res, confidence = k_nearest_neighbors(train_set, data,k=5)
            if group == res:
                correct += 1
            else:
                print(confidence)
            total += 1
    print(correct / total)  # 准确率
    print(k_nearest_neighbors(train_set, [4, 2, 1, 1, 1, 2, 3, 2, 1], k=5))  # 预测一条记录

执行结果:

$ python breast_cancer_knn.py 
1.0    # 分类错误时对应的自信程度,100%自信但是分类错误,这是我们要注意的
0.6    #
0.6    #
0.9779411764705882   # 预测准确率
(2, 1.0)    # 良性
使用scikit-learn中k邻近算法
# -*- coding:utf-8 -*-
import numpy as np
# cross_validation已deprecated,使用model_selection替代
from sklearn import preprocessing, model_selection, neighbors
import pandas as pd


df = pd.read_csv('breast-cancer-wisconsin.data')
# print(df.head())
# print(df.shape)
df.replace('?', np.nan, inplace=True)  # -99999
df.dropna(inplace=True)
# print(df.shape)
df.drop(['id'], 1, inplace=True)

X = np.array(df.drop(['class'], 1))
Y = np.array(df['class'])

X_trian, X_test, Y_train, Y_test = model_selection.train_test_split(X, Y, test_size=0.2)

clf = neighbors.KNeighborsClassifier()
clf.fit(X_trian, Y_train)

accuracy = clf.score(X_test, Y_test)
print(accuracy)

sample = np.array([4, 2, 1, 1, 1, 2, 3, 2, 1])
print(sample.reshape(1, -1))
print(clf.predict(sample.reshape(1, -1)))

执行结果:

$  python breast_cancer.py
0.970802919708  # 预测准确率
[2]  # 良性

scikit-learn中的算法和我们上面实现的算法原理完全一样,只是它的效率更高,支持的参数更全。

来源:http://blog.topspeedsnail.com/archives/10287#more-10287

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

[546]python实现K-Nearest Neighbor算法 的相关文章

  • 【枚举的定义;枚举变量的定义、初始化和赋值】(学习笔记16--枚举)

    目录 枚举的定义 枚举变量的定义 枚举变量的初始化与赋值 使用枚举类型 可以提高程序代码的健壮性和可读性 并且枚举成员属于常量 甚至可以使用枚举成员名作为维的大小 来进行数组的定义 枚举的定义 定义枚举的格式为 enum 枚举名 枚举成员1
  • Muduo网络库核心梳理

    Muduo网络库 Muduo网络库本身并不复杂 是一个新手入门C 面向对象网络编程的经典实战项目 但是 新手在刚刚上手读代码的时候 非常容易陷入代码的汪洋大海 迷失方向 本文旨在简要梳理Muduo网络库的核心内容 帮助初学者快速上手源码阅读

随机推荐

  • DES算法简单介绍及用法

    大家好 今天给大家分享一下DES加密 一 DES介绍 加密一般分为可逆加密和不可逆加密 其中可逆加密一般又分为对称加密和非对称加密 前者是我们使用公用密钥加密之后可以使用公用密钥再解密出来 而后者则是使用公用密钥加密之后必须使用私用密钥来解
  • MySQL这一章就够了(一)

    前言 呕心沥血5个月淦出本文 整理所有MySQL知识 我愿称之为地表最强MySQL MySql笔记 MySQL是关系型数据库 基于SQL查询的开源跨平台数据库管理系统 它最初是由瑞典MySQL AB公司开发的 现在它是Oracle Corp
  • 手把手教你区块链java开发智能合约nft(第四篇)-如何动态获取gasPrice和gasLimit?

    手把手教你区块链java开发智能合约nft 第三篇 如何动态获取gasPrice和gasLimit 初学区块链 那真叫一个痛苦并无助 如果没有人带你的话 今天写的这篇是在前面文章基础上写的 初学区块链的朋友建议先看我前面写的文章 手把手教你
  • 【雕爷学编程】Arduino动手做(72)---HX711 人体称重模块

    37款传感器与执行器的提法 在网络上广泛流传 其实Arduino能够兼容的传感器模块肯定是不止这37种的 鉴于本人手头积累了一些传感器和执行器模块 依照实践出真知 一定要动手做 的理念 以学习和交流为目的 这里准备逐一动手尝试系列实验 不管
  • win32读取注册表

    直接代码 bool bIsIE6 false HKEY hKey NULL DWORD dwType DWORD dwSize LONG lReg RegOpenKey HKEY CLASSES ROOT HTTP shell open c
  • 技术方案设计没有深度?试试这套方法论

    原文为阿里技术发布的一篇文章 作者 高福来 不拔 读后受益匪浅 决定转载分享 平时听到一些同学说技术方案没什么深度 很难讲出来 怎么去体现技术方案设计的深度是大家普遍关心的一个问题 这个问题不是个例问题 因此分享下自己的一些观点和看法 主要
  • ps怎么对比原图快捷键_Photoshop最常用的10个快捷键,让你修图事半功倍!

    小伙伴们 小编今天要给大家发一波福利 揭秘Photoshop最常用的10个快捷键 让你修图事半功倍 1 Ctrl Ctrl 放大 缩小图层 使用Photoshop进行修图时 为了更加准确地进行精修 我们需要放大图片 此时使用快捷键 Ctrl
  • STM32F4-正点原子探索者-SYSTEM文件夹下的delay.c文件内延时函数详解

    目录 笔记 首先是对应的头文件delay h中的函数 1 delay init u8 SYSCLK 此处将把关于UCOS相关代码忽略 后面学习 注 以下为SysTick结构体详解 与主体函数只是有一定联系 可略过 SysTick结构体中的C
  • 散点矩阵

    import pandas as pd import matplotlib pyplot as plt import seaborn as sns crime pd read csv crimeRatesByState2005 csv cr
  • 【Unity3D】Unity3D游戏里实现复制粘贴功能

    public class sTest MonoBehaviour public InputField input public Button btn if UNITY IOS DllImport Internal private stati
  • ffmpeg开发环境的安装测试和更新的步骤

    本文将介绍ffmpeg开发环境的安装测试和更新的步骤 基于ubuntu16 04和ffmpeg3 2 1 安装x264 1 libx264需要yasm sudo apt get install yasm 但是yasm版本比较旧 所以安装na
  • 做个成功的管理者

    什么是 管理 大家对这词都不陌生 但什么才是真正的管理呢 管理的真谛在 理 不在 管 管理者的主要职责就是建立一个合理的游戏规则 让每个员工按照游戏规则自我管理 游戏规则要兼顾公司的利益和个人的利益 并且把公司的利益和个人的利益统一起来 尽
  • SQL语法基础

    结构化查询语言 Structured Query Language 是一种特殊目的的编程语言 是一种数据库查询和程序设计语言 用于存取数据以及查询 更新和管理关系数据库系统 数据查询语言 DQL SELECT 查询 数据操作语言 DML I
  • Flink实时任务性能调优

    前言 通常我们在开发完Flink任务提交运行后 需要对任务的参数进行一些调整 通常需要调整的情况是任务消费速度跟不上数据写入速度 从而导致实时任务出现反压 内存GC频繁 FullGC 频繁 内存溢出导致TaskManager被Kill 今天
  • 【Git】(六)子模块跟随主仓库切换分支

    场景 主仓库 TestGit 子模块 SubModule 分支v1 0 gitmodules文件 submodule Library SubModule path Library SubModule url git gitee com su
  • 内置变量列表(Unix)

    当前页可打印的行数 属于Perl格式系统的一部分 根据上下文内容返回错误号或者错误串 列表分隔符 打印数字时默认的数字输出格式 Perl解释器的进程ID 当前输出通道的当前页号 与上个格式匹配的字符串 当前进程的组ID 当前进程的有效组ID
  • 【JS逆向】之JS函数的闭包导出调用

    前言 闭包其实就是一个函数里面的私有方法 我们在函数外部无法调用 这个就叫闭包 理解起来其实也不难 这个其实也跟js的作用域有很大关系的 这里的闭包也用了作用域的特性 js的作用域分两种 全局和局部 基于我们所熟悉的作用域的知识 我们知道在
  • org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter

    大部分愿意是因为导入的包的版本问题 我是把jackson的包换成了高版本的就没有报错了
  • 【golang设计模式】Golang设计模式详解二

    六 工厂方法模式 工厂方法模式使用子类的方式延迟生成对象到子类中实现 Go中不存在继承 所以使用匿名组合来实现 代码如下 factorymethod go package factorymethod Operator 是被封装的实际类接口
  • [546]python实现K-Nearest Neighbor算法

    K Nearest Neighbor KNN 可以翻译为K最近邻算法 是机器学习中最简单的分类算法 为了更好的理解这个算法 本帖使用Python实现这个K Nearest Neighbor算法 最后和scikit learn中的k Near