最优传输问题与Sinkhorn算法

2023-05-16

目录

  • 1 引言
  • 2 例子:分甜点
  • 3 最优传输问题
  • 4 Sinkhorn算法
    • 4.1 Sinkhorn距离
    • 4.2 算法流程
    • 4.3 代码实验

1 引言

最近看到一篇特征匹配相关的论文,思想是将特征匹配问题转化为最优传输问题求解,于是我去学习了一下最优传输问题。
本文主要是对博文 Notes on Optimal Transport 的学习做一个记录总结,该博文写的不错,推荐阅读。

2 例子:分甜点

文章作者以一个简单的甜点分配例子引入了最优传输问题。
向量 r = [ 3 , 3 , 3 , 4 , 2 , 2 , 2 , 1 ] ⊤ \mathbf{r}=[3, 3, 3, 4, 2, 2, 2, 1]^{\top} r=[3,3,3,4,2,2,2,1] 表示 n = 8 n=8 n=8 个人需要的甜点数:
在这里插入图片描述
向量 c = [ 4 , 2 , 6 , 4 , 4 ] ⊤ \mathbf{c}=[4, 2, 6, 4, 4]^{\top} c=[4,2,6,4,4] 表示 m = 5 m=5 m=5 种甜点的数量:
5种甜点的数量分布

矩阵 M ∈ R 5 × 8 \mathbf{M}\in \mathbb{R}^{5\times 8} MR5×8 表示每个人对各种甜点的偏好,尺度区间 [ − 2 , 2 ] [-2, 2] [2,2],-2表示非常不喜欢,2表示非常喜欢:
在这里插入图片描述

我们的目标,就是要根据甜点的数量,同时考虑每个人的需求和偏好,将所有甜点合理地分配到每个人手中。

3 最优传输问题

最优运输问题的目标就是以最小的成本将一个概率分布转换为另一个概率分布。上面的分甜点的目标,用最优传输问题的定义来说,就是将概率分布 c \mathbf{c} c 以最小的成本转换到概率分布 r \mathbf{r} r
这就需要我们求得一个分配方案,由矩阵 P ∈ R n × m P\in \mathbb{R}^{n\times m} PRn×m 表示,存储每个人分得的每个甜点的情况。

根据现实条件,这个分配矩阵 P P P 显然具有以下约束:

  1. 分配的甜点数量不能为负数;
  2. 每个人的需求都要满足,即 P P P 的行和服从分布 r \mathbf{r} r
  3. 每种甜点要全部分完,即 P P P 的列和服从分布 c \mathbf{c} c

于是在分布 r \mathbf{r} r c \mathbf{c} c 约束下, P P P 的解空间可以做如下定义:
U ( r , c ) = { P ∈ R > 0 n × m ∣ P 1 m = r , P ⊤ 1 n = c } (1) U(\mathbf{r}, \mathbf{c})=\left\{P \in \mathbb{R}_{>0}^{n \times m} \mid P \mathbf{1}_{m}=\mathbf{r}, P^{\top} \mathbf{1}_{n}=\mathbf{c}\right\} \tag 1 U(r,c)={PR>0n×mP1m=r,P1n=c}(1)
PS:这是博文的原公式,这里我有个疑问,为什么 P P P 的元素要求严格大于0,而不是大于等于0?希望有同学能够解答我的疑惑(感谢)

如前面所述,我们希望最小化转换成本,可以简单地反转偏好矩阵 M \mathbf{M} M 的符号,就可以得到成本矩阵(cost matrix)。于是就有了最优传输问题的公式化表示:
d M ( r , c ) = min ⁡ P ∈ U ( r , c ) ∑ i , j P i j M i j (2) d_{M}(\mathbf{r}, \mathbf{c})=\min _{P \in U(\mathbf{r}, \mathbf{c})} \sum_{i, j} P_{i j} M_{i j} \tag 2 dM(r,c)=PU(r,c)mini,jPijMij(2)

标量 d M d_{M} dM 也被称为推土机距离(earth mover distance),因为它可以解释为至少移动多少“泥土”(成本)才能将一个土堆(分布)变成另一个土堆(分布)。

4 Sinkhorn算法

4.1 Sinkhorn距离

Sinkhorn距离是对推土机距离的一种改进,在其基础上引入了熵正则化项:
d M λ ( r , c ) = min ⁡ P ∈ U ( r , c ) ∑ i , j P i j M i j − 1 λ h ( P ) (3) d_{M}^{\lambda}(\mathbf{r}, \mathbf{c})=\min _{P \in U(\mathbf{r}, \mathbf{c})} \sum_{i, j} P_{i j} M_{i j}-\frac{1}{\lambda} h(P) \tag 3 dMλ(r,c)=PU(r,c)mini,jPijMijλ1h(P)(3)
其中 h ( P ) = − ∑ P i j log ⁡ P i j h(P)=-\sum{P_{ij}\log{P_{ij}}} h(P)=PijlogPij 称作 P P P 的信息熵(information entropy), P P P 分布越均匀,信息熵越大。

