在 n 维数组上使用 scipy interpn 和 meshgrid

2024-03-29

我正在尝试翻译大型 4D 数组的 Matlab“interpn”插值,但 Matlab 和 Python 之间的公式存在显着差异。几年前有一个很好的问题/答案here https://stackoverflow.com/questions/39332053/using-scipy-interpolate-interpn-to-interpolate-a-n-dimensional-array/39357219#39357219我一直在尝试与之合作。我想我已经快到了,但显然我的网格插值器还没有正确制定。

我尽可能按照上面链接答案中给出的示例建模我的代码示例,同时使用我实际工作的维度。唯一的变化是我将 rollaxis 切换为 moveaxis,因为前者已被弃用。

本质上,给定 4D 数组 skyrad0 (取决于第一个代码块中定义的四个元素)以及第三个代码块中定义的两个常量和两个 1D 数组,我想要插值的 2D 结果。

from scipy.interpolate import interpn
import numpy as np

# Define the data space in the 4D skyrad0 array
solzen = np.arange(0,70,10)     # 7
aod = np.arange(0,0.25,0.05)    # 5
index = np.arange(1,92477,1)    # 92476
wave = np.arange(350,1050,5)    # 140

# Simulated skyrad for the values above
skyrad0 = np.random.rand(
    solzen.size,aod.size,index.size,wave.size) # 7, 5, 92476, 140

# Data space for desired output values of skyrad 
# with interpolation between input data space
solzen0 = 30                    # 1
aod0 = 0.1                      # 1
index0 = index                  # 92476
wave0 = np.arange(350,1050,10)  # 70

# Matlab
# result = squeeze(interpn(solzen, aod, index, wave,
#                   skyrad0,
#                   solzen0, aod0, index0, wave0))

# Scipy
points = (solzen, aod, index, wave)             # 7, 5, 92476, 140
interp_mesh = np.array(
    np.meshgrid(solzen0, aod0, index0, wave0))  # 4, 1, 1, 92476, 70
interp_points = np.moveaxis(interp_mesh, 0, -1) # 1, 1, 92476, 70, 4
interp_points = interp_points.reshape(
    (interp_mesh.size // interp_mesh.shape[3], 
    interp_mesh.shape[3]))                      # 280, 92476

result = interpn(points, skyrad0, interp_points)

我期待一个 4D 数组“结果”,我可以将其 numpy.squeeze 放入我需要的 2D 答案中,但 interpn 会产生错误:

ValueError: The requested sample points xi have dimension 92476, but this RegularGridInterpolator has dimension 4

在这个例子中,我最困惑的是查询点网格的结构,以及将第一个维度移动到末尾并重塑它。还有更多关于这方面的内容here https://stackoverflow.com/questions/27286537/numpy-efficient-way-to-generate-combinations-from-given-ranges/27286794#27286794,但我仍然不清楚如何将其应用于这个问题。

如果有人能找出我的公式中明显的低效之处,那就太好了。我需要在许多不同的结构上运行这种类型的插值数千次 - 甚至扩展到 6D - 因此效率很重要。

Update下面的答案非常优雅地解决了这个问题。然而,随着计算和数组变得更加复杂,另一个问题出现了,即数组中的元素不单调增加的问题。这是用 6D 重构的问题:

# Data space in the 6D rad_boa array
azimuth = np.arange(0, 185, 5) # 37
senzen = np.arange(0, 185, 5) # 37
wave = np.arange(350,1050,5)    # 140
# wave = np.array([350, 360, 370, 380, 390, 410, 440, 470, 510, 550, 610, 670, 750, 865, 1040, 1240, 1640, 2250]) # 18
solzen = np.arange(0,65,5)     # 13
aod = np.arange(0,0.55,0.05)    # 11
wind = np.arange(0, 20, 5)      # 4

# Simulated rad_boa
rad_boa = np.random.rand(
    azimuth.size,senzen.size,wave.size,solzen.size,aod.size,wind.size,) # 37, 37, 140/18, 13, 11, 4

azimuth0 = 135              # 1
senzen0 = 140               # 1
wave0 = np.arange(350,1010,10) # 66
solzen0 = 30                # 1
aod0 = 0.1                  # 1
wind0 = 10                  # 1

