概率密度估计(Probability Density Estimation)--Part3:混合模型

2023-10-26

引入

在结束了有参估计,无参估计后,现在记录混合模型(Mixture models)。这里附一张有参和无参的对比图(本来应该附在Part 2的,不想回去改了。。):
在这里插入图片描述

字面意思,混合模型就是有参模型和无参模型的混合。

举个例子,高斯模型的混合(Mixture of Gaussians,MoG)。
现有三个高斯模型如下:
在这里插入图片描述
我们可以将其视为:
在这里插入图片描述
其概率密度可以近似表示为: p ( x ) = ∑ j = 1 M p ( x ∣ j ) p ( j ) p(x)=\sum^M_{j=1}p(x|j)p(j) p(x)=j=1Mp(xj)p(j)各变量有如下关系:
在这里插入图片描述
大概解释一下: p ( x ∣ j ) p(x|j) p(xj)表示第 j j j个高斯模型的概率密度, p ( j ) = π j p(j)=\pi_j p(j)=πj表示第 j j j个高斯模型出现的概率,是一种先验概率。
这里记住两点:
在这里插入图片描述
即是说,混合模型概率密度积分为1,然后对于混合模型的每个子模型,我们都要求出对应的三个值(均值,方差和模型先验),这是对于MoG而言的。

求解方法

MLE法

对于简单的(单个)高斯模型,我们可以用MLE近似求解,但在高斯混合中无法使用该方法,因为我们在求参数时并不知道这个参数属于高斯混合的哪个子分布。如下:
在这里插入图片描述
这个 p ( j ∣ x ) p(j|x) p(jx)我们是无法观测到的,毕竟如果这个都知道,那高斯混合就只是多个高斯模型的简单叠加而已。
也就是说单个高斯模型的参数会取决于所有高斯模型的参数。所以MLE不可用。

Clustering

分为两种:

  1. Clustering with soft assignments
  2. Clustering with hard assignments

简单说的话,hard assignments是根据Label划分的,也就是说每个数据点只属于一个label,比如说K-mean;soft assignment是根据概率区分的,每个数据点可能属于多个label,但概率不同。如下图:
在这里插入图片描述
所有点都可能是label1或2
这里只介绍soft assignment,毕竟两个几乎是一样的。这就要用到这章的重点了, E M EM EM算法。

E M EM EM算法
大概的说明

E M EM EM算法是一种迭代算法,本质上它也是一种最大似然估计的方法,其特点是,当我们的数据不完整时,比如说只有观测数据,缺乏隐含数据时,可以用EM算法进行迭代推导。
该算法包括 E E E步骤和 M M M步骤,这里不具体介绍一般性的 E M EM EM算法,主要说明其在高斯混合中如何运用。其步骤大体可以概括为:

  1. 随机初始化各子分布的期望 μ 1 , μ 2 . . . μ m \mu_1,\mu_2...\mu_m μ1,μ2...μm
  2. E-step:计算每个子分布的后验 p ( j ∣ x n ) p(j|x_n) p(jxn)
  3. M-step计算所有数据点的加权平均值
    在这里插入图片描述
    其效果如下图所示:
    在这里插入图片描述
较为详细的说明

现在说一下更加详细的步骤。
假设现在有两组数据,即是观测数据和隐藏数据:

  1. Incomplete (observed) data: X = ( X 1 , X 2 , . . . , X n ) X={(X_1, X_2, ... ,X_n)} X=(X1,X2,...,Xn)
  2. Hidden (unobserved) data: Y = ( Y 1 , Y 2 , . . . , Y n ) Y=(Y_1, Y_2,...,Y_n) Y=(Y1,Y2,...,Yn)

组合后形参完整数据:

  1. Complete data: Z = ( X , Y ) Z=(X,Y) Z=(X,Y)

联合密度为 p ( Z ) = p ( X , Y ) = p ( Y ∣ X ) p ( X ) p(Z)=p(X,Y)=p(Y|X)p(X) p(Z)=p(X,Y)=p(YX)p(X)即是 p ( Z ∣ θ ) = p ( X , Y ∣ θ ) = p ( Y ∣ X , θ ) p ( X ∣ θ ) p(Z|\theta)=p(X,Y|\theta)=p(Y|X,\theta)p(X|\theta) p(Zθ)=p(X,Yθ)=p(YX,θ)p(Xθ)在高斯混合中:
p ( X ∣ θ ) p(X|\theta) p(Xθ)是混合模型的似然
p ( Y ∣ X , θ ) p(Y|X,\theta) p(YX,θ)是混合模型中子分布的估计

