Numpy 与 Cython 速度

2024-05-31

我有一个分析代码,它使用 numpy 执行一些繁重的数值运算。只是出于好奇,尝试使用 cython 进行少量更改来编译它,然后我使用 numpy 部分的循环重写它。

令我惊讶的是,基于循环的代码要快得多(8 倍)。我无法发布完整的代码,但我整理了一个非常简单的不相关的计算,显示了类似的行为(尽管时间差异不是那么大):

版本 1(不含 cython)

import numpy as np

def _process(array):

    rows = array.shape[0]
    cols = array.shape[1]

    out = np.zeros((rows, cols))

    for row in range(0, rows):
        out[row, :] = np.sum(array - array[row, :], axis=0)

    return out

def main():
    data = np.load('data.npy')
    out = _process(data)
    np.save('vianumpy.npy', out)

版本 2(使用 cython 构建模块)

import cython
cimport cython

import numpy as np
cimport numpy as np

DTYPE = np.float64
ctypedef np.float64_t DTYPE_t

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cdef _process(np.ndarray[DTYPE_t, ndim=2] array):

    cdef unsigned int rows = array.shape[0]
    cdef unsigned int cols = array.shape[1]
    cdef unsigned int row
    cdef np.ndarray[DTYPE_t, ndim=2] out = np.zeros((rows, cols))

    for row in range(0, rows):
        out[row, :] = np.sum(array - array[row, :], axis=0)

    return out

def main():
    cdef np.ndarray[DTYPE_t, ndim=2] data
    cdef np.ndarray[DTYPE_t, ndim=2] out
    data = np.load('data.npy')
    out = _process(data)
    np.save('viacynpy.npy', out)

版本 3(使用 cython 构建模块)

import cython
cimport cython

import numpy as np
cimport numpy as np

DTYPE = np.float64
ctypedef np.float64_t DTYPE_t

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cdef _process(np.ndarray[DTYPE_t, ndim=2] array):

    cdef unsigned int rows = array.shape[0]
    cdef unsigned int cols = array.shape[1]
    cdef unsigned int row
    cdef np.ndarray[DTYPE_t, ndim=2] out = np.zeros((rows, cols))

    for row in range(0, rows):
        for col in range(0, cols):
            for row2 in range(0, rows):
                out[row, col] += array[row2, col] - array[row, col]

    return out

def main():
    cdef np.ndarray[DTYPE_t, ndim=2] data
    cdef np.ndarray[DTYPE_t, ndim=2] out
    data = np.load('data.npy')
    out = _process(data)
    np.save('vialoop.npy', out)

将 10000x10 矩阵保存在 data.npy 中时,时间为:

$ python -m timeit -c "from version1 import main;main()"
10 loops, best of 3: 4.56 sec per loop

$ python -m timeit -c "from version2 import main;main()"
10 loops, best of 3: 4.57 sec per loop

$ python -m timeit -c "from version3 import main;main()"
10 loops, best of 3: 2.96 sec per loop

这是预期的还是我缺少优化?版本 1 和 2 给出相同的结果这一事实在某种程度上是预料之中的,但为什么版本 3 更快呢?

Ps.-这不是我需要进行的计算,只是一个显示相同内容的简单示例。


经过轻微修改,版本 3 的速度提高了一倍:

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
def process2(np.ndarray[DTYPE_t, ndim=2] array):

    cdef unsigned int rows = array.shape[0]
    cdef unsigned int cols = array.shape[1]
    cdef unsigned int row, col, row2
    cdef np.ndarray[DTYPE_t, ndim=2] out = np.empty((rows, cols))

    for row in range(rows):
        for row2 in range(rows):
            for col in range(cols):
                out[row, col] += array[row2, col] - array[row, col]

    return out

计算的瓶颈是内存访问。您的输入数组是 C 排序的,这意味着沿着最后一个轴移动会在内存中产生最小的跳跃。因此,您的内部循环应该沿着轴 1,而不是轴 0。进行此更改可以将运行时间减少一半。

如果您需要在小型输入数组上使用此函数,那么您可以通过使用来减少开销np.empty代替np.ones。为了减少进一步使用的开销PyArray_EMPTY来自 numpy C API。

如果您在非常大的输入数组 (2**31) 上使用此函数,则用于索引的整数(以及在range函数)会溢出。为了安全使用:

cdef Py_ssize_t rows = array.shape[0]
cdef Py_ssize_t cols = array.shape[1]
cdef Py_ssize_t row, col, row2

代替

cdef unsigned int rows = array.shape[0]
cdef unsigned int cols = array.shape[1]
cdef unsigned int row, col, row2

Timing:

In [2]: a = np.random.rand(10000, 10)
In [3]: timeit process(a)
1 loops, best of 3: 3.53 s per loop
In [4]: timeit process2(a)
1 loops, best of 3: 1.84 s per loop

