Numba 和 numpy 数组分配:为什么这么慢?

2024-03-09

我最近使用 Cython 和 Numba 来加速进行数值模拟的 python 的小片段。起初,使用 numba 进行开发似乎更容易。然而,我发现很难理解 numba 何时会提供更好的性能,何时不会。

意外性能下降的一个例子是当我使用该函数时np.zeros()在编译函数中分配一个大数组。例如,考虑三个函数定义:

import numpy as np 
from numba import jit 

def pure_python(n):
    mat = np.zeros((n,n), dtype=np.double)
    # do something
    return mat.reshape((n**2))

@jit(nopython=True)
def pure_numba(n):
    mat = np.zeros((n,n), dtype=np.double)
    # do something
    return mat.reshape((n**2))

def mixed_numba1(n):
    return mixed_numba2(np.zeros((n,n)))

@jit(nopython=True)

def mixed_numba2(array):
    n = len(array)
    # do something
    return array.reshape((n,n))

# To compile 
pure_numba(10)
mixed_numba1(10)

自从#do something是空的,我不期望pure_numba功能变得更快。然而,我没想到性能会如此下降:

n=10000
%timeit x = pure_python(n)
%timeit x = pure_numba(n)
%timeit x = mixed_numba1(n)

我获得(Mac上的python 3.7.7,numba 0.48.0)

4.96 µs ± 65.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
344 ms ± 7.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.8 µs ± 30.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

在这里,当我使用该函数时,numba 代码要慢得多np.zeros()在已编译的函数内。当np.zeros()位于函数之外。

我在这里做错了什么,还是应该总是分配大数组,比如这些由 numba 编译的外部函数?

Update