对于不完整的数据(观测数据),其似然为: L ( θ ∣ X ) = p ( X ∣ θ ) = ∏ n = 1 N p ( X n ∣ θ ) L(\theta|X)=p(X|\theta)=\prod_{n=1}^{N}p(X_n|\theta) L(θX)=p(Xθ)=n=1Np(Xnθ)对于完整数据(Z),其似然为:
在这里插入图片描述
在这里我们虽然不知道 Y Y Y,但如果我们知道当前的参数猜测 θ i − 1 \theta^{i-1} θi1,我们就能用它来预测 Y Y Y
在这里我们计算完整数据的对数似然的期望,如下:
在这里插入图片描述
其中 X X X θ i − 1 \theta^{i-1} θi1是已知的。更进一步展开如下:
在这里插入图片描述
这个等式是根据均值和积分的关系写出来的。即是如下关系:
E [ x ] = ∫ X x f ( x ) d x E[x]=\int_Xxf(x)dx E[x]=Xxf(x)dx其中 f ( x ) f(x) f(x)是概率密度(这部分不太确定,有错的话麻烦大家指出)。
我们需要最大化这个 Q Q Q函数。
接下来是 E M EM EM算法:
E-step(expectation): 计算 p ( y ∣ X , θ i − 1 ) p(y|X,\theta^{i-1}) p(yX,θi1)以便计算 Q ( θ , θ i − 1 ) Q(\theta,\theta^{i-1}) Q(θ,θi1);
M-step(maximization): 最大化 Q Q Q函数求出 θ \theta θ θ ^ = a r g m a x θ Q ( θ , θ i − 1 ) \hat{\theta}=arg max_\theta Q(\theta,\theta^{i-1}) θ^=argmaxθQ(θ,θi1)
这是一种迭代运算,我们要确保每次迭代中,第 i i i次的结果至少和第 i − 1 i-1 i1的一样好,即是: Q ( θ i , θ i − 1 ) ≥ Q ( θ i − 1 , θ i − 1 ) Q(\theta^i,\theta^{i-1})\geq Q(\theta^{i-1},\theta^{i-1}) Q(θi,θi1)Q(θi1,θi1)
若该期望值对于 θ \theta θ来说是最大的,则可以认为(这部分未理解): L ( θ i ∣ X ) ≥ L ( θ i − 1 ∣ X ) L(\theta^i|X)\geq L(\theta^{i-1}|X) L(θiX)L(θi1X)
也就是说,在每次迭代中观测数据的对数似然都会不断增大(或者至少保持不变),最终达到局部最大值。
所以,在实际运用中,初始化对于 E M EM EM算法很重要,一个不好的初始化可能会使结果停在一个不好的局部最优值中。

高斯混合中的 E M EM EM算法(EM for Gaussian Mixtures)

步骤:

  1. 初始化参数 μ 1 , σ 1 , π 1 . . . \mu_1,\sigma_1, \pi_1... μ1,σ1,π1...
  2. 循环,直到满足终止条件:
    1. E-step: 计算每个数据点对于每个子分布的后验分布:
      在这里插入图片描述
      这里的 α \alpha α可以理解成每个数据点属于各个子分布的权重或者说概率。
    2. M-step: 使用E步骤的权重进行更新数据: 在这里插入图片描述至此,高斯混合的 E M EM EM算法就结束了,然后还有最后一个问题,这部分不太理解,但把结论贴上来吧,以后再探究:

在这里插入图片描述

(附)作业相关代码

在这里插入图片描述
也就是给出数据点,然后用EM算法进行高斯拟合:

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal

def load_data(path):
    return np.loadtxt(path)

def init_params(k, d = 2):
    pi = np.ones(k) * 1/k
    mu = [np.random.rand(2, 1) for i in range(k)]
    cov = [np.eye(2) for i in range(k)]

    return pi, mu, cov