where process你的版本是3.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Numpy 与 Cython 速度 的相关文章

  • 如何将 typeshed 与 mypy 一起使用?

    我克隆了typeshed https github com python typeshed但我不知道如何告诉 mypy 使用它包含的类型提示 我在 mypy help 中没有看到任何选项 mypy 存储库确实包含对 typeshed 存储库
  • 我如何知道Python的unicode函数识别的所有支持的编码

    Python 有一个unicode将字节流转换为 unicode 字符串的内置函数 我只是希望我能查询所有可用的encoding在我的系统上 但如何 这个问题的原因是 有人使用 MAC OS X 向我发送了一封内容编码为 iso 2022
  • 如何在嵌套列表中查找给定元素?

    这是我的迭代解决方案 def exists key arg if not arg return False else for element in arg if isinstance element list for i in elemen
  • 为什么Flask后台线程获取错误的数据库信息?

    为了将实时数据库信息推送到客户端 我在服务器端使用flask socketio 通过使用websocket将所有实时数据库信息推送到客户端 我的视图文件有一个片段 from models import Host from flask soc
  • 使用 cx_oracle 返回 MERGE 中受影响的行数

    如何在 CX Oracle 中执行 MERGE INTO sql 命令来获取受影响的行数 当我在cx oracle 上执行MERGE SQL 时 我得到的cursor rowcount 为 1 有没有办法获取受合并影响的行数 由于 cx o
  • 使用 Python 访问内存映射文件

    我希望利用激战 2 中的内存映射文件 该文件旨在链接到 Mumble 以获得位置音频 该文件包含有关字符坐标的信息和其他有用的信息 我已经能够使用此脚本访问坐标信息 import mmap import struct last while
  • 映射 2 个数据帧并替换目标数据帧中匹配值的标头

    我有一个数据框 df1 SAP Name SAP Class SAP Sec Avi 5 C Rison 6 A Slesh 7 B San 8 C Sud 7 B df2 Name Fi Class Avi 5 Rison 6 Slesh
  • 为什么我的字符串中出现不需要的换行符?

    这应该很简单 这很愚蠢 但我无法让它发挥作用 我有一个在读取文件时定义的标头 if gene env in line or gene HIV2gp7 in line header line 现在这个标题看起来像 gt lcl NC 0018
  • Python Jinja2 调用宏会导致(不需要的)换行符

    我的 JINJA2 模板如下所示 macro print if john name if name John Hi John endif endmacro Hello World print if john Foo print if joh
  • 在 Javascript 中实现 Zobrist 哈希

    我需要在 Javascript 中为国际象棋引擎实现 Zobrist 哈希 我想知道实现此目的的最佳方法是什么 现在 我不是计算机科学家 也从未上过正式的算法和数据结构课程 所以如果我在这方面有点偏离 我很抱歉 据我了解 我需要一个 64
  • 使用 SQLAlchemy 查询 Pandas DataFrame 时重命名列

    当您将数据查询到 pandas 数据帧时 有没有办法保留 SqlAlchemy 属性名称 这是我的数据库的简单映射 对于 school 表 我将数据库名称 SchoolDistrict 重命名为较短的 district 我从 DBA 中删除
  • Python父类访问子私有变量

    以下代码会生成错误 class A object def say something self print self foo print self bar class B A def init self self foo hello sel
  • python请求ssl握手失败

    每次我尝试这样做 requests get https url 我收到这条消息 import requests gt gt gt requests get https reviews gethuman com companies Trace
  • 如何忽略 Sentry 捕获中的某些 Python 错误

    我已将 Sentry 配置为捕获 Django Celery 应用程序中的所有错误 它工作正常 但我发现一个令人讨厌的用例是当我必须重新启动我的 Celery 工作人员 PostgreSQL 数据库或消息服务器时 这会导致数千种各种 无法访
  • Python httplib 和 POST

    我目前正在使用别人编写的一段代码 它用httplib向服务器发出请求 它以正确的格式提供所有数据 例如消息正文 标头值等 问题是 每次尝试发送 POST 请求时 数据都在那里 我可以在客户端看到它 但没有任何内容到达服务器 我已经阅读了库规
  • 从Python列表中挑选出具有特定索引的项目

    我确信在 Python 中有一种很好的方法可以做到这一点 但我对这门语言还很陌生 所以如果这是一个简单的方法 请原谅我 我有一个列表 我想从该列表中挑选某些值 我想要挑选的值是列表中索引在另一个列表中指定的值 例如 indexes 2 4
  • 如何使用 opencv python 根据检测到的物体的位置生成其热图

    我需要根据对象的位置生成其热图 示例 视频帧中检测到的绿色球 如果它长时间停留在某个位置 那么该位置应该是红色的 并且球在短时间内经过的帧中的位置必须是蓝色的 这样我就需要生成热图 提前致谢 那么你在这里可以做的是 1 首先定义一个热图作为
  • tkinter 库 treectrl 转换为 exe 安装程序时出现 cx_freeze 错误

    我使用的是 python 版本 3 7 我使用了这个名为 treectrl 的外部库 当我运行 py 文件时它工作得很好 但是当我使用 cx freeze 转换为 exe 文件时 它给了我错误 NomodulleFound 名为 tkint
  • pandas groupby 中两个系列的最大值和最小值

    是否可以从 groupby 中的两个系列中获取最小值和最大值 例如下面的情况 分组时c 我怎样才能得到最小值和最大值a and b同时 df pd DataFrame a 10 20 3 40 55 b 5 14 8 50 60 c x x
  • Maya python 连接选择的属性

    我一直在尝试制作一个简单的脚本 它将采用两个视口选择 然后基本上将第二个视口的旋转连接到第一个 我不确定如何正确地从视口选择中为对象创建变量 这是我的尝试 但不起作用 import maya cmds as cmds sel cmds ls

随机推荐