[551]python实现Mean Shift算法

2023-11-18

前文介绍的K-Means算法需要指定K值(分组数),本文实现的MeanShift聚类算法不需要预先知道聚类的分组数,对聚类的形状也没有限制。

为了更好的理解这个算法,本帖使用Python实现Mean Shift算法。

MeanShift算法详细介绍:https://en.wikipedia.org/wiki/Mean_shift

scikit-learn中的MeanShift
import numpy as np
from sklearn.cluster import MeanShift
from matplotlib import pyplot
from mpl_toolkits.mplot3d import Axes3D
from sklearn.datasets.samples_generator import make_blobs
 
fig = pyplot.figure()
ax = fig.add_subplot(111, projection='3d')
 
# 生成3组数据样本
centers = [[2,1,3], [6,6,6], [10,8,9]]
x,_ = make_blobs(n_samples=200, centers=centers, cluster_std=1)
#for i in range(len(x)):
#    ax.scatter(x[i][0], x[i][1], x[i][2])

image.png

# 对上面数据进行分组
clf = MeanShift()
clf.fit(x)
 
labels = clf.labels_    # 每个点对应的组
cluster_centers = clf.cluster_centers_  # 每个组的"中心点"
#print(labels)
print(cluster_centers)
 
colors = ['r', 'g', 'b']
for i in range(len(x)):
    ax.scatter(x[i][0], x[i][1], x[i][2], c=colors[labels[i]])
 
ax.scatter(cluster_centers[:,0], cluster_centers[:,1], cluster_centers[:,2], marker='*', c='k', s=200, zorder=10)
 
pyplot.show()

image.png

MeanShift把上面数据自动分为3组,计算出的三个组的”中心点”为:

[[  1.97566619   1.04212548   3.02410725]
 [  6.01672157   6.18325271   5.96562957]
 [ 10.14455378  12.02394435   9.03499578]]
# 和[[2,1,3], [6,6,6], [10,12,9]]接近;生成的样本越多越接近
使用Python实现Mean Shift算法
# -*- coding:utf-8 -*-
import numpy as np
from matplotlib import pyplot
from mpl_toolkits.mplot3d import Axes3D
from sklearn.datasets.samples_generator import make_blobs


class MeanShift(object):
    def __init__(self, bandwidth=4):
        #bandwidth参数代表点的半径(radius)范围
        self.bandwidth_ = bandwidth

    def fit(self, data):
        centers = {}
        # 把每个点都当做中心点
        for i in range(len(data)):
            centers[i] = data[i]
            # print(centers)
        while True:
            new_centers = []
            for i in centers:
                in_bandwidth = []
                # 取一个点,把在范围内的其它点放到in_bandwidth
                center = centers[i]
                for feature in data:
                    # self.bandwidth_越小分的组越多
                    if np.linalg.norm(feature - center) < self.bandwidth_:
                        in_bandwidth.append(feature)

                new_center = np.average(in_bandwidth, axis=0)
                new_centers.append(tuple(new_center))

            uniques = sorted(list(set(new_centers)))
            prev_centers = dict(centers)
            centers = {}
            for i in range(len(uniques)):
                centers[i] = np.array(uniques[i])

            optimzed = True
            for i in centers:
                if not np.array_equal(centers[i], prev_centers[i]):
                    optimzed = False
                    if not optimzed:
                        break
            if optimzed:
                break

        self.centers_ = centers


if __name__ == '__main__':
    fig = pyplot.figure()
    ax = fig.add_subplot(111, projection='3d')
    centers = [[2, 1, 3], [6, 6, 6], [10, 12, 9]]
    x, _ = make_blobs(n_samples=18, centers=centers, cluster_std=1)
    clf = MeanShift()
    clf.fit(x)
    print(clf.centers_)
    for i in clf.centers_:
        ax.scatter(clf.centers_[i][0], clf.centers_[i][1], clf.centers_[i][2], marker='*', c='k', s=200, zorder=10)

    for i in range(len(x)):
        ax.scatter(x[i][0], x[i][1], x[i][2])

    pyplot.show()

执行结果:

使用Python实现Mean Shift算法

bandwidth参数代表点的半径(radius)范围,bandwidth=20:

使用Python实现Mean Shift算法

bandwidth=2.5:

使用Python实现Mean Shift算法

这个bandwidth可以根据数据样本求出最适合的值。

# -*- coding:utf-8 -*-
import numpy as np
from matplotlib import pyplot
from mpl_toolkits.mplot3d import Axes3D
from sklearn.datasets.samples_generator import make_blobs