熵正则化参数 λ \lambda λ 负责调整信息熵的影响程度, λ \lambda λ 越大,信息熵的影响越小,最终结果受成本矩阵的影响更大,即更多地考虑每个人的喜好;反之,最终结果则更倾向于均匀分配,每种甜点将平均分配给每个人。

4.2 算法流程

新增的熵正则化项似乎让问题更加难以优化,但Sinkhorn算法提供了一种简单且有效的方法应对这一问题,Sinkhorn算法认为,最优分配矩阵 P λ ∗ P^*_\lambda Pλ 的元素应该具有如下形式:
( P λ ∗ ) i j = α i β j e − λ M i j (4) (P^*_\lambda)_{ij}=\alpha_i \beta_j e^{-\lambda M_{ij}} \tag 4 (Pλ)ij=αiβjeλMij(4)
其中正是 α 1 , . . . , α n \alpha_1,...,\alpha_n α1,...,αn β 1 , . . . , β n \beta_1,...,\beta_n β1,...,βn 使得 P ∗ P^* P 满足分配矩阵的三个约束。如何推导出这一形式可以参考SuperGlue中的最优传输算法详解一文。

具体流程如下:

给定: 代价矩阵 M M M, 分布 r \mathbf{r} r, 分布 c \mathbf{c} c, 熵正则化参数 λ \lambda λ
初始化: 分配矩阵 P λ = e − λ M P_\lambda=e^{-\lambda M} Pλ=eλM
重复:

  1. 缩放行,使得 P P P 的行和逼近分布 r \mathbf{r} r
  2. 缩放列,使得 P P P 的列和逼近分布 c \mathbf{c} c

直到: 收敛

4.3 代码实验

以下是Sinkhorn代码实现:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


r = np.array([3, 3, 3, 4, 2, 2, 2, 1])
c = np.array([4, 2, 6, 4, 4])
M = np.array(
    [[2, 2, 1, 0, 0], 
    [0, -2, -2, -2, -2], 
    [1, 2, 2, 2, -1], 
    [2, 1, 0, 1, -1],
    [0.5, 2, 2, 1, 0], 
    [0, 1, 1, 1, -1], 
    [-2, 2, 2, 1, 1], 
    [2, 1, 2, 1, -1]],
    dtype=float) 
M = -M # 将M变号,从偏好转为代价

def compute_optimal_transport(M, r, c, lam, eplison=1e-8):
    """
    Computes the optimal transport matrix and Slinkhorn distance using the
    Sinkhorn-Knopp algorithm

    Inputs:
        - M : cost matrix (n x m)
        - r : vector of marginals (n, )
        - c : vector of marginals (m, )
        - lam : strength of the entropic regularization
        - epsilon : convergence parameter

    Outputs:
        - P : optimal transport matrix (n x m)
        - dist : Sinkhorn distance
    """
    n, m = M.shape  # 8, 5
    P = np.exp(-lam * M) # (8, 5)
    P /= P.sum()  # 归一化
    u = np.zeros(n) # (8, )
    # normalize this matrix
    while np.max(np.abs(u - P.sum(1))) > eplison: # 这里是用行和判断收敛
        # 对行和列进行缩放,使用到了numpy的广播机制,不了解广播机制的同学可以去百度一下
        u = P.sum(1) # 行和 (8, )
        P *= (r / u).reshape((-1, 1)) # 缩放行元素,使行和逼近r
        v = P.sum(0) # 列和 (5, )
        P *= (c / v).reshape((1, -1)) # 缩放列元素,使列和逼近c
    return P, np.sum(P * M) # 返回分配矩阵和Sinkhorn距离

我们来看看在不同 λ \lambda λ 下,得到的分配矩阵有什么特点:

lam = 0.1

P, d = compute_optimal_transport(M,
        r,
        c, lam=lam)

partition = pd.DataFrame(P, index=np.arange(1, 9), columns=np.arange(1, 6))
ax = partition.plot(kind='bar', stacked=True)
print('Sinkhorn distance: {}'.format(d))
ax.set_ylabel('portions')
ax.set_title('Optimal distribution ($\lambda={}$)'.format(lam))

在这里插入图片描述

可以看到每个人分配得到的甜点基本上都符合初始甜点的分布比例 c = [ 4 , 2 , 6 , 4 , 4 ] ⊤ \mathbf{c}=[4, 2, 6, 4, 4]^{\top} c=[4,2,6,4,4]

试着调大 λ \lambda λ
在这里插入图片描述
可以看到最终的分配向每个人的偏好靠拢了。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