da = xr.DataArray(name='Radiance_BOA',
                data=rad_boa,
                dims=['azimuth','senzen','wave','solzen','aod','wind'],
                coords=[azimuth,senzen,wave,solzen,aod,wind])

rad_inc_scaXR = da.loc[azimuth0,senzen0,wave0,solzen0,aod0,wind0].squeeze()

按照目前的情况,它会运行,但是如果将 wave 的定义更改为注释行,则会抛出错误:

KeyError: "not all values found in index 'wave'"

最后,为了回应下面的评论(并帮助提高效率),我包含了 HDF5 文件的结构(在 Matlab 中创建),该“rad_boa”6D 数组实际上是从该文件构建的(上面的示例仅使用模拟随机大批)。实际数据库读入Xarray如下:

sdb = xr.open_dataset(db_path, group='sdb')

And the resulting Xarray looks something like this: Xarray VSC guide


为什么会出现值错误?

首先,scipy.interpolate.interpn要求interp_points.shape[-1]与问题中的维数相同。这就是为什么你会得到一个ValueError从你的代码片段——你的interp_points有 92476 作为n_dims,这与实际的 sim 数量 (4) 冲突。

快速解决

您只需更改操作顺序即可修复此代码片段。如果你挤压的话,你尝试挤压的太早了after插值:

points = (solzen, aod, index, wave)                 # 7, 5, 92476, 140
mg = np.meshgrid(solzen0, aod0, index0, wave0)      # 4, 1, 1, 92476, 70
interp_points = np.moveaxis(mg, 0, -1)              # 1, 1, 92476, 70, 4
result_presqueeze = interpn(points, 
                            skyrad0, interp_points) # 1, 1, 92476, 70
result = np.squeeze(result_presqueeze,
                    axis=(0,1))                     # 92476, 70

我已经更换了interp_mesh with mg在这里,并删除了np.array(这不是必需的,因为np.meshgrid返回一个ndarray目的)。

性能评价

我认为你的代码片段很好,但是你可能希望使用xarray如果您正在处理标记数据,则如下所示:

  • 比无标签更具可读性numpy arrays
  • 还可以使用处理一些后台工作dask https://dask.org/(如果您正在检查大量 6D 数据,则很有用)

Update: 哎呀!这本来应该是.interp, not .loc。下面的代码片段之所以有效,是因为数据点实际上是原始数据点。作为对其他人的警告:

from scipy.interpolate import interpn
import numpy as np
from xarray import DataArray

# Define the data space in the 4D skyrad0 array
solzen = np.arange(0,70,10)     # 7
aod = np.arange(0,0.25,0.05)    # 5
index = np.arange(1,92477,1)    # 92476
wave = np.arange(350,1050,5)    # 140

# Simulated skyrad for the values above
skyrad0 = np.random.rand(
    solzen.size,aod.size,index.size,wave.size) # 7, 5, 92476, 140

# Data space for desired output values of skyrad 
# with interpolation between input data space
solzen0 = 30                    # 1
aod0 = 0.1                      # 1
index0 = index                  # 92476
wave0 = np.arange(350,1050,10)  # 70

def slow():
    points = (solzen, aod, index, wave)                 # 7, 5, 92476, 140
    mg = np.meshgrid(solzen0, aod0, index0, wave0)      # 4, 1, 1, 92476, 70
    interp_points = np.moveaxis(mg, 0, -1)              # 1, 1, 92476, 70, 4
    result_presqueeze = interpn(points, 
                                skyrad0, interp_points) # 1, 1, 92476, 70
    result = np.squeeze(result_presqueeze,
                        axis=(0,1))                     # 92476, 70
    return result

# This function uses .loc instead of .interp!
"""
def fast():
    da = DataArray(name='skyrad0',
                   data=skyrad0,
                   dims=['solzen','aod','index','wave'],
                   coords=[solzen, aod, index, wave])

    result = da.loc[solzen0, aod0, index0, wave0].squeeze()

    return result
"""

通过对OP给出的更新代码片段进行一些修改:

import numpy as np
import xarray as xr
from scipy.interpolate import interpn

