numpy 向量化而不是 for 循环

2024-05-30

我用 Python 写了一些代码,运行良好,但速度很慢;我认为是由于 for 循环。我希望可以使用 numpy 命令加速以下操作。让我定义目标。

假设我有一个 2D numpy 数组all_CMs尺寸row x col。例如考虑一个6x11数组(见下图)。

  1. 我想计算所有行的平均值,即sumⱼ aᵢⱼ 生成一个数组。这当然可以轻松完成。 (我称这个值为CM_tilde)

  2. 现在,为了each row我想计算一些选定值的平均值,即通过计算它们的总和并将其除以所有列的数量来计算低于特定阈值的所有值(N)。如果该值高于此定义的阈值,CM_tilde添加值(整行的平均值)。这个值称为CM

  3. 随后,CM从行中的每个元素中减去值

除此之外,我想要一个 numpy 数组或列表,其中所有这些CM列出了值。

如图:

以下代码可以工作,但速度非常慢(特别是当数组变大时)

CM_tilde = np.mean(data, axis=1)
N = data.shape[1]
data_cm = np.zeros(( data.shape[0], data.shape[1], data.shape[2] ))
all_CMs = np.zeros(( data.shape[0], data.shape[2]))
for frame in range(data.shape[2]):
    for row in range(data.shape[0]):
        CM=0
        for col in range(data.shape[1]):
            if data[row, col, frame] < (CM_tilde[row, frame]+threshold):
               CM += data[row, col, frame]
            else:
               CM += CM_tilde[row, frame]
        CM = CM/N
        all_CMs[row, frame] = CM
        # calculate CM corrected value
        for col in range(data.shape[1]):
            data_cm[row, col, frame] = data[row, col, frame] - CM
    print "frame: ", frame
return data_cm, all_CMs

有任何想法吗?


将您正在做的事情矢量化非常容易:

import numpy as np

#generate dummy data
nrows=6
ncols=11
nframes=3
threshold=0.3
data=np.random.rand(nrows,ncols,nframes)

CM_tilde = np.mean(data, axis=1)
N = data.shape[1]

all_CMs2 = np.mean(np.where(data < (CM_tilde[:,None,:]+threshold),data,CM_tilde[:,None,:]),axis=1)
data_cm2 = data - all_CMs2[:,None,:]

将此与您的原件进行比较:

In [684]: (data_cm==data_cm2).all()
Out[684]: True

In [685]: (all_CMs==all_CMs2).all()
Out[685]: True

逻辑是我们使用大小的数组[nrows,ncols,nframes]同时地。主要技巧是利用Python的广播,通过转动CM_tilde大小的[nrows,nframes] into CM_tilde[:,None,:]大小的[nrows,1,nframes]。然后,Python 将为每一列使用相同的值,因为这是此修改后的单一维度CM_tilde.

通过使用np.where我们选择(基于threshold) 是否要获取对应的值data,或者,再次,广播值CM_tilde。一个新的用途np.mean允许我们计算all_CMs2.

在最后一步中,我们通过直接减去这个新的来利用广播all_CMs2从相应的元素data.

通过查看临时变量的隐式索引,可能有助于以这种方式矢量化代码。我的意思是你的临时变量CM生活在一个循环中[nrows,nframes],并且其值在每次迭代时都会重置。这意味着CM实际上是一个数量CM[row,frame](后来显式分配给二维数组all_CMs),从这里很容易看出,您可以通过总结适当的CMtmp[row,col,frames]沿其列尺寸的数量。如果有帮助,您可以命名np.where(...)部分作为CMtmp为此目的,然后计算np.mean(CMtmp,axis=1)从那。显然,结果相同,但可能更透明。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

numpy 向量化而不是 for 循环 的相关文章

