[549]python实现K-Means算法

2023-05-16

K-Means是一种聚类(Clustering)算法,使用它可以为数据分类。K代表你要把数据分为几个组,前文实现的K-Nearest Neighbor算法也有一个K,实际上,它们有一个相似之处:K-Means也使用欧拉距离公式。

  • K-Means:https://en.wikipedia.org/wiki/K-means_clustering
  • scikit-learn中的聚类算法:http://scikit-learn.org/stable/modules/clustering.html
  • scikit-learn K-Means文档:http://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html

K-Means算法的基本思想是初始随机给定K个簇中心,按照最邻近原则把待分类样本点分到各个簇。然后按平均法重新计算各个簇的质心,从而确定新的簇心。一直迭代,直到簇心的移动距离小于某个给定的值。

为了更好的理解这个K-Means,本帖使用Python实现K-Means算法。

K-Means简单图示(sklearn)

import numpy as np
from sklearn.cluster import KMeans
from matplotlib import pyplot
 
# 要分类的数据点
x = np.array([ [1,2],[1.5,1.8],[5,8],[8,8],[1,0.6],[9,11] ])
# pyplot.scatter(x[:,0], x[:,1])

image.png

# 把上面数据点分为两组(非监督学习)
clf = KMeans(n_clusters=2)
clf.fit(x)  # 分组
 
centers = clf.cluster_centers_ # 两组数据点的中心点
labels = clf.labels_   # 每个数据点所属分组
print(centers)
print(labels)
 
for i in range(len(labels)):
    pyplot.scatter(x[i][0], x[i][1], c=('r' if labels[i] == 0 else 'b'))
pyplot.scatter(centers[:,0],centers[:,1],marker='*', s=100)
 
# 预测
predict = [[2,1], [6,9]]
label = clf.predict(predict)
for i in range(len(label)):
    pyplot.scatter(predict[i][0], predict[i][1], c=('r' if label[i] == 0 else 'b'), marker='x')
 
pyplot.show()

image.png

*是两组数据的”中心点”;x是预测点分组。上面使用的是二维数据,方便可视化。

使用Python实现K-Means算法

K-Means聚类算法主要分为三个步骤

  • 第一步是为待聚类的点随机寻找聚类中心
  • 第二步是计算每个点到聚类中心的距离,将各个点归类到离该点最近的聚类中去
  • 第三步是计算每个聚类中所有点的坐标平均值,并将这个平均值作为新的聚类中心,反复执行(2)、(3),直到聚类中心不再进行大范围移动或者聚类次数达到要求为止

Python代码:

# -*- coding:utf-8 -*-
import numpy as np
from matplotlib import pyplot


class K_Means(object):
    # k是分组数;tolerance‘中心点误差’;max_iter是迭代次数
    def __init__(self, k=2, tolerance=0.0001, max_iter=300):
        self.k_ = k
        self.tolerance_ = tolerance
        self.max_iter_ = max_iter

    def fit(self, data):
        self.centers_ = {}
        for i in range(self.k_):
            self.centers_[i] = data[i]

        for i in range(self.max_iter_):
            self.clf_ = {}
            for i in range(self.k_):
                self.clf_[i] = []
            # print("质点:",self.centers_)
            for feature in data:
                # distances = [np.linalg.norm(feature-self.centers[center]) for center in self.centers]
                distances = []
                for center in self.centers_:
                    # 欧拉距离
                    # np.sqrt(np.sum((features-self.centers_[center])**2))
                    distances.append(np.linalg.norm(feature - self.centers_[center]))
                classification = distances.index(min(distances))
                self.clf_[classification].append(feature)

            # print("分组情况:",self.clf_)
            prev_centers = dict(self.centers_)
            for c in self.clf_:
                self.centers_[c] = np.average(self.clf_[c], axis=0)

            # '中心点'是否在误差范围
            optimized = True
            for center in self.centers_:
                org_centers = prev_centers[center]
                cur_centers = self.centers_[center]
                if np.sum((cur_centers - org_centers) / org_centers * 100.0) > self.tolerance_:
                    optimized = False
            if optimized:
                break

    def predict(self, p_data):
        distances = [np.linalg.norm(p_data - self.centers_[center]) for center in self.centers_]
        index = distances.index(min(distances))
        return index


if __name__ == '__main__':
    x = np.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8], [1, 0.6], [9, 11]])
    k_means = K_Means(k=2)
    k_means.fit(x)
    print(k_means.centers_)
    for center in k_means.centers_:
        pyplot.scatter(k_means.centers_[center][0], k_means.centers_[center][1], marker='*', s=150)

    for cat in k_means.clf_:
        for point in k_means.clf_[cat]:
            pyplot.scatter(point[0], point[1], c=('r' if cat == 0 else 'b'))

    predict = [[2, 1], [6, 9]]
    for feature in predict:
        cat = k_means.predict(predict)
        pyplot.scatter(feature[0], feature[1], c=('r' if cat == 0 else 'b'), marker='x')

    pyplot.show()

执行结果:

使用Python实现K-Means算法