class MeanShift(object):

    def __init__(self, bandwidth=None, bandwidth_step=100):
        self.bandwidth_ = bandwidth
        self.bandwidth_step_ = bandwidth_step

    def fit(self, data):
        if self.bandwidth_ == None:
            all_data_center = np.average(data, axis = 0)
            self.bandwidth_ = np.linalg.norm(all_data_center) /self.bandwidth_step_
        print(self.bandwidth_)
        centers = {}
        # 把每个点都当做中心点
        for i in range(len(data)):
            centers[i] = data[i]
            # print(centers)
        while True:
            new_centers = []
            for i in centers:
                in_bandwidth = []
                # 取一个点,把在范围内的其它点放到in_bandwidth
                center = centers[i]
                w = [i for i in range(self.bandwidth_step_)][::-1]
                for feature in data:
                    distance = np.linalg.norm(feature - center)
                    if distance == 0:
                        distance = 0.000000001
                    w_index = int(distance /self.bandwidth_)
                    if w_index > self. bandwidth_step_ -1:
                        w_index = self. bandwidth_step_ -1
                    in_bandwidth += (w[w_index] **2) * [feature]

                new_center = np.average(in_bandwidth, axis=0)
                new_centers.append(tuple(new_center))

            uniques = sorted(list(set(new_centers)))
            tmp = []
            for i in uniques:
                for ii in uniques:
                    if i == ii:
                        pass
                    elif np.linalg.norm(np.array(i) - np.array(ii)) <= self.bandwidth_:
                        tmp.append(ii)
                        break
            for i in tmp:
                try:
                    uniques.remove(i)
                except:
                    pass

            prev_centers = dict(centers)
            centers = {}
            for i in range(len(uniques)):
                centers[i] = np.array(uniques[i])

            optimzed = True
            for i in centers:
                if not np.array_equal(centers[i], prev_centers[i]):
                    optimzed = False
                    if not optimzed:
                        break
            if optimzed:
                break

        self.centers_ = centers

    def predict(self, data):
        self.labels_ = {}
        for i in range(len(centers)):
            self.labels_[i] = []
        for feature in data:
            distances = [np.linalg.norm(feature - self.centers_[center]) for center in self.centers_]
            clf = distances.index(min(distances))
            self.labels_[clf].append(feature)
在实际数据上应用Mean Shift算法

数据集:titanic.xls(泰坦尼克号遇难者/幸存者名单)。目的:对乘客进行分类,看看这几组人有什么共同特点。

# -*- coding:utf-8 -*-
import numpy as np
from sklearn.cluster import MeanShift,estimate_bandwidth

from sklearn import preprocessing
import pandas as pd


'''
数据集:titanic.xls(泰坦尼克号遇难者/幸存者名单)
<http://blog.topspeedsnail.com/wp-content/uploads/2016/11/titanic.xls>
***字段***
pclass: 社会阶层(1,精英;2,中层;3,船员/劳苦大众)
survived: 是否幸存
name: 名字
sex: 性别
age: 年龄
sibsp: 哥哥姐姐个数
parch: 父母儿女个数
ticket: 船票号
fare: 船票价钱
cabin: 船舱
embarked
boat
body: 尸体
home.dest
******
目的:使用除survived字段外的数据进行means shift分组,看看能分为几组,这几组人有什么共同特点
'''

# 加载数据
df = pd.read_excel('titanic.xls')
# print(df.shape)  (1309, 14)
# print(df.head())
# print(df.tail())
"""
    pclass  survived                                            name     sex  \
0       1         1                    Allen, Miss. Elisabeth Walton  female
1       1         1                   Allison, Master. Hudson Trevor    male
2       1         0                     Allison, Miss. Helen Loraine  female
3       1         0             Allison, Mr. Hudson Joshua Creighton    male
4       1        G 0  Allison, Mrs. Hudson J C (Bessie Waldo Daniels)  female

       age  sibsp  parch  ticket      fare    cabin embarked boat   body  \
0  29.0000      0      0   24160  211.3375       B5        S    2    NaN
1   0.9167      1      2  113781  151.5500  C22 C26        S   11    NaN
2   2.0000      1      2  113781  151.5500  C22 C26        S  NaN    NaN
3  30.0000      1      2  113781  151.5500  C22 C26        S  NaN  135.0
4  25.0000      1      2  113781  151.5500  C22 C26        S  NaN    NaN

    home.dest
0                     St Louis, MO
1  Montreal, PQ / Chesterville, ON
2  Montreal, PQ / Chesterville, ON
3  Montreal, PQ / Chesterville, ON
4  Montreal, PQ / Chesterville, ON
"""