最优传输问题与Sinkhorn算法 的相关文章

  • 76-高斯核函数

    高斯核函数 上一篇博客详细的介绍了什么是核函数 xff0c 并且主要以多项式核函数为例 这篇博客主要学习一种特殊的核函数 xff0c 而且它也是 SVM 算法使用最多的一种核函数 xff1a 高斯核函数 核函数我们通常表示成 xff1a 那
  • 数仓实践:总线矩阵架构设计

    如何设计一套切实可行的数据仓库呢 xff1f 我们要明白 xff0c 对于数据仓库的设计是不能完全依赖于业务的需求 xff0c 但往往又必须要服务于业务的价值 因此 xff0c 在构建数据仓库前 xff0c 我们往往会通过总线矩阵设计 xf
  • 05 反向传播

    反向传播 上一篇博客介绍了从输入 X 样本开始 xff0c 通过一组 w w w 参数 xff0c 得到了一个得分值 xff0c 然后又将得分值经过 y 61
  • 07 神经网络整体架构

    神经网络整体架构 我们先看看神经网络是什么样子的 xff0c 如下图 可以说神经网络是一个层次的结构 xff0c 有一个输入层 xff0c 隐层 1 xff0c 隐层 2 和输出层 可以说是由多个层组成了一个完整的神经网络 输入层相当于输入
  • 地震勘探原理(一)之地震波的基本概念

    绪论 一 石油勘探的主要方法 地质法 岩石露头物探法 覆盖区 连续测量 间接钻井法 一点 直接勘探 二 地球物理勘探方法 重力法 岩石密度差异磁法 演示磁性差异电法 岩石电性差异地震勘探方法 岩石弹性差异 xff08 用得最多 xff0c
  • 地震勘探原理(二)之时距曲线

    文章目录 什么是时距曲线 xff1f 直达波的时距曲线水平界面的共炮点反射波时距曲线方程 xff08 一个分界面 xff09 倾斜界面的共炮点反射波时距曲线正常时差倾角时差 xff08 dip moveout xff09 时局曲面和时间场的
  • 地震勘探原理(四)之频谱分析概述

    文章目录 一 频谱的基本概念二 频谱的主要特征 振幅谱和相位谱三 获取频谱的方法四 傅里叶展式的重要性质五 地震波频谱特征及其应用六 线性时不变系统的滤波方程七 频率滤波参数的选择 一 频谱的基本概念 频谱 xff08 Spectrum S
  • 模糊C均值聚类算法

    学习了一下模糊聚类中的模糊 C 均值聚类算法 Fuzzy C Means Clustering Fuzzy 意为模糊 xff0c 其中包括几种模糊的方式 xff0c 这里使用的是最简单的方式 xff0c 它是基于概率的概念 我们把每一个点属
  • 数据建模之查文献找数据以及数据预处理

    1 查文献 知网 xff1a 先看硕博士论文谷歌学术镜像 xff1a http scholar scqylaw com Open Access Library xff1a https www oalib com 2 找数据 优先 xff1a
  • 数学建模之论文

    一篇完整的数模论文 包括摘要 最重要 问题重述 模型假设和符号说明 模型建立与求解 最长 模型的优缺点与改进方法 参考文献和附录 1 摘要 最重要 论文研究的问题 43 使用的方法 43 得到的结果 43 每一部分的大致步骤 2 问题重述
  • Deformable Convolution 可变形卷积

    可变形卷积概念出自2017年论文 xff1a Deformable Convolutional Networks 顾名思义 xff0c 可变形卷积的是相对于标准卷积的概念而来 a 一个经典的 3 3 3 times3 3 3
  • 模块化

    模块化 遵守固定的规则 xff0c 把一个 大文件 拆成 独立并互相依赖 的 多个小模块 优点 xff1a 提高了代码的 复用性 提高了代码的 可维护性 可以实现 按需加载 模块化规范 xff1a 降低沟通成本 xff0c 方便模块间的相互
  • 栈的应用:左右符号匹配

    说明 xff1a 在编译器中 xff0c 都有这么一个左右符号匹配的功能 xff0c 这里通过栈来模拟实现这一功能 xff1b 这里采用了代码复用的方法 xff0c 即使用了LinkStack链栈 xff0c 详见 LinkStack链栈
  • windows安装gcc

    完整报错 xff1a RuntimeError Error building extension 39 fused 39 1 3 C Program Files NVIDIA GPU Computing Toolkit CUDA v11 4
  • openstack-mitaka(一) 架构简介

    官网 xff1a OpenStack Docs 概况 1 openstack概况 OpenStack是一个云操作系统 xff0c 它控制整个数据中心的计算 存储和网络资源的大型池 OpenStack通过各种补充服务提供基础设施即服务 Inf
  • ITK和VTK读取DICOM图像文件

    ITK和VTK读取DICOM图像文件 ITK读取DICOM图像 相比于VTK类库中vtkDICOMImageReader类读取DICOM序列图像 xff0c 借助ITK类库实现对DICOM序列图像的读取要复杂许多 但是 xff0c 使用IT
  • 一招完美解决vscode安装go插件失败问题

    vscode 安装go插件 前置用vscode新建一个go文件使用go mod 代理来安装 前置 从https studygolang com dl下载go1 14 6 windows amd64 msi安装即可 xff0c 安装路径选择默

随机推荐