这似乎与矩阵的延迟初始化有关np.zeros((n,n)) when n足够大(参见Numpy 中的 Zeros 函数的性能 https://stackoverflow.com/questions/44487786/performance-of-zeros-function-in-numpy ).

for n in [1000, 2000, 5000]:
    print('n=',n)
    %timeit x = pure_python(n)
    %timeit x = pure_numba(n)
    %timeit x = mixed_numba1(n)

给我:

n = 1000
468 µs ± 15.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
296 µs ± 6.55 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
300 µs ± 2.26 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
n = 2000
4.79 ms ± 182 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.45 ms ± 36 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.54 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
n = 5000
270 µs ± 4.66 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
104 ms ± 599 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
119 µs ± 1.24 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

tl;dr Numpy 使用 C 内存函数,而 Numba 必须分配零

我编写了一个脚本来绘制几个选项完成所需的时间,当选项的大小变大时,Numba 的性能似乎会严重下降。np.zeros数组到达2048*2048*8 = 32 MB在我的机器上如下图所示。

Numba 的实施np.zeros与创建一个空数组并通过迭代数组的维度来用零填充它一样快(这是Numba 嵌套循环图中的绿色曲线)。实际上可以通过设置来双重检查NUMBA_DUMP_IR运行脚本之前的环境变量(见下文)。与转储进行比较时numba_loop没有太大区别。

有趣的是,np.zeros通过 32 MB 阈值获得了一点提升。

尽管我远非专家,但我的最佳猜测是 32 MB 限制是操作系统或硬件瓶颈,来自同一进程的缓存中可以容纳的数据量。如果超过这个值,将数据移入和移出缓存来对其进行操作的操作将非常耗时。

相比之下,Numpy 使用calloc https://github.com/numpy/numpy/blob/590facafcbd454fc02f96aa52c1c3412d6cccd14/numpy/core/src/multiarray/alloc.c#L212-L232获取一些内存段,并承诺在访问数据时用零填充数据。

这就是我所取得的进展,我意识到这只是答案的一半,但也许知识渊博的人可以阐明实际发生的情况。

Numba IR 转储:

---------------------------IR DUMP: pure_numba_zeros----------------------------
label 0:
    n = arg(0, name=n)                       ['n']
    $2load_global.0 = global(np: <module 'numpy' from '/lib/python3.8/site-packages/numpy/__init__.py'>) ['$2load_global.0']
    $4load_attr.1 = getattr(value=$2load_global.0, attr=zeros) ['$2load_global.0', '$4load_attr.1']
    del $2load_global.0                      []
    $10build_tuple.4 = build_tuple(items=[Var(n, script.py:15), Var(n, script.py:15)]) ['$10build_tuple.4', 'n', 'n']
    $12load_global.5 = global(np: <module 'numpy' from '/lib/python3.8/site-packages/numpy/__init__.py'>) ['$12load_global.5']
    $14load_attr.6 = getattr(value=$12load_global.5, attr=double) ['$12load_global.5', '$14load_attr.6']
    del $12load_global.5                     []
    $18call_function_kw.8 = call $4load_attr.1($10build_tuple.4, func=$4load_attr.1, args=[Var($10build_tuple.4, script.py:15)], kws=[('dtype', Var($14load_attr.6, script.py:15))], vararg=None) ['$10build_tuple.4', '$14load_attr.6', '$18call_function_kw.8', '$4load_attr.1']
    del $4load_attr.1                        []
    del $14load_attr.6                       []
    del $10build_tuple.4                     []
    mat = $18call_function_kw.8              ['$18call_function_kw.8', 'mat']
    del $18call_function_kw.8                []
    $24load_method.10 = getattr(value=mat, attr=reshape) ['$24load_method.10', 'mat']
    del mat                                  []
    $const28.12 = const(int, 2)              ['$const28.12']
    $30binary_power.13 = n ** $const28.12    ['$30binary_power.13', '$const28.12', 'n']
    del n                                    []
    del $const28.12                          []
    $32call_method.14 = call $24load_method.10($30binary_power.13, func=$24load_method.10, args=[Var($30binary_power.13, script.py:16)], kws=(), vararg=None) ['$24load_method.10', '$30binary_power.13', '$32call_method.14']
    del $30binary_power.13                   []
    del $24load_method.10                    []
    $34return_value.15 = cast(value=$32call_method.14) ['$32call_method.14', '$34return_value.15']
    del $32call_method.14                    []
    return $34return_value.15                ['$34return_value.15']

生成图表的脚本:

import numpy as np
from numba import jit
from time import time
import os
import matplotlib.pyplot as plt

os.environ['NUMBA_DUMP_IR'] = '1'

def numpy_zeros(n):
    mat = np.zeros((n,n), dtype=np.double)
    return mat.reshape((n**2))

@jit(nopython=True)
def numba_zeros(n):
    mat = np.zeros((n,n), dtype=np.double)
    return mat.reshape((n**2))

@jit(nopython=True)
def numba_loop(n):
    mat = np.empty((n * 2,n), dtype=np.float32)
    for i in range(mat.shape[0]):
        for j in range(mat.shape[1]):
            mat[i, j] = 0.
    return mat.reshape((2 * n**2))

# To compile
numba_zeros(10)
numba_loop(10)

os.environ['NUMBA_DUMP_IR'] = '0'

max_n = 4100
time_deltas = {
    'numpy_zeros': [],
    'numba_zeros': [],
    'numba_loop': [],
}
call_count = 10
for n in range(0, max_n, 10):
    for f in (numpy_zeros, numba_zeros, numba_loop):
        start = time()
        for i in range(call_count):
              x = f(n)
        delta = time() - start
        time_deltas[f.__name__].append(delta / call_count)
        print(f'{f.__name__:25} n = {n}: {delta}')
    print()

size = np.arange(0, max_n, 10) ** 2 * 8 / 1024 ** 2
fig, ax = plt.subplots()
plt.xticks(np.arange(0, size[-1], 16))
plt.axvline(x=32, color='gray', lw=0.5)
ax.plot(size, time_deltas['numpy_zeros'], label='Numpy zeros (calloc)')
ax.plot(size, time_deltas['numba_zeros'], label='Numba zeros')
ax.plot(size, time_deltas['numba_loop'], label='Numba nested loop')
ax.set_xlabel('Size of array in MB')
ax.set_ylabel(r'Mean $\Delta$t in s')
plt.legend(loc='upper left')
plt.show()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Numba 和 numpy 数组分配:为什么这么慢? 的相关文章

  • 在 django ORM 中查询时如何将 char 转换为整数?

    最近开始使用 Django ORM 我想执行这个查询 select student id from students where student id like 97318 order by CAST student id as UNSIG
  • 将html数据解析成python列表进行操作

    我正在尝试读取 html 网站并提取其数据 例如 我想查看公司过去 5 年的 EPS 每股收益 基本上 我可以读入它 并且可以使用 BeautifulSoup 或 html2text 创建一个巨大的文本块 然后我想搜索该文件 我一直在使用
  • 独立滚动矩阵的行

    我有一个矩阵 准确地说 是 2d numpy ndarray A np array 4 0 0 1 2 3 0 0 5 我想滚动每一行A根据另一个数组中的滚动值独立地 r np array 2 0 1 也就是说 我想这样做 print np
  • Pandas Merge (pd.merge) 如何设置索引和连接

    我有两个 pandas 数据框 dfLeft 和 dfRight 以日期作为索引 dfLeft cusip factorL date 2012 01 03 XXXX 4 5 2012 01 03 YYYY 6 2 2012 01 04 XX
  • 如何将张量流模型部署到azure ml工作台

    我在用Azure ML Workbench执行二元分类 到目前为止 一切正常 我有很好的准确性 我想将模型部署为用于推理的 Web 服务 我真的不知道从哪里开始 azure 提供了这个doc https learn microsoft co
  • 如何在 Python 中解析和比较 ISO 8601 持续时间? [关闭]

    Closed 这个问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 我正在寻找一个 Python v2 库 它允许我解析和比较 ISO 8601 持续时间may处于不同单
  • 从Python中的字典列表中查找特定值

    我的字典列表中有以下数据 data I versicolor 0 Sepal Length 7 9 I setosa 0 I virginica 1 I versicolor 0 I setosa 1 I virginica 0 Sepal
  • Python,将函数的输出重定向到文件中

    我正在尝试将函数的输出存储到Python中的文件中 我想做的是这样的 def test print This is a Test file open Log a file write test file close 但是当我这样做时 我收到
  • “隐藏”内置类对象、函数、代码等的名称和性质[关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 我很好奇模块中存在的类builtins无法直接访问的 例如 type lambda 0 name function of module
  • pyspark 将 twitter json 流式传输到 DF

    我正在从事集成工作spark streaming with twitter using pythonAPI 我看到的大多数示例或代码片段和博客是他们从Twitter JSON文件进行最终处理 但根据我的用例 我需要所有字段twitter J
  • Cython 和类的构造函数

    我对 Cython 使用默认构造函数有疑问 我的 C 类 Node 如下 Node h class Node public Node std cerr lt lt calling no arg constructor lt lt std e
  • Jupyter Notebook 找不到 Python 模块

    不知道发生了什么 但每当我使用 ipython 氢 原子 或 jupyter 笔记本时都找不到任何已安装的模块 我知道我安装了 pandas 但笔记本说找不到 我应该补充一点 当我正常运行脚本时 python script py 它确实导入
  • 从 NumPy ndarray 中选择行

    我只想从 a 中选择某些行NumPy http en wikipedia org wiki NumPy基于第二列中的值的数组 例如 此测试数组的第二列包含从 1 到 10 的整数 gt gt gt test numpy array nump
  • 使用特定颜色和抖动在箱形图上绘制数据点

    我有一个plotly graph objects Box图 我显示了箱形 图中的所有点 我需要根据数据的属性为标记着色 如下所示 我还想抖动这些点 下面未显示 Using Box我可以绘制点并抖动它们 但我不认为我可以给它们着色 fig a
  • 如何断言 Unittest 上的可迭代对象不为空?

    向服务提交查询后 我会收到一本字典或一个列表 我想确保它不为空 我使用Python 2 7 我很惊讶没有任何assertEmpty方法为unittest TestCase类实例 现有的替代方案看起来并不正确 self assertTrue
  • Pandas 将多行列数据帧转换为单行多列数据帧

    我的数据框如下 code df Car measurements Before After amb temp 30 268212 26 627491 engine temp 41 812730 39 254255 engine eff 15
  • 根据列 value_counts 过滤数据框(pandas)

    我是第一次尝试熊猫 我有一个包含两列的数据框 user id and string 每个 user id 可能有多个字符串 因此会多次出现在数据帧中 我想从中导出另一个数据框 一个只有那些user ids列出至少有 2 个或更多string
  • 如何在 pygtk 中创建新信号

    我创建了一个 python 对象 但我想在它上面发送信号 我让它继承自 gobject GObject 但似乎没有任何方法可以在我的对象上创建新信号 您还可以在类定义中定义信号 class MyGObjectClass gobject GO
  • python import inside函数隐藏现有变量

    我在我正在处理的多子模块项目中遇到了一个奇怪的 UnboundLocalError 分配之前引用的局部变量 问题 并将其精简为这个片段 使用标准库中的日志记录模块 import logging def foo logging info fo
  • 使用随机放置的 NaN 创建示例 numpy 数组

    出于测试目的 我想创建一个M by Nnumpy 数组与c随机放置的 NaN import numpy as np M 10 N 5 c 15 A np random randn M N A mask np nan 我在创建时遇到问题mas

随机推荐

  • DialogFragment 按钮被推出屏幕 API 24 及更高版本

    我正在定制DialogFragment显示可选择的数据列表 该列表太长 无法在不滚动的情况下显示在屏幕上 对于 API 23 及以下版本 一切似乎都工作正常 但当我在 API 24 上进行测试时 DialogFragment 的按钮不再可见
  • 从 Firebase 数据库获取的数据显示在 3 个单独的警报对话框中,而不是一个

    我正在从中获取一些数据FirebaseDatabase然后将它们放入array然后尝试以List这是一个习惯AlertDialog 这是代码 query mDatabase child child child anotherChild ch
  • Spring MVC 4:“application/json”内容类型未正确设置

    我有一个使用以下注释映射的控制器 RequestMapping value json method RequestMethod GET produces application json ResponseBody public String
  • 如何强制删除Python对象?

    我很好奇的细节 del 在 python 中 何时 为什么应该使用它以及不应该使用它 我经历了惨痛的教训才知道 它并不像人们天真地期望的析构函数那样 因为它并不是与 new init class Foo object def init se
  • Jquery中动态选择Drop Down

    我有 4 个下拉菜单 默认情况下 每个 drop 都有一个 select 选项 每个盒子都有一个唯一的 ID 如您所见 如果上面的下拉列表值为 select 则禁用第二个下拉列表 仅当该值不是 select 时才会启用 这是我的代码 doc
  • Java ImageIO.read 导致 OSX 挂起

    我必须在 Mac OSX 上读取图像时执行一些操作 但是在调用 ImageIO read File 时它似乎挂起 似乎也没有出现堆栈跟踪 它实际上只是挂起 想知道其他人是否遇到过这个问题 我已经成功地写了一张图片 只是阅读方面似乎有问题 使
  • C# 中的双向适配器和可插拔适配器模式有什么区别?

    双向适配器和可插入适配器都可以访问这两个类 并且还可以更改需要更改的方法的行为 以下是我的代码 双向适配器 public interface IAircraft bool Airborne get void TakeOff int Heig
  • 即使删除 Entitlements.plist 后,Xamarin 钥匙串错误中也找不到有效的 iPhone 代码签名密钥

    我收到这个错误即使删除 Entitlements plist 后 Xamarin 钥匙串错误中仍发现 iPhone 代码签名密钥当尝试使用 Xamarin Studios 构建 HelloWorld iPhone 应用程序时 我了解在真实设
  • 使用 jemmy 测试 java web start 应用程序

    我需要使用 Jemmy 创建一些 gui 测试 但我不知道如何使用 javaws 应用程序启动它 在教程 示例 等中是这样的 new ClassReference org netbeans jemmy explorer GUIBrowser
  • Python Pickling 字典 EOFError

    我有几个脚本在服务器上运行 用于腌制和取消腌制各种字典 它们都使用相同的基本代码进行酸洗 如下所示 SellerDict open home hostadl SellerDictkm rb SellerDictionarykm pickle
  • Angular 6 不应用 scss 样式

    我有一个组件页面和相应的样式表 但是 component scss 中的类不适用于该页面 没有错误 我仍然想知道为什么 这是我的产品详细信息 page component html div h1 Product Detail Page h1
  • 向 EditText 字段添加阴影效果

    我正在尝试设计一个编辑文本字段像这样有阴影 底部和右侧 尝试谷歌搜索并搜索了许多SO讨论 但所有讨论都是针对TextView而不是EditText 这是我的代码 向输入文本添加阴影 但不向 TextField 添加阴影
  • 使用 urllib2 HTTPS 登录

    我目前有一个小脚本 可以下载网页并提取一些我感兴趣的数据 没什么花哨的 目前我正在下载页面 如下所示 import commands command wget output document quiet http user USER htt
  • 从共享菜单中泄露 IntentReceiver

    我通过在特定活动中单击按钮来打开发送菜单 Intent i new Intent Intent ACTION SEND i setType text plain i putExtra Intent EXTRA TEXT meh try st
  • 什么是合适的数据结构和数据库模式来存储逻辑规则?

    前言 我没有规则引擎 构建规则 建模规则 实现规则数据结构等方面的经验 因此 我不知道我在做什么 也不知道我下面的尝试是否偏离了基础 我试图弄清楚如何存储和处理以下假设场景 为了简化我的问题 假设我有一种游戏类型 用户购买一个对象 其中可能
  • 将 Dictionary 序列化为 BSON 时出现 BsonSerializationException

    我最近搬到了新的 MongoDB C 驱动程序 v2 0 https www nuget org packages MongoDB Driver 2 0 0来自已弃用 v1 9 https www nuget org packages mo
  • 如何在 try catch 语句中重新请求输入

    string l Console ReadLine try int Parse l catch FormatException Console WriteLine Invalid input Please enter 1 2 or 3 正如
  • 在python中将字典的字典写入csv

    我有一本字典 我想将其写入 csv 我的字典看起来像 dict object1 time1 value1 value2 time2 value3 value4 object2 time1 value5 value6 time2 value7
  • 使用 boost 序列化抽象类时出错

    我正在尝试序列化我的数据结构 以便将它们写入 TCP 套接字 到目前为止我发现我的问题是序列化 我什至尝试使用 BOOST SERIALIZATION ASSUME ABSTRACT T 但我找不到任何与我的程序类似的工作示例以及如何正确实
  • Numba 和 numpy 数组分配:为什么这么慢?

    我最近使用 Cython 和 Numba 来加速进行数值模拟的 python 的小片段 起初 使用 numba 进行开发似乎更容易 然而 我发现很难理解 numba 何时会提供更好的性能 何时不会 意外性能下降的一个例子是当我使用该函数时n