org_df = pd.DataFrame.copy(df)

# 去掉无用字段
df.drop(['body', 'name'], 1, inplace=True)

df.convert_objects(convert_numeric=True)#将object格式转float64格式
df.fillna(0, inplace=True)  # 把NaN替换为0

# 把字符串映射为数字,例如{female:1, male:0}
df_map = {}  # 保存映射
cols = df.columns.values
for col in cols:
    if df[col].dtype != np.int64 and df[col].dtype != np.float64:
        temp = {}
        x = 0
        for ele in set(df[col].values.tolist()):
            if ele not in temp:
                temp[ele] = x
                x += 1

        df_map[df[col].name] = temp
        df[col] = list(map(lambda val: temp[val], df[col]))

for key, value in df_map.items():
   print(key,value)
# print(df.head())

# 由于是非监督学习,不使用label
x = np.array(df.drop(['survived'], 1).astype(float))
# 将每一列特征标准化为标准正太分布,注意,标准化是针对每一列而言的
x = preprocessing.scale(x)

clf = MeanShift()
clf.fit(x)

labels = clf.labels_
cluster_centers = clf.cluster_centers_
print('labels:',labels)
print('cluster_centers:',cluster_centers)
n_cluster = len(np.unique(labels))
print('n_cluster:',n_cluster)

org_df['group'] = np.nan
for i in range(len(x)):
    org_df['group'].iloc[i] = labels[i]

survivals = {}
for i in range(n_cluster):
    temp_df = org_df[org_df['group'] == float(i)]
    survival_cluster = temp_df[(temp_df['survived'] == 1)]
    survial = 1.0 * len(survival_cluster) / len(temp_df)
    survivals[i] = survial
print(survivals)

# MeanShift自动把数据分成了三组,每组对应的生还率为(有时分成4组):
# {0: 0.37782982045277125, 1: 0.8333333333333334, 2: 0.1}
# 你可以详细分析一下org_df, 看看这几组人的共同特点是什么
# print(org_df[ org_df['group'] == 2 ])
# print(org_df[ org_df['group'] == 2 ].describe())
org_df.to_excel('group.xls')

来源:http://blog.topspeedsnail.com/archives/10366

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

[551]python实现Mean Shift算法 的相关文章

  • 论文笔记: MOGRIFIER LSTM

    2020 ICLR 修改传统LSTM 当前输入和隐藏状态充分交互 从而获得更佳的上下文相关表达 1 Mogrifier LSTM LSTM的输入X和隐藏状态H是完全独立的 机器学习笔记 GRU gruc UQI LIUWJ的博客 CSDN博