随机推荐

  • Android 工具测试库模块覆盖率

    我继承了一个android项目来设置代码覆盖率 由于我对 android 没有做过太多的工作 在 gradle 中也几乎没有做过多少工作 所以我开始寻找有用的教程 令人惊讶的是 前几个教程非常有帮助 我能够包含 jacoco gradle
  • Pandas 根据条件替换数据框值

    我有一个主数据框 df Colour Item Price Blue Car 40 Red Car 30 Green Truck 50 Green Bike 30 然后我有一个价格修正数据框 df pc Colour Item Price
  • 如何在Phone类库项目中添加ResourceDictionary并访问它

    我正在开发一个项目 其中我有一个引用图书馆项目的子项目 在我的库项目 电话类库 中 如何创建 ResourceDictionary xaml 其中我需要添加一些样式并在 xaml 文件和 cs 文件中使用它 我需要访问 xaml 文件中的
  • 类型错误:无法连接“str”和“int”对象有人可以帮助新手使用他们的代码吗?

    感谢任何帮助 还有任何重大缺陷或您在格式或基本方面看到的任何重大缺陷 请指出 谢谢 day raw input How many days locations raw input Where to days str day location
  • 使用 InputStream 通过 TCP 套接字接收多个图像

    每次我从相机捕获图像时 我试图将多个图像自动从我的 Android 手机一张一张地发送到服务器 PC 问题是read 函数仅在第一次时阻塞 因此 从技术上讲 只有一张图像被接收并完美显示 但在那之后当is read 回报 1 该功能不阻塞
  • 在 Keras 中使用有状态 LSTM 训练多变量多级数回归问题

    我有时间序列P过程 每个过程的长度各不相同 但都有 5 个变量 维度 我试图预测测试过程的估计寿命 我正在用有状态的方法来解决这个问题LSTM在喀拉斯 但我不确定我的训练过程是否正确 我将每个序列分成长度的批次30 所以每个序列都是这样的形
  • 使用mockery和sinon模拟类方法

    我正在学习使用带有 sinon 的节点模块模拟进行单元测试 仅使用模拟和普通类 我就可以成功注入模拟 不过 我想注入一个 sinon 存根而不是一个普通的类 但我在这方面遇到了很多麻烦 我试图嘲笑的班级 function LdapAuth
  • 批量电子邮件仅限 80 封电子邮件 (GMAIL)?

    Gmail 在此处列出了其电子邮件限制 https support google com a answer 166852 hl en https support google com a answer 166852 hl en 但是 我收到
  • 当 C 中没有足够的内存用于静态分配时会发生什么?

    当您动态分配内存时 例如malloc 1024 sizeof char 结果指针设置为NULL如果没有足够的可用内存来满足请求 当没有足够的内存来满足静态分配时会发生什么 例如char c 1024 char c 1024 不一定是静态分配
  • Intern JS - 如何在链式 Command 方法中使用 Promise.all()?

    我是用 Intern JS 编写测试的新手 并且一直在遵循他们的文档来使用对象接口 https theintern github io intern interface object and 页面对象 https theintern git
  • Jquery Ajax 调用返回 403 状态

    我有一个 jquery Ajax 调用来实现会话的 keepalive 这个 keepAlive 方法将每 20 分钟调用一次 function keepAlive ajax type POST url KeepAliveDummy asp
  • Sql 查询:Sum,表中所有可能的行组合

    SQL Server 2008 R2 表结构示例 create table TempTable ID int identity value int insert into TempTable values 6 insert into Tem
  • Java小程序找不到JavaPOS配置文件

    我创建了一个小程序 它使用 JavaPOS 与用户本地系统上的支付终端进行通信 当从 Eclipse IDE 中运行时 该小程序可以正常工作 但在浏览器中运行时则不然 在浏览器中 小程序似乎找不到 jpos res jpos propert
  • 如何使breezejs所需的验证器允许空字符串

    在breezejs中允许所需属性中存在空字符串的首选方式是什么 I found 这个答案 https stackoverflow com questions 19658297 how does breeze saves empty stri
  • 从多个表中选择 - 一对多关系

    我有这样的表 表产品 身份证 姓名 表格图像 产品 ID 网址 订单号 表价 产品 ID 组合 货币 价格 表数量 产品 ID 组合 数量 表 Product 与其他表是一对多关系 我需要查询表并得到类似这样的结果 伪数组 ProductI
  • 用于一个自定义字段的 Jackson 反序列化器?

    我相信我们需要一个自定义反序列化器来对我们类中的一个字段执行特定的操作 看来一旦我这样做了 我现在就负责反序列化所有其他字段 有没有办法让杰克逊反序列化所有字段except我在这里关心的那个人 public class ThingDeser
  • geom_polygon 的渐变填充

    此代码生成一个包含 3 个多边形的图表 我正在创建一个显示 3 个多边形的图表 如果有更好的方法来绘制多边形 我不太感兴趣 实际上这些多边形代表事件 并且这些事件有一个持续时间 首先 我感兴趣的是使用渐变填充每个多边形的可能性 librar
  • Kendo 刷新 (DropDownList.refresh()) 不起作用错误未定义

    我试图在另一个 DropDownList 更改后刷新下拉列表 但 Refresh 方法未定义错误正在升级 我尝试再次读取数据源 它显示它正在加载 但数据仍然相同 帮助解决这个问题请 Code DropDownList1 change fun
  • 为什么 C++20 范围不只提供管道语法?

    我知道这个问题听起来很奇怪 所以这里有一些背景信息 最近 我很失望地了解到 C 20 范围内的映射缩减并不像人们所期望的那样工作 即 const double val data transform accumulate 不起作用 你必须这样
  • numpy 向量化而不是 for 循环

    我用 Python 写了一些代码 运行良好 但速度很慢 我认为是由于 for 循环 我希望可以使用 numpy 命令加速以下操作 让我定义目标 假设我有一个 2D numpy 数组all CMs尺寸row x col 例如考虑一个6x11数