def EM(iter_times, data, k):
    pi, mu, cov = init_params(k)
    N = data.shape[0]
    alpha = np.zeros((N, k))
    likelihood_list = []

    def update_alpha(alpha, N, k):
        for i in range(N):
            for j in range(k):
                p_ij = pi[j] * multivariate_normal.pdf(data[i], mu[j].flatten(), cov[j])
                alpha[i][j] = p_ij
                likelihood[i] += p_ij
            alpha[i] = alpha[i] / np.sum(alpha[i])
        return alpha

    sigma = np.empty((N, 2, 2))
    for i in range(N):
        d = data[i, :].reshape(2, -1)
        sigma[i] = np.dot(d, d.T)

    
    for i in range(iter_times):
        likelihood = np.zeros(N)
        # E-step: update alpha
        alpha = update_alpha(alpha, N, k)
        likelihood_list.append(np.sum(np.log(likelihood)))
        # M-step: update mu, cov, pi
        N_j = np.sum(alpha, axis = 0)

        for j in range(k):
            # update mu
            alpha_x = 0
            alpha_x_mu = 0
            for n in range(N):
                alpha_x += (alpha[n][j] * data[n])
            alpha_x = alpha_x.reshape(2,1)
            mu[j] = alpha_x / N_j[j]
            # update pi
            pi[j] = N_j[j] / N
            # update cov
            for n in range(N):
                alpha_x_mu += alpha[n][j]*(data[n] - mu[j].T)*(data[n] - mu[j].T).T
            cov[j] = alpha_x_mu / N_j[j]

    plot_contour(iter_times, data, mu, cov, k)
    if iter_times == 30:
        plt.figure()
        plt.plot(np.arange(1, 31), np.array(likelihood_list))
        plt.title('log-likelihood for every iteration')
        plt.xlabel('iteration')
        plt.ylabel('log-likelihood')
        plt.grid()
        plt.show()
    
def plot_contour(iter_num, data, mu, cov, k):
    plt.figure()
    plt.title("iter_num = %d" % iter_num)
    plt.scatter(data[:,0], data[:,1])
    x_min, x_max = plt.gca().get_xlim()
    y_min, y_max = plt.gca().get_ylim()

    num = 50
    x = np.linspace(x_min, x_max, num)
    y = np.linspace(y_min, y_max, num)
    X, Y = np.meshgrid(x, y)

    for sub_k in range(k):
        Z = np.zeros_like(X)
        for i in range(len(x)):
            for j in range(len(y)):
                z = multivariate_normal.pdf(np.array([[x[i]], [y[j]]]).flatten(), mu[sub_k].flatten(), cov[sub_k])
                Z[j, i] = z
        plt.contour(X, Y, Z)
        plt.show()

path = "./dataSets/gmm.txt"

def main():
    k = 4
    iter_lst = [1, 3, 5, 10, 30]
    pi, mu, cov = init_params(k)
    data = load_data(path)
    for i in iter_lst:
        EM(i, data, k)

if __name__ == '__main__':
    main()

效果如下图所示:
迭代1次:
在这里插入图片描述
迭代3次:
在这里插入图片描述
迭代5次:
在这里插入图片描述
迭代10次:
在这里插入图片描述
迭代30次:
在这里插入图片描述
每次迭代的对数似然:
在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

概率密度估计(Probability Density Estimation)--Part3:混合模型 的相关文章

  • win7 svn服务器搭建过程

    svn简介 https baike baidu com item subversion 7818587 fr aladdin SVN服务端分为 Subversion和VisualSVN Server 这里 我选择了VisualSVN Ser

