kmeans算法和kmeans++

2023-11-13

kmeans算法及其优化改进

kmeans聚类算法

算法原理

kmeans的算法原理其实很简单

我用一个最简单的二维散点图来做解释

如上图,我们直观的看到该图可聚成两个分类,我们分别用红点和蓝点表示

image-20211112201540298

下面我们模拟一下Kmeans是怎么对原始的二维散点图做聚类的

首先,随机初始化2个聚类中心(一般是在随机选择两个样本点作为聚类中心)至于什么是聚类中心呢,我们暂且压下不表,现在就把它当成一个点就好。

image-20211112203412124

然后我们就去把所有离红色点进的样本点标成红色,把离蓝色点近的样本点标成蓝色

image-20211112203555261

然后我们重新设定聚类中心的位置,设在哪呢。红色聚类中心就设在现在的红色点的中心(均值),蓝色聚类中心就设在现在的蓝色点的中心(均值),样本颜色重新设为黑色

image-20211112203818576

然后我们继续把所有离红色点进的样本点标成红色,把离蓝色点近的样本点标成蓝色

image-20211112204034130

其实我们现在看来,红色和蓝色已经很分明了,已经达到最初的效果了,但是机器没有我们的眼睛,机器没办法直观的看到现在已经可以停止了。

所以它会继续按照计算均值的方法重新设定聚类中心

image-20211112204447156

然后继续把所有离红色点进的样本点标成红色,把离蓝色点近的样本点标成蓝色

image-20211112204908973

然后继续按照计算均值的方法重新设定聚类中心,但到这里机器会发现我们设定的聚类中心和之前的聚类中心的位置几乎没有改变,这说明我们的算法收敛了,每个样本的类别基本已经确定了。于是算法终止,聚类完成。其实在这个地方有两个指标来表示是否终止,一是计算聚类中心位置的变化,二是计算样本聚类的变化,二者是等价的。

算法步骤

  1. 选定K个聚类中心

  2. 计算每个样本点到K个聚类中心的距离,将样本分类设定为距离最小的聚类中心对应分类

  3. 计算每个分类集合的样本均值,并将其作为新的聚类中心

  4. 重复2,3步骤,直到新的聚类中心与原聚类中心的距离小于设定的阈值即可

算法实现

导入包

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs

定义模型

class Kmeans:
    def __init__(self, k, max_iter=300, thresh=1e-5):
        self.k = k
        self.thresh = thresh
        self.max_iter = max_iter

    def random_centroid_init(self,X):
        # 随机选取K个样本作为聚类中心
        return X[np.random.choice(X.shape[0], size=self.k)]

    def dist(self, x):
        return [np.linalg.norm(x - c) for c in self.centroids]

    def fit_predict(self, X):
        # 初始化聚类中心
        self.centroids = self.random_centroid_init(X)
        for _ in range(self.max_iter):
            # 涂色
            y_pred = np.array([np.argmin(self.dist(x)) for x in X])
            
            # 计算新的聚类中心
            new_centroids = self.centroids.copy()
            for i in range(self.k):
                new_centroids[i] = np.mean(X[y_pred==i],axis=0)
            
            # 如果聚类中心位置基本没有变化,那么终止
            if np.max(np.abs(new_centroids - self.centroids)) < self.thresh:
                break
            
            # 否则更新聚类中心,重复上述步骤
            self.centroids = new_centroids
        return y_pred

生成数据集

X, y = make_blobs(n_samples=1000, n_features=2, centers=3)

训练

model = KMeans(3, 1e-2)
y_pred = model.fit_predict(X)

可视化结果

plt.figure()
plt.subplot(121)
plt.scatter(X[:, 0], X[:, 1], c=y)
plt.subplot(122)
plt.scatter(X[:, 0], X[:, 1], c=y_pred)
plt.show()

优化改进

好讲完Kmeans算法,我们再想一想Kmeans算法有什么问题.

聚类中心初始化

看下面两张图,它是上面的程序偶尔可能运行出来的结果,左图是运气不好的错误聚类,右边是运气好的正确聚类

image-20211113201051989image-20211113201123442

程序的两次运行唯一的区别就是聚类中心初始化是随机的,那么现在问题就出在这里

再看我们前面介绍的例子,假如我们的聚类中心是在这个地方

image-20211112210301542

于是,绿线下方都被分到红色,绿线上方都分到蓝色,我们一求均值,重新设定聚类中心,发现位置也没有太大改变,于是我们得到的聚类就是这样的

image-20211112210543239

很分明,但是显然不是我们希望的聚类

那么这个问题怎么解决,也就是说,随机初始化聚类中心是不好的,我们应该怎样初始化聚类中心?

Kmeans++算法

为了解决初始化的问题,Kmeans++算法有这样的策略

第一,初始的聚类中心一定是在样本中选,这个与Kmeans算法不同,尽管kmeans算法我们一般也是这样做的,但在算法中并没有对其有实际的限制。而Kmeans++的算法流程中,只有在样本中选才能进行这个算法

第二,选择的K个聚类中心需要尽可能的远。

算法过程