azimuth = np.arange(0, 185, 5) # 37
senzen = np.arange(0, 185, 5) # 37
#wave = np.arange(350,1050,5)    # 140
wave = np.asarray([350, 360, 370, 380, 390, 410, 440, 470, 510,
                   550, 610, 670, 750, 865, 1040, 1240, 1640, 2250]) # 18
solzen = np.arange(0,65,5)     # 13
aod = np.arange(0,0.55,0.05)    # 11
wind = np.arange(0, 20, 5)      # 4

coords = [azimuth, senzen, wave, solzen, aod, wind]

azimuth0 = 135              # 1
senzen0 = 140               # 1
wave0 = np.arange(350,1010,10) # 66
solzen0 = 30                # 1
aod0 = 0.1                  # 1
wind0 = 10                  # 1

interp_coords = [azimuth0, senzen0, wave0, solzen0, aod0, wind0]

# Simulated rad_boa
rad_boa = np.random.rand(
    *map(lambda x: x.size, coords)) # 37, 37, 140/18, 13, 11, 4

def slow():
    mg = np.meshgrid(*interp_coords)
    interp_points = np.moveaxis(mg, 0, -1)
    result_presqueeze = interpn(coords, 
                                rad_boa, interp_points)
    result = np.squeeze(result_presqueeze)
    return result

def fast():
    da = xr.DataArray(name='Radiance_BOA',
                    data=rad_boa,
                    dims=['azimuth','senzen','wave','solzen','aod','wind'],
                    coords=coords)

    interp_dict = dict(zip(da.dims, interp_coords))

    rad_inc_scaXR = da.interp(**interp_dict).squeeze()
    return rad_inc_scaXR

这相当快:

>>> %timeit slow()
2.09 ms ± 85.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
>>> %timeit fast()
343 ms ± 6.77 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> np.array_equal(slow(),fast())
True

您可以找到更多关于xarray插值法here http://xarray.pydata.org/en/stable/interpolation.html。数据集实例具有非常相似的语法。

也可以根据需要更改插值方法(也许,人们可能希望提供关键字参数method='nearest' to .interp对于离散插值问题)。

更高级的东西

如果您希望实现更高级的东西,我建议也许使用 MARS(多元自适应回归样条)的实现之一。它介于标准回归和插值之间,适用于多维情况。在 Python 3 中,你最好的选择是pyearth https://github.com/scikit-learn-contrib/py-earth.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 n 维数组上使用 scipy interpn 和 meshgrid 的相关文章