随机推荐

  • 一次excle导入数值精度失真处理过程(附java、python、goland实现代码)

    在一次excle导入中通过java poi包导入数值过长时出现数值失真的问题 100283710028672000000 在通过java导入时变成了100283710028672010000 现在通过goland java python三种
  • Jupyter Notebook与Markdown知识点汇总(一)

    知识点汇总 安装与启动 软件简介 安装与启动 新建Notebook 操作教程 认识界面 运行Jupyter notebook 新建notebook 修改文件名 菜单栏详情 熟悉工具栏 单元 快捷键 Markdown知识点汇总 运行Pytho
  • 【已解决】如何Python利用matplotlib绘制三维曲面图(可自由旋转的三维图)

    1 需求 在做电机的电磁设计时 需要对某一些参数进行优化 因此从电磁仿真软件Maxwell中导出了数据 部分数据如下图所示 可是这样无法直观地看出参数的影响 因此将其调整为矩阵形式 如下图所示 这样虽然已经能够比较直观地看出输入参数 电流和
  • 如何使用Lua脚本来实现原子性操作

    找一个让你开心一辈子的人 才是爱情的目标 最好的 往往就是在你身边最久的 在Redis中 Lua脚本可以用于实现原子性操作 原子性操作指的是一组操作要么全部执行成功 要么全部不执行 使用Lua脚本可以将多个Redis命令组合成一个原子性操作
  • js的new操作符做了哪些事情

    js的new操作符做了哪些事情 new 操作符新建了一个空对象 这个对象原型指向构造函数的prototype 执行构造函数后返回这个对象
  • 简述什么是静态测试、动态测试、黑盒测试、白盒测试、α测试 β测试

    简述什么是静态测试 动态测试 黑盒测试 白盒测试 测试 测试 静态测试是不运行程序本身而寻找程序代码中可能存在的错误或评估程序代码的过程 动态测试是实际运行被测程序 输入相应的测试实例 检查运行结果与预期结果的差异 判定执行结果是否符合要求
  • css 标签默认样式及清除

    标签默认样式及清除 标签默认样式 一些HTML标签在浏览器中会有默认样式 例如 body标签会有margin 8px ul标签会有margin 16px 0 及padding left 40px 当我们在切图软件中进行尺寸或位置测量的时候
  • 在本地jupyter使用其他环境(如:云创)

    接上一篇笔记 1 打开本地cmd 也可以使用x shell 输入命令 ssh L8888 localhost 8888 root 你的IP 此处的IP是远程需要用到的环境的IP 如下图所示 输入命令 enter之后 如下图所示 2 此时 仍
  • 拟合椭圆分开内外两组点集并计算两两之间的距离

    通过拟合椭圆 区分开内外两组点 然后计算两两的距离 from Ransac Process import RANSAC import cv2 import numpy as np import math from operator impo
  • 逻辑漏洞小结之SRC篇(值得收藏,反复看!)

    最近在挖各大src 主要以逻辑漏洞为主 想着总结一下我所知道的一些逻辑漏洞分享一下以及举部分实际的案例展示一下 方便大家理解 主要从两个方面看 业务方面与漏洞方面 接下来就从拿到网站的挖掘步骤进行逐一介绍各个逻辑漏洞 一 业务 注册 1 短
  • 使用matlab中的随机森林进行数据回归预测

    在MATLAB中使用随机森林进行数据回归预测 你可以遵循以下步骤 准备数据集 将你的特征矩阵X和目标变量向量y加载到MATLAB工作空间中 确保X和y的维度匹配 拆分数据集 将数据集划分为训练集和测试集 可以使用cvpartition函数进
  • app毕业设计开题报告基于Uniapp实现的鲜花商城App

    更多项目资源 最下方联系我们 目录 Uniapp项目介绍 资料获取 Uniapp项目介绍 计算机毕业设计安卓App毕设项目之基于APP的鲜花商城 IT实战课堂 哔哩哔哩 bilibili计算机毕业设计安卓App毕设项目之基于APP的鲜花商城
  • 如何 ubuntu下启动/停止/重启MySQL

    如何启动 停止 重启MySQL 一 启动方式 1 使用 service 启动 service mysql start 2 使用 mysqld 脚本启动 etc inint d mysql start 3 使用 safe mysqld 启动
  • 安全保护策略:iOS应用程序代码保护的关键步骤和技巧

    转载 怎么保护苹果手机移动应用程序ios ipa文件中的代码 目录 转载 怎么保护苹果手机移动应用程序ios ipa文件中的代码 代码混淆步骤 1 选择要混淆保护的ipa文件 2 选择要混淆的类名称 3 选择要混淆保护的函数 方法 4 配置
  • [教程]VC++6.0的简单使用

    鉴于许多同学的vc 6 0无法正常使用 并且不会创建工程及文件 还有的同学会遇到一些编译的问题 我在这里做个小教程 1 工具的准备 首先 我把需要的资源给大家 一共就两个文件 一个安装文件 另一个是MSDEV exe 用于替换 链接 htt
  • 解决IDEA每次新建maven工程都要设置目录问题

    使用IDEA每次新建工程时都要设置这三个目录 解决方法 设置好就 了
  • 磁盘数据线接触不良的故障排查

    手头有个小型主机 运行centos 发现工作很不稳定 经常启动不起来 就算启动起来也会在几分钟内出现各种IO错误 可能出现以下几种报错 1 只读文件系统 Read only file system 尝试对磁盘写入的时候可能出现这个错误 2
  • 2 Linux系统高级

    2 Linux系统高级 1 Linux用户与权限 1 1 文件权限概述 Linux操作系统是 多任务多用户 操作系统 每当我们使用用户名登录操作系统时 Linux都会对该用户进行认证 授权审计等操作 操作系统为了识别每个用户 会给每个用户定
  • 大数据开发必备面试题Spark篇合集

    1 Hadoop 和 Spark 的相同点和不同点 Hadoop 底层使用 MapReduce 计算架构 只有 map 和 reduce 两种操作 表达能力比较欠缺 而且在 MR 过程中会重复的读写 hdfs 造成大量的磁盘 io 读写操作
  • [551]python实现Mean Shift算法

    前文介绍的K Means算法需要指定K值 分组数 本文实现的MeanShift聚类算法不需要预先知道聚类的分组数 对聚类的形状也没有限制 为了更好的理解这个算法 本帖使用Python实现Mean Shift算法 MeanShift算法详细介