首先随机选取第一个聚类中心,计算各个样本点到距离最近的聚类中心的距离。

让距离越远的点被选择为新的聚类中心的概率越大,重复上述步骤,直到选择出所有的聚类中心。

算法实现

这里怎么将距离的大小体现在选择的概率上其实有很多方式,最极端的一种就是,距离最大的概率为1,其余为0。

实现如下

def max_centroid_init(self,X):
    centroids = []
    centroids.append(X[np.random.choice(X.shape[0])])
    for i in range(self.k-1):
        index = np.argmax([np.min([np.linalg.norm(x - c) for c in centroids]) for x in X])
        centroids.append(X[index])
    return np.array(centroids)

比较温和一点的是这样,以距离/(距离和)作为每个样本被选择的概率

例如现在有3个样本离自己的聚类中心的距离分别为 5 , 10 , 10 5,10,10 5,10,10​,其和为 25 25 25

我们随机一个 0 0 0 25 25 25之间的数 n u m b e r number number​​​​

如果这个数在 0 0 0​到 5 5 5​以内,我们选择第一个样本,在 5 5 5 15 15 15​​以内我们选择第二个样本

其实就是从第一个样本开始遍历, n u m b e r = n u m b e r − D ( x ) number = number-D(x) number=numberD(x)​​, n u m b e r number number​​ 什么时候小于0,此时遍历到的样本就是新的聚类中心​

实现如下

def soft_centroid_init(self,X):
    centroids = []
    centroids.append(X[np.random.choice(X.shape[0])])
    for i in range(self.k-1):
        D = [np.min([np.linalg.norm(x - c) for c in centroids]) for x in X]
        number = np.random.choice(np.sum(D))
        for i,d in enumerate(D):
            number -= d
            if number<0:
                centroids.append(X[i])
                break
    return np.array(centroids)

其它的改进优化我就不介绍了,我暂时不关注性能问题

改进全部代码

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs

class Kmeans:
    def __init__(self, k, init='pp-soft', max_iter=300, thresh=1e-5):
        self.k = k
        self.thresh = thresh
        self.max_iter = max_iter
        self.init = init

    def random_centroid_init(self,X):
        # 随机选取K个样本作为聚类中心
        return X[np.random.choice(X.shape[0], size=self.k)]
    
    def max_centroid_init(self,X):
        centroids = []
        centroids.append(X[np.random.choice(X.shape[0])])
        for i in range(self.k-1):
            index = np.argmax([np.min(self.dist(x)) for x in X])
            centroids.append(X[index])
        return np.array(centroids)

    def soft_centroid_init(self,X):
        centroids = []
        centroids.append(X[np.random.choice(X.shape[0])])
        for i in range(self.k-1):
            D = [np.min(self.dist(x)) for x in X]
            number = np.random.choice(int(np.sum(D)))
            for i,d in enumerate(D):
                number -= d
                if number<0:
                    centroids.append(X[i])
                    break
        return np.array(centroids)

    def dist(self, x):
        return [np.linalg.norm(x - c) for c in self.centroids]

    def fit_predict(self, X):
        # 初始化聚类中心
        if self.init == 'random':
            self.centroids = self.random_centroid_init(X)
        elif self.init == 'pp-max':
            self.centroids = self.max_centroid_init(X)
        else:
            self.centroids = self.soft_centroid_init(X)
        for _ in range(self.max_iter):
            # 涂色
            y_pred = np.array([np.argmin(self.dist(x)) for x in X])
            
            # 计算新的聚类中心
            new_centroids = self.centroids.copy()
            for i in range(self.k):
                new_centroids[i] = np.mean(X[y_pred==i],axis=0)
            
            # 如果聚类中心位置基本没有变化,那么终止
            if np.max(np.abs(new_centroids - self.centroids)) < self.thresh:
                break
            
            # 否则更新聚类中心,重复上述步骤
            self.centroids = new_centroids
        return y_pred
    
X, y = make_blobs(n_samples=1000, n_features=2, centers=3)

model = Kmeans(3)
y_pred = model.fit_predict(X)

plt.figure()
plt.subplot(121)
plt.scatter(X[:, 0], X[:, 1], c=y)
plt.subplot(122)
plt.scatter(X[:, 0], X[:, 1], c=y_pred)
plt.show()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

kmeans算法和kmeans++ 的相关文章