随机推荐

  • Java笔记:UDP基础使用与广播

    文章目录 目的 作为客户端使用 作为服务器使用 广播 广播地址获取 广播功能演示 总结 目的 UDP是比较基础常用的网络通讯方式 这篇文章将介绍Java中UDP基础使用的一些内容 本文中使用 Packet Sender 工具进行测试 其官网
  • Java使用DES加密解密

    一 DES算法 DES Data Encryption Standard 数据加密标准 它是由IBM公司研制的一种对称密码算法 DES是一个分组加密算法 典型的DES以64位分组对数据加密 加密和解密用的是用一个算法 总长度64位 8字节
  • Spring:基于xml文件的控制反转(ioc)

    1 环境搭建 导入spring使用最基本的坐标
  • VMware Workstation 不可恢复错误: (vmx)

    errors VMware Workstation 不可恢复错误 vmx Exception 0xc0000006 disk error while paging has occurred 日志文件位于 K vmware centos vm
  • 运用决策表设计测试用例

    逻辑关系 逻辑关系 logic relationship 即 依赖关系 在项目管理中 指表示两个活动 前导活动和后续活动 中一个活动的变更将会影响到另一个活动的关系 强制依赖关系 所做工作中固有的依赖关系 可自由处理的依赖关系 由项目队伍确
  • MyBatis:尝试解决Spring Boot集成MyBatis 懒加载时序列化失败的三种方法以及原因FAIL_ON_EMPTY_BEANS

    MyBatis 解决No serializer found for class org apache ibatis executor loader javassist JavassistProxyFactory EnhancedResult
  • python3 Flask 简单入门(MVC模板类)

    跟上一篇文章一样的内容 Flask默认支持的模板是jinja2 jinja2简单实用 1 在Jinja2模板中 我们用 name 表示一个需要替换的变量 很多时候 还需要循环 条件判断等指令语句 在Jinja2中 用 表示指令 2 循环输出
  • win10 装黑苹果 完整教程

    一 材料准备 1 虚拟机软件VMware 2 适用于Windows版本的VMware解锁安装Mac OS的补丁 3 Mac OS X 10 10的黑苹果镜像 以上材料我都为你贴心地准备齐了 在我的云盘获取 链接 https pan baid
  • VUE3+Element-Plus form表单封装

    VUE3 Element Plus form表单封装 新建form组件页面 创建index vue 新建form组件页面 在components中创建新组件 将需要的form表单中常用的UI组件引入 vue3创建组件和vue2中多少有点区别
  • 大学《数据库原理与技术》复习题(二)

    数据库复习题 一 选择题 1 B 是按照一定的数据模型组织的 长期存储在计算机内 可为多个用户共享的数据的集合 A 数据库系统 B 数据库 C 关系数据库 D 数据库管理系统 2 数据库系统的基础是 A 数据结构 B 数据库管理系统 C 操
  • LVGL V8

    本文适用于LVGL V8版本 LVGL simulator vs2019 官方工程 lv sim visual studio 使用注意事项 1 将官方工程从github上下载下来 最好使用git 将整个工程clone下来 因为工程内部有依赖
  • c++坑人

    大家好 我是LCR 今天为大家带来的是c 中的弹窗病毒 当然你也可以把它理解为坑人代码 如果喜欢这篇文章 可以给我点一个赞吗 代码解释 system是c语言库里面自带的一个函数 start的原本意思为 跳转 后面本应接网址 当你的后面为空时
  • 多功能翻译工具:全球翻译、润色和摘要生成

    openai translator openai translator Stars 18 1k License AGPL 3 0 这个项目是一个多功能翻译工具 由 OpenAI 提供支持 可以进行全球单词翻译 单词润色和摘要生成等操作 提供
  • python项目导出依赖包requirements.txt文件

    只导出当前项目依赖包 注意 使用 pip freeze gt requirements txt 会导出大量无用的文件 包括很多个包信息 其实这里是把你当前 python 环境的所有包的相关信息导出来了 如果我们只需导出当前项目所需的依赖包
  • 如何创建线程,多线程下又如何上锁保护公共资源?

    目录 一 创建线程几种方法 1 继承thread类 重写run方法 2 实现runnable接口 重写run方法 3 使用匿名类 或 lamda表达式 让代码更简洁 4 Callable 接口 5 使用线程池创建线程 二 多线程下 需要上锁
  • canvas画布合成

  • windows自动颁发证书

    首先去配置组策略 计算机配置 windows设置 安全设置 公钥策略 证书注册策略和证书服务客户端 不需要勾选禁用用户配置注册策略服务器 用户配置也这样配置 最后进入证书管理器 找到证书模板 右键证书管理 看见一个计算机 去右键 安全这里允
  • 虚拟内存笔记

    虚拟内存 为什么要有虚拟内存 有些进程实际需要的内存很大 超过物理内存的容量 比如一个几十G的游戏 要运行在内存为8G的计算机上 由于多道程序设计 主存是同时可以存放多个进程的逻辑及数据的 这就使得每个进程可用的物理内存更加稀缺 不可能无限
  • [1194]GitLab在web端合并分支

    文章目录 gitlab 在 web 端合并分支 1 1 发起合并操作 1 2 选择源分支和目标分支 1 3 输入合并备注 1 4 合并检查 1 5 完成合并 1 6 查看提交记录 修改的文件及内容 gitlab 在 web 端合并分支 1
  • 概率密度估计(Probability Density Estimation)--Part3:混合模型

    目录 引入 求解方法 MLE法 Clustering E M EM EM算法 大概的说明 较为详细的说明 高斯混合中的