K-Means算法需要你指定K值,也就是需要人为指定数据应该分为几组。下一帖我会实现Mean Shift算法,它也是一种聚类算法(Hierarchical),和K-Means(Flat)不同的是它可以自动判断数据集应该分为几组。

在实际数据上应用K-Means算法

# -*- coding:utf-8 -*-
import numpy as np
from sklearn.cluster import KMeans
from sklearn import preprocessing
import pandas as pd

'''
数据集:titanic.xls(泰坦尼克号遇难者/幸存者名单)
<http://blog.topspeedsnail.com/wp-content/uploads/2016/11/titanic.xls>
***字段***
pclass: 社会阶层(1,精英;2,中产;3,船员/劳苦大众)
survived: 是否幸存
name: 名字
sex: 性别
age: 年龄
sibsp: 哥哥姐姐个数
parch: 父母儿女个数
ticket: 船票号
fare: 船票价钱
cabin: 船舱
embarked
boat
body: 尸体
home.dest
******
目的:使用除survived字段外的数据进行k-means分组(分成两组:生/死),然后和survived字段对比,看看分组效果。
'''

# 加载数据
df = pd.read_excel('titanic.xls')
# print(df.shape)  (1309, 14)
# print(df.head())
# print(df.tail())
"""
    pclass  survived                                            name     sex  \
0       1         1                    Allen, Miss. Elisabeth Walton  female
1       1         1                   Allison, Master. Hudson Trevor    male
2       1         0                     Allison, Miss. Helen Loraine  female
3       1         0             Allison, Mr. Hudson Joshua Creighton    male
4       1         0  Allison, Mrs. Hudson J C (Bessie Waldo Daniels)  female

       age  sibsp  parch  ticket      fare    cabin embarked boat   body  \
0  29.0000      0      0   24160  211.3375       B5        S    2    NaN
1   0.9167      1      2  113781  151.5500  C22 C26        S   11    NaN
2   2.0000      1      2  113781  151.5500  C22 C26        S  NaN    NaN
3  30.0000      1      2  113781  151.5500  C22 C26        S  NaN  135.0
4  25.0000      1      2  113781  151.5500  C22 C26        S  NaN    NaN

    home.dest
0                     St Louis, MO
1  Montreal, PQ / Chesterville, ON
2  Montreal, PQ / Chesterville, ON
3  Montreal, PQ / Chesterville, ON
4  Montreal, PQ / Chesterville, ON
"""

# 去掉无用字段
df.drop(['body', 'name', 'ticket'], 1, inplace=True)
# print(df.info())#可以查看数据类型
df.convert_objects(convert_numeric=True)#将object格式转float64格式
df.fillna(0, inplace=True)  # 把NaN替换为0

# 把字符串映射为数字,例如{female:1, male:0}
df_map = {}  # 保存映射关系
cols = df.columns.values
print('cols:',cols)
for col in cols:
    if df[col].dtype != np.int64 and df[col].dtype != np.float64:
        temp = {}
        x = 0
        for ele in set(df[col].values.tolist()):
            if ele not in temp:
                temp[ele] = x
                x += 1

        df_map[df[col].name] = temp
        df[col] = list(map(lambda val: temp[val], df[col]))

for key, value in df_map.items():
   print(key,value)
# print(df.head())

# 由于是非监督学习,不使用label
x = np.array(df.drop(['survived'], 1).astype(float))
# 将每一列特征标准化为标准正太分布,注意,标准化是针对每一列而言的
x = preprocessing.scale(x)

clf = KMeans(n_clusters=2)
clf.fit(x)
# 上面已把数据分成两组

# 下面计算分组准确率是多少
y = np.array(df['survived'])

correct = 0
for i in range(len(x)):
    predict_data = np.array(x[i].astype(float))
    predict_data = predict_data.reshape(-1, len(predict_data))
    predict = clf.predict(predict_data)
    # print(predict[0], y[i])
    if predict[0] == y[i]:
        correct += 1

print(correct * 1.0 / len(x))

执行结果:

$ python sk_kmeans.py 
0.692131398014  # 泰坦尼克号的幸存者和遇难者并不是随机分布的,在很大程度上取决于年龄、性别和社会地位
$ python sk_kmeans.py 
0.307868601986  # 结果出现很大波动,原因是它随机分配组(生:0,死:1)(生:1,死:0)
                # 1-0.307868601986是实际值
$ python sk_kmeans.py 
0.692131398014

来源:http://blog.topspeedsnail.com/archives/10349

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

[549]python实现K-Means算法 的相关文章