随机推荐

  • css3 transform + deviceorientation实现图片旋转效果

    1 陀螺仪deviceorientation的使用 参考 关于陀螺仪deviceorientation https segmentfault com a 1190000007183883 2 transform各属性的具体使用 参考 深入理
  • 计算机组成原理——单周期CPU

    单周期CPU 项目代码 实验原理 MIPS指令 rom coe文件 代码 顶层模块SingleCycleCPU display外围模块 PC instructionMemory Alu模块 DataMemory ControlUnit 旧的
  • 排序(六):归并排序

    排序算法系列文章 排序 一 冒泡排序 排序 二 选择排序 排序 三 堆排序 排序 四 插入排序 排序 五 二分搜索 排序 六 归并排序 排序 七 快速排序 排序 八 希尔排序 目录 排序算法系列文章 归并排序 Merge Sort 基本思想
  • Python之文件的读写

    文章目录 前言 一 打开和关闭文件 open和close 1 打开文件 2 关闭文件 mode的方式 几种读取文件的函数 写入文件的函数 二 with open as操作文件 1 with open as与open close的区别 总结
  • Ubuntu部署OpenStack zed版本neutron报错:Feature ‘linuxbridge‘ is experimental and has to be explicitly enab

    系统版本 Ubuntu 22 04 1 LTS OpenStack版本 zed 组件 Neutron 组件报错内容 Feature linuxbridge is experimental and has to be explicitly e
  • GLSL中texture3D获得的值大小

    使用OpenGL的glTexImage3D 获得纹理数据 再在片元着色器对数据进行处理texture3D 得到的数据已被压缩到0 1 openGL函数glTexImage3D导入数据后 在GLSL中 数据被进行了压缩 glTexImage3
  • python3GUI--音乐播放器(精简版)By:PyQt5(附下载地址)

    文章目录 一 前言 二 预览 1 主界面 2 歌单页 3 歌词页 4 播放列表 5 mini 6 设置 三 心得 1 解耦 2 体验优化 3 歌词显示 4 双击歌曲后发生什么 四 总结 一 前言 传送门 1 python3GUI 打造一款音
  • 关于Linux下的pid文件

    1 pid文件的内容 用cat命令查看 可以看到内容只有一行 记录了该进程的ID 2 pid文件的作用防止启动多个进程副本 3 pid文件的原理进程运行后会给 pid文件加一个文件锁 只有获得pid文件 固定路径固定文件名 写入权限 F W
  • Elasticsearch聚合分析、mget批量查询、bulk批量更新

    Elasticsearch分组集合 一 分组聚合操作 开启fielddata属性 1 在ElasticSearch中默认fielddata默认是false的 因为开启Text的fielddata后对内存的占用很高 如果进行聚合查询时候就需要
  • Redfish协议测试工具–Postman

    1 工具和资料获取 2 简单使用说明 1 GET类举例 2 PATCH类举例 3 常见命令 1 工具和资料获取 Postman工具获取 服务器Redfish接口说明文档 使用前必读接口文档中 适用的产品 查看自己的服务器是否支持此协议 2
  • 简单sql注入

    报错注入找列数 确定为16 联合查询找回显点 查询数据库和数据库版本 版本为5 0以上 需要对查询的内容加密否则报错 结果不是需要的 查询所有的表 获得表名cms users 获得字段usename password 得到账号密码
  • 用java代码验证char类型数据占几个字节

    char为字符型数据 储存单个字符 但阿拉伯数字 英文字母 标点符号等皆为字符型数据 占用字节看似错综复杂 但是char也为脱离计算机基本 二进制储存机制 char本质上内存中皆存储字符编码 1 127为ASCII码 也就是常用的字符 但在
  • 关于iOS9中的App Transport Security(ATS)相关说明及适配

    iOS9中新增App Transport Security 简称ATS 特性 主要使到原来请求的时候用到的HTTP 都转向TLS1 2协议进行传输 这也意味着所有的HTTP协议都强制使用了HTTPS协议进行传输 原文如下 App Trans
  • VS2010:error C2061: 语法错误

    实例 类名 类中包含的头文件 point iostream line point flat flat line 输出错误 error C2061 语法错误 标识符 flat 解决办法 前置声明 line h class flat
  • 区块链读书笔记04 - 以太坊

    区块链读书笔记04 以太坊 以太坊 Ethereum 以太坊关键概念 账户 Account 交易 Transaction 消息 Messsage Gas 合约 contract 以太坊虚拟机 EVM DApp 去中心化应用 以太坊架构 以太
  • 网站域名服务器加密,网站域名利用https防劫持方法

    原标题 网站域名利用https防劫持方法 公共 DNS HttpDNS 的部署成本过高 并且具有一定的技术门槛 在面对无孔不入的 DNS 劫持时有时候其实有点力不从心 那么如何简单有效低成本的加强域名防劫持呢 只需要给网站开启 HTTPS
  • mysql jdbc 多数据源_springboot多数据源(oracle、mysql)

    1 application properties配置 server port 8085 server tomcat uri encoding utf 8 MySQL spring datasource primary driver clas
  • java基于BufferedImage进行图片数字识别预处理

    参考文章链接 1 https blog csdn net kobesdu article details 8142068 2 https blog csdn net fjssharpsword article details 5265184
  • 从此刻开始走进HTML的大门!!!

    文章目录 什么是HTML呢 HTML的结构又是怎么样的呢 学习HTML的标签 标题标签 段落标签 文本格式化标签 换行标签 字符实体 容器标签 图片标签 超链接标签 列表标签 什么是HTML呢 HTML 英文全称是 Hyper Text M
  • kmeans算法和kmeans++

    kmeans算法及其优化改进 kmeans聚类算法 算法原理 kmeans的算法原理其实很简单 我用一个最简单的二维散点图来做解释 如上图 我们直观的看到该图可聚成两个分类 我们分别用红点和蓝点表示 下面我们模拟一下Kmeans是怎么对原始