随机推荐

  • 通过按同一个按钮来打开/关闭 MapKit 叠加?

    我有一个带有工具栏按钮的 MapView 按下该按钮时会向 MapView 添加叠加层 我想要的是按钮 IBAction 检查地图上是否已经有覆盖物 如果有 则删除 如果没有 则添加它们 我当前添加叠加层的代码如下 IBAction wat
  • JacksonFeature 破坏了 JsonIgnoreProperties

    我有这样的 pojo JsonIgnoreProperties ignoreUnknown true public class SNAPIResponse public String status public String message
  • Keras:从 ImageDataGenerator 或 Predict_generator 获取真实标签 (y_test)

    我在用ImageDataGenerator flow from directory 从目录生成批量数据 模型成功构建后 我想获得真实和预测类标签的两列数组 和model predict generator validation genera
  • 在mawk中使用strftime函数

    我正在尝试创建 AWK 脚本 该脚本将根据某种模式过滤输入文件 并使用 strftime 函数进行一些计算 2 HB 2 n print strftime Y 使用的解释器是mawk 使用此命令触发此脚本时 awk f script3 in
  • 使用curl将文件推送到GitHub存储库

    我想在 GitHub 存储库上创建 推送 新文件 而不需要git工具 因为git我的工具不可用PHP主持人 所以我做了一些研究 发现GitHub REST API https docs github com en rest 我尝试使用cur
  • 电池的最佳使用

    作为一名程序员 我可以采取哪些措施来确保我的应用程序不会占用大量资源并耗尽电池 根据您正在编写的应用程序 其中一些可能适用于您 不要使用过多的网络调用 尝试维护不经常更改的数据缓存 并且仅在上次刷新 10 秒后运行完全刷新 阻止它们向服务器
  • SQLite Swift 中有多少种方式进行 CRUD 操作?

    当我在 SQLite 中进行 CRUD 操作时 我很困惑 因为有人对我说你可以使用 FMDB 库进行 CRUD 操作 有人说 GRDB 所以 我的问题是 在 SQLite 中有多少种方法可以进行 CRUD 操作 哪种方法是正确的 我认为 G
  • 如何在 Jquery 验证中处理 html 元素 id/name 中的特殊字符?

    我有一个 HTML 表单 它在 ids 中使用特殊字符 该表单使用 JQuery 验证插件来验证用户输入 具体来说 id 包括 GUID 以下是示例代码
  • 在 Eclipse 中,Source -> Format 在“Maven Pom Editor”中被禁用

    当打开pom xml在 Eclipse 中使用 Maven Pom Editor 并切换到选项卡pom xml我无法格式化该文件 Hitting Ctrl Shift F在完全未格式化的文件中不会执行任何操作 通过上下文菜单时Source
  • Python 中的递归、记忆和可变默认参数

    Base 的意思是不只使用lru cache 所有这些都 足够快 我并不是在寻找最快的算法 但时间安排让我感到惊讶 所以我希望我能了解一些有关 Python 如何 工作 的知识 简单循环 尾递归 def fibonacci n a b 0
  • Flask 应用偶尔挂起

    我一直在开发一个 Flask 应用程序 它使用 Twilio 处理 SMS 消息 将它们存储在数据库中 并通过 JSONP GET 请求提供对前端的访问 我已经使用supervisord对其进行了守护进程 这似乎工作得很好 但每隔几天它就会
  • Erlang / Golang 端口示例中的缓冲区大小

    我有一个粗略的 Erlang to Golang 端口示例 将数据从 Erlang 传递到 Golang 并回显响应 问题是我可以传输的数据量似乎仅限于 2 8 字节 见下文 我认为问题可能出在 Golang 方面 没有创建足够大的缓冲区
  • JavaScript 中的继承和 Super

    我正在学习 JavaScript 的第三天 我遇到了这段代码 class B constructor name this name name printn return this name class A extends B constru
  • ajaxForm 错误回调内的表单对象

    我试图在 ajaxForm 的错误方法中访问我的表单对象 foo ajaxForm error function where s my foo object error 可以接受 3 个参数 但它们都不是表单对象 这也返回 url 但同样没
  • 为什么 CSS Grid 的自动填充属性在列方向上不起作用

    我正在练习用行自动填充属性 但是 它并没有按照我的意愿进行 我想创建具有高度的行minmax 140px 200px 而是获取一行高度为 200px 的行 其余行为 18px 为什么会发生这种情况 body html height 100
  • 使用ajax上传文件到远程服务器

    我对服务器端没有任何控制权 是否可以在 Iframe 中上传并加载远程服务器给出的结果 请分享一些代码 谢谢 使用名称声明 iframe 并在表单元素中定位该名称
  • 调整大小和滚动问题(JS/HTML)

    有两个容器 第一个是小视口 第二个是巨大的工作区 因此 用户滚动视口以在工作区中移动 我想通过 CSS 属性实现放大 缩小功能tranform 但是在这个过程中我遇到了一个难题 并没有找到精确的解决方案 问题是 当用户放大 缩小时 工作区中
  • 带有 @MappedSuperclass 的 Hibernate TABLE_PER_CLASS 不会创建 UNION 查询

    我正在尝试创建一系列对象 这些对象全部存储在单独的表中 但所有这些表上都有一组共同的字段 我希望 Hibernate 对所有这些表进行 UNION 但不包括超类作为表 当我用以下方式注释超类时 MappedSuperclass Inheri
  • 插入、删除、最大值 O(1)

    有人能告诉我哪种数据结构支持 O 1 的插入 删除 最大操作吗 这是一个经典的面试问题 通常是这样提出的 设计一个类似堆栈的数据结构 在 O 1 时间内执行压入 弹出和最小 或最大 操作 没有空间限制 答案是 您使用两个堆栈 主堆栈和最小
  • 在 n 维数组上使用 scipy interpn 和 meshgrid

    我正在尝试翻译大型 4D 数组的 Matlab interpn 插值 但 Matlab 和 Python 之间的公式存在显着差异 几年前有一个很好的问题 答案here https stackoverflow com questions 39