随机推荐

  • APM和PIX飞控日志分析入门贴

    我们在飞行中 xff0c 经常会碰到各种各样的问题 xff0c 经常有模友很纳闷 xff0c 为什么我的飞机会这样那样的问题 xff0c 为什么我的飞机会炸机 xff0c 各种问题得不到答案是一件非常不爽的问题 xff0c 在APM和PIX
  • 微电子及集成电路设计常用问题总结(考研面试向)

    mos管的沟道长度调制效应 xff1f 源极导致势垒下降 xff1f 衬底电流体效应 xff1f 衬底偏执效应 xff1f 速度饱和效应 xff1f 举例典型的trade off xff1f mos amp bjt的工作曲线 xff1f 加
  • YOLO详解

    转载自 xff1a https zhuanlan zhihu com p 25236464 从五个方面解读CVPR2016 目标检测论文YOLO Unified Real Time Object Detection 创新 核心思想 效果 改
  • 使用微信监管你的TF训练

    以TensorFlow的example中 xff0c 利用CNN处理MNIST的程序为例 xff0c 我们做了下面一点点小小的修改 1 xff0c 首先导入了itchat和threading两个包分别用于微信和县线程 xff08 因为要有一
  • 你应该知道的9篇深度学习论文(CNNs 理解)

    当时看到英文的博客 xff0c 本想翻译给感兴趣的同学们看看 xff0c 没想到已经有人翻译 xff0c 于是进行了转载 xff0c 留给自己和更多的人学习 xff0c 本文仅供参考 英文博客 xff1a https adeshpande3
  • JS笔记(==和===的介绍)

    61 61 和 61 61 61 介绍 61 61 关系运算符 等于 用于比较两个操作数是否相等的 相等为true xff0c 否则为false 61 不等于 61 61 61 xff1a 绝对等于 用于比较两个操作数是否相等的 相等为tr
  • 全国大学生电子设计竞赛B题感悟-优象光流篇

    今年是2019年电赛国赛年 xff0c 这本是是一个很好的机会冲击国家奖的 xff0c 但是由于个人视野太窄 xff0c 眼光不够长远而错失良机 今年测评结束的时候我就已经预感到了结果 xff0c 记得比赛前去提交作品的时候 xff0c 大
  • 滑模控制以及系统动力学与控制论(1)

    维基百科里是这样定义系统 System 的 System from Latin syst ma in turn from Greek syst ma is a set of entities real or abstract compris
  • 安装docker

    首先信任 Docker 的 GPG 公钥 sudo apt key adv keyserver hkp p80 pool sks keyservers net 80 recv keys 58118E89F3A912897C070ADBF76
  • 我的AI之路(39)--使用深度相机之小觅深度相机

    小觅深度相机的SDK代码在github上 小觅相机的支持库需要从这里https github com slightech MYNT EYE D SDK下载SDK源码后本地编译后再安装 xff0c Ubuntu上的步骤是 xff1a 1 如果
  • DeepSORT C++版的一个bug

    DeepSORT的官方python版实现是https github com nwojke deep sort xff0c C 43 43 版的DeepSORT中https github com shaoshengsong DeepSORT这
  • 使用sudo运行vncserver后导致Ubuntu循环登录进入不了桌面的问题原因及解决办法

    因需要多人同时登录到机器人的Ubuntu主机调试 xff0c 于是安装VNC xff0c 不记得N年以前怎么做的了 xff0c 于是按照网上某文说的先 xff1a sudo apt get install xfce4 插一句 xff1a 也
  • CAS6.1 配置连接数据库,以及修改自定义的密码验证(SpringSecurity)

    一 cas 配置数据库 1 在build gradle中引入jar dependencies Other CAS dependencies modules may be listed here compile 34 org apereo c
  • 怎样学好数电

    随着社会的进步和科学技术的发展 xff0c 数字系统和数字设备已广泛应用于各个领域 xff0c 大规模 xff0c 超大规模集成电路技术的不断完善使得数字电路在现代电子系统的比重越来越大 xff0c 数字电路建立了根本是信号的数字处理 xf
  • 嵌入式经典面试题之选择题

    一 单项选择题 1 如下哪一个命令可以帮助你知道shell命令的用法 xff08 A xff09 A man B pwd C help D more 2 Linux分区类型默认的是 xff1a xff08 B xff09 A vfat B
  • 自定义 Windows RE 体验

    发布时间 2009年10月 更新时间 2009年10月 应用到 Windows 7 Windows Server 2008 R2 https technet microsoft com zh cn library dd744576 v 61
  • Java命令行运行错误: 找不到或无法加载主类

    前言 xff1a 虽然学习Java语言约有两年多 xff0c 但在最近需要使用命令行工具编译并运行Java程序时 xff0c 还是报错了 花费了一些时间 xff0c 解决了该问题 xff0c 发现解决方法在初学Java时使用过 一则 xff
  • 开贴记录STM32工程遇到的各种问题及解决方法

    开贴记录STM32工程遇到的各种问题及解决方法 STM32工程问题集锦 针对工程开发过程中常见问题进行备注 文章目录 STM32工程问题集锦问题列表时钟设置串口设置STM32CUBEIDEADCDMA定时器HardFault 处理方法时钟设
  • [1040]DataWorks中MaxCompute的常用操作命令

    文章目录 表操作1 查看表的详细信息 xff1a 2 通过 96 create table as select 96 语句创建表 xff0c 并在建表的同时将数据复制到新表中 xff1a 3 如果希望源表和目标表具有相同的表结构 xff0c
  • [549]python实现K-Means算法

    K Means是一种聚类 Clustering 算法 xff0c 使用它可以为数据分类 K代表你要把数据分为几个组 xff0c 前文实现的K Nearest Neighbor算法也有一个K xff0c 实际上 xff0c 它们有一个相似之处