通过字典使用 numba njit 并行化来加速代码的问题

2024-03-25

我编写了一个代码并尝试使用 numba 来加速代码。代码的主要目标是根据条件对一些值进行分组。在这方面,iter_用于收敛代码以满足条件。下面我准备了一个小案例来重现示例代码:

import numpy as np
import numba as nb

rng = np.random.default_rng(85)

# --------------------------------------- small data volume ---------------------------------------
# values_ = {'R0': np.array([0.01090976, 0.01069902, 0.00724112, 0.0068463 , 0.01135723, 0.00990762,
#                                        0.01090976, 0.01069902, 0.00724112, 0.0068463 , 0.01135723]),
#            'R1': np.array([0.01836379, 0.01900166, 0.01864162, 0.0182823 , 0.01840322, 0.01653088,
#                                        0.01900166, 0.01864162, 0.0182823 , 0.01840322, 0.01653088]),
#            'R2': np.array([0.02430913, 0.02239156, 0.02225379, 0.02093393, 0.02408692, 0.02110411,
#                                        0.02239156, 0.02225379, 0.02093393, 0.02408692, 0.02110411])}
#
# params = {'R0': [3, 0.9490579204466154, 1825, 7.070272000000002e-05],
#           'R1': [0, 0.9729203826820172, 167 , 7.070272000000002e-05],
#           'R2': [1, 0.6031363088057902, 1316, 8.007296000000003e-05]}
#
# Sno, dec_, upd_ = 2, 100, 200
# -------------------------------------------------------------------------------------------------

# ----------------------------- UPDATED (medium and large data volumes) ---------------------------
# values_ = np.load("values_med.npy", allow_pickle=True)[()]
# params = np.load("params_med.npy", allow_pickle=True)[()]
values_ = np.load("values_large.npy", allow_pickle=True)[()]
params = np.load("params_large.npy", allow_pickle=True)[()]

Sno, dec_, upd_ = 2000, 1000, 200
# -------------------------------------------------------------------------------------------------

# values_ = [*values_.values()]
# params = [*params.values()]


# @nb.jit(forceobj=True)
# def test(values_, params, Sno, dec_, upd_):

final_dict = {}
for i, j in enumerate(values_.keys()):
    Rand_vals = []
    goal_sum = params[j][1] * params[j][3]
    tel = goal_sum / dec_ * 10
    if params[j][0] != 0:
        for k in range(Sno):
            final_sum = 0.0
            iter_ = 0
            t = 1
            while not np.allclose(goal_sum, final_sum, atol=tel):
                iter_ += 1
                vals_group = rng.choice(values_[j], size=params[j][0], replace=False)
                # final_sum = 0.0016 * np.sum(vals_group)  # -----> For small data volume
                final_sum = np.sum(vals_group ** 3)        # -----> UPDATED For med or large data volume
                if iter_ == upd_:
                    t += 1
                    tel = t * tel
            values_[j] = np.delete(values_[j], np.where(np.in1d(values_[j], vals_group)))
            Rand_vals.append(vals_group)
    else:
        Rand_vals = [np.array([])] * Sno
    final_dict["R" + str(i)] = Rand_vals

#    return final_dict


# test(values_, params, Sno, dec_, upd_)

首先,在这段代码上应用 numba@nb.jit被使用(forceobj=True用于避免警告和...),这将对性能产生不利影响。nopython也被检查了@nb.njit由于以下原因而出现以下错误不支持 https://github.com/numba/numba/issues/7503(如中提到的1 https://stackoverflow.com/questions/55078628/using-dictionaries-with-numba-njit-function, 2 https://stackoverflow.com/questions/50744686/numba-typingerror-cannot-determine-numba-type-of-class-builtin-function-or) 字典类型输入的:

无法确定 的 Numba 类型

我不知道是否(如何)可以处理Dict from numba.typed(通过将创建的 python 字典转换为 numba Dict)或者如果将字典 to 数组列表有什么优势。我认为,如果某些代码行例如Rand_vals.append(vals_group) or 其他部分或者……从函数中取出或修改以获得与以前相同的结果,但我不知道该怎么做。

我将非常感谢您帮助在这段代码上使用 numba。numba parallelization如果可以的话,将是最理想的(可能是性能方面最适用的方法)解决方案.


Data:

  • 中等数据量:价值观_医学 https://drive.google.com/file/d/1lIeYkbmw6Mjeb6xHY25D5UFuOvBvg9QC/view?usp=sharing, 参数_med https://drive.google.com/file/d/1GjYK13KjyJGab4eyX9V8MRILv_K8S6oa/view?usp=sharing
  • 大数据量:值_大 https://drive.google.com/file/d/1GoQaeYoM5WwBL4jM5ALGw8FN7lyiTLZG/view?usp=sharing, 参数大 https://drive.google.com/file/d/13TsB4mGGpchWd-32POIfHfLzGbQB4SCz/view?usp=sharing

该代码可以转换为 Numba,但并不简单。

首先,从Numba开始必须定义字典和列表类型njit函数不能直接对反射列表进行操作(又名纯Python列表)。在 Numba 中这有点乏味,而且生成的代码也有点冗长:

String = nb.types.unicode_type
ValueArray = nb.float64[::1]
ValueDict = nb.types.DictType(String, ValueArray)
ParamDictValue = nb.types.Tuple([nb.int_, nb.float64, nb.int_, nb.float64])
ParamDict = nb.types.DictType(String, ParamDictValue)
FinalDictValue = nb.types.ListType(ValueArray)
FinalDict = nb.types.DictType(String, FinalDictValue)

然后你需要转换输入字典:

nbValues = nb.typed.typeddict.Dict.empty(String, ValueArray)
for key,value in values_.items():
    nbValues[key] = value.copy()

nbParams = nb.typed.typeddict.Dict.empty(String, ParamDictValue)
for key,value in params.items():
    nbParams[key] = (nb.int_(value[0]), nb.float64(value[1]), nb.int_(value[2]), nb.float64(value[3]))

然后,你需要编写核心功能。np.allclose and np.isin未在 Numba 中实现,因此应手动重新实现。但最主要的一点是 Numba 不支持rngnumpy 对象。我想短期内肯定不会支持。请注意,Numba 有一个随机数实现,它尝试模仿 Numpy 的行为,但种子的管理有点不同。另请注意,结果应该与np.random.xxx如果种子设置为相同的值,Numpy 会起作用(Numpy 和 Numba 具有不同步的不同种子变量)。

@nb.njit(FinalDict(ValueDict, ParamDict, nb.int_, nb.int_, nb.int_))
def nbTest(values_, params, Sno, dec_, upd_):
    final_dict = nb.typed.Dict.empty(String, FinalDictValue)
    for i, j in enumerate(values_.keys()):
        Rand_vals = nb.typed.List.empty_list(ValueArray)
        goal_sum = params[j][1] * params[j][3]
        tel = goal_sum / dec_ * 10
        if params[j][0] != 0:
            for k in range(Sno):
                final_sum = 0.0
                iter_ = 0
                t = 1

                vals_group = np.empty(0, dtype=nb.float64)

                while np.abs(goal_sum - final_sum) > (1e-05 * np.abs(final_sum) + tel):
                    iter_ += 1
                    vals_group = np.random.choice(values_[j], size=params[j][0], replace=False)
                    final_sum = 0.0016 * np.sum(vals_group)
                    # final_sum = 0.0016 * np.sum(vals_group)  # (for small data volume)
                    final_sum = np.sum(vals_group ** 3)        # (for med or large data volume)
                    if iter_ == upd_:
                        t += 1
                        tel = t * tel

                # Perform an in-place deletion
                vals, gr = values_[j], vals_group
                cur = 0
                for l in range(vals.size):
                    found = False
                    for m in range(gr.size):
                        found |= vals[l] == gr[m]
                    if not found:
                        # Keep the value (delete it otherwise)
                        vals[cur] = vals[l]
                        cur += 1
                values_[j] = vals[:cur]

                Rand_vals.append(vals_group)
        else:
            for k in range(Sno):
                Rand_vals.append(np.empty(0, dtype=nb.float64))
        final_dict["R" + str(i)] = Rand_vals
    return final_dict

请注意,替换实现np.isin很幼稚,但在您的输入示例的实践中效果很好。

可以使用以下方式调用该函数:

nbFinalDict = nbTest(nbValues, nbParams, Sno, dec_, upd_)

最后,字典应该转换回基本的 Python 对象:

finalDict = dict()
for key,value in nbFinalDict.items():
    finalDict[key] = list(value)

这种实现对于小输入来说很快,但对于大输入来说却不是,因为np.random.choice几乎花费了所有的时间(>96%)。问题是这个函数是显然不是最佳的当请求的项目数量很少时(这是你的情况)。事实上,令人惊讶的是,它以输入数组的线性时间运行,而不是以请求的项目数量的线性时间运行。


进一步优化

该算法可以完全重写,以仅提取 12 个随机项,并以更有效的方式从主 currant 数组中丢弃它们。想法是交换n数组末尾的项目(小目标样本)与随机位置的其他项目,然后检查总和,重复此过程,直到满足条件,最后提取视图到最后n调整视图大小之前的项目,以便丢弃最后的项目。所有这一切都可以在O(n)时间而不是O(m)时间 地点m是主电流数组的大小n << m(例如,12 VS 20_000)。它也可以在没有任何昂贵的分配的情况下进行计算。这是生成的代码:

@nb.njit(nb.void(ValueArray, nb.int_, nb.int_))
def swap(arr, i, j):
    arr[i], arr[j] = arr[j], arr[i]

@nb.njit(FinalDict(ValueDict, ParamDict, nb.int_, nb.int_, nb.int_))
def nbTest(values_, params, Sno, dec_, upd_):
    final_dict = nb.typed.Dict.empty(String, FinalDictValue)
    for i, j in enumerate(values_.keys()):
        Rand_vals = nb.typed.List.empty_list(ValueArray)
        goal_sum = params[j][1] * params[j][3]
        tel = goal_sum / dec_ * 10
        values = values_[j]
        n = params[j][0]

        if n != 0:
            for k in range(Sno):
                final_sum = 0.0
                iter_ = 0
                t = 1

                m = values.size
                assert n <= m
                group = values[-n:]

                while np.abs(goal_sum - final_sum) > (1e-05 * np.abs(final_sum) + tel):
                    iter_ += 1

                    # Swap the group view with other random items
                    for pos in range(m - n, m):
                        swap(values, pos, np.random.randint(0, m))

                    # For small data volume:
                    # final_sum = 0.0016 * np.sum(group)

                    # For med/large data volume
                    final_sum = 0.0
                    for v in group:
                        final_sum += v ** 3

                    if iter_ == upd_:
                        t += 1
                        tel *= t

                assert iter_ > 0
                values = values[:m-n]
                Rand_vals.append(group)
        else:
            for k in range(Sno):
                Rand_vals.append(np.empty(0, dtype=nb.float64))
        final_dict["R" + str(i)] = Rand_vals
    return final_dict

除了更快之外,这种实现的好处还在于更简单。结果看起来与之前的实现非常相似,尽管随机性使得结果检查变得棘手(特别是因为该函数不使用相同的方法来选择随机样本)。请注意,此实现不会删除中的项目values那些在group与前一个相反(但这可能不是想要的)。


基准

以下是我的机器上最后一次实现的结果(不包括编译和转换时间):

Provided small input (embedded in the question):
 - Initial code:   42.71 ms
 - Numba code:      0.11 ms

Medium input:
 - Initial code:   3481 ms
 - Numba code:       11 ms

Large input:
 - Initial code:   6728 ms
 - Numba code:       20 ms

请注意,转换时间与计算时间大致相同。

最后一个实现是快316~388倍比小输入上的初始代码。


Notes

请注意,由于字典和列表类型,编译时间需要几秒钟。

请注意,虽然可以并行化实现,但只能并行化最具包容性的循环。问题是只有很少的项目需要计算,而且时间已经相当短了(这不是多线程的最佳情况)。 rng.choice)肯定会导致并行循环无法很好地扩展。 --> 此外,列表/字典无法从多个线程安全地写入,因此需要在整个函数中使用 Numpy 数组才能做到这一点(或添加已经很昂贵的额外转换)。此外,Numba 并行性往往会显着增加本来就很重要的编译时间。最后,结果的确定性较差,因为每个 Numba 线程都有自己的随机数生成器种子,并且无法使用以下方法预测线程计算的项目prange(取决于目标平台上选择的并行运行时)。请注意,在 Numpy 中,通常的随机函数默认使用一个全局种子(已弃用的方式),并且 RNG 对象有自己的种子(新的首选方式)。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

通过字典使用 numba njit 并行化来加速代码的问题 的相关文章

  • 取消的分支与常规分支有何不同?

    特别是对于 SPARC Assembly 取消的分支与常规分支有何不同 我一直认为 当我需要填充分支指令的 nop 延迟槽时 需要取消分支指令 但是 我认为我在这一部分上是不正确的 因为您可以在不取消分支的情况下填充 nop 如果不采用分支
  • 如何提高MySQL INSERT和UPDATE性能?

    我们数据库中的 INSERT 和 UPDATE 语句的性能似乎正在下降 并导致我们的 Web 应用程序性能不佳 表是InnoDB 应用程序使用事务 我可以做一些简单的调整来加快速度吗 我认为我们可能会遇到一些锁定问题 我怎样才能找到答案 你
  • Google BigQuery:检索每行的最后版本

    我有一个 Google BigQuery 表 其中包含所有版本的资源 每次创建 更新 删除资源时 都会添加一个新行 并递增版本号 该数字将是添加行时的时间戳 ID ResourceID Action Count Timestamp ABC
  • 主键删除需要多长时间?

    画一个简单的表结构 Table1 Table2 ID lt ID Name gt Table1ID Name Table1有几百万行 例如 350 万行 我通过主键发出删除 DELETE FROM Table1 WHERE ID 100 中
  • 在 Android 谷歌地图中绘制 4K 折线

    我现在正在开发一个适用于 Android 设备的应用程序 主要功能是在地图上绘制折线以显示城市中每条街道的交通情况 不幸的是 当我绘制大约 3K 折线时 数量会根据屏幕尺寸和缩放级别而减少 我的地图变得非常慢 我没有提及绘制所有线条的时间
  • 如何使用 VBA 将符号/图标格式化为单元格而不使用条件格式

    我使用 VBA 代码放置条件格式以覆盖大型表格中的值 每个单元格使用 2 个公式来确定使用 3 个符号中的哪一个 我需要根据列使用不同的单元格检查每个单元格的值 因此据我了解 我必须将条件格式规则单独放置在每个单元格上 以确保每个单元格中的
  • 让 GHC 生成“带进位加法 (ADC)”指令

    下面的代码将表示 192 位数字的两个未装箱字三元组添加到新的未装箱字三元组中 并且还返回任何溢出 LANGUAGE MagicHash LANGUAGE UnboxedTuples import GHC Prim plusWord2 Wo
  • gcc总是做这种优化吗? (公共子表达式消除)

    作为示例 假设表达式sys gt pot atoms item gt P kind mass在循环内求值 循环只改变item 因此表达式可以简化为atoms item gt P kind mass通过将变量定义为atoms sys gt p
  • PHP include():文件大小和性能

    一个没有经验的PHP问题 我有一个 PHP 脚本文件 我需要在不同页面的很多地方多次包含该文件 我可以选择将包含的文件分解为几个较小的文件 并根据需要包含这些文件 或者 我可以将它们全部保存在一个 PHP 文件中 我想知道在这种情况下使用较
  • 加载实体实例需要超过 1 秒

    我在EF中遇到了一件有趣的事情 如果我们使用基础实体获取子实体 则加载实体需要更多时间 我的模型看起来像这样 public abstract class BaseDocument public Guid Id get set public
  • MSMQ 慢速队列读取

    我正在使用一个开源 Net 库 它在底层使用 MSMQ 大约一两周后 服务速度变慢 时间不准确 但一般猜测 看来发生的情况是来自 MSMQ 的消息每 10 秒才被读取一次 通常 它们会立即被读取 因此 它们将在 T 10 秒 T 20 秒
  • 如何分析Android应用程序的电池使用情况并对其进行优化?

    我想分析我的应用程序的电池使用情况 我的意思是应用程序的各个部分 例如 广播接收器 监听器 服务等 使用多少电池 我需要一个详细的列表 从列表中 我想优化电池的使用 方法与使用内存分析器类似 http android developers
  • 静态方法是否会立即编译(JIT)?

    根据我的理解 CLR 编译器对实例方法和静态方法的处理方式相同 并且每当首次调用该方法时 IL 代码都会进行 JIT 编译 今天我和同事讨论了 他告诉我静态方法与实例方法的处理方式不同 即 静态方法在程序集加载到应用程序域后立即进行 JIT
  • 如果我将一个大函数声明为内联函数怎么办?

    我搜索了一些相关问题 例如C 中内联函数的好处 https stackoverflow com questions 145838 benefits of inline functions in c 但我还有疑问 如果内联函数只是为了 为编译
  • 我想优化这个短循环

    我想优化这个简单的循环 unsigned int i while j 0 j is an unsigned int with a start value of about N 36 000 000 float sub 0 i 1 unsig
  • 如何提高QNX6下Eclipse IDE的性能

    我们在 VMWare 环境中通过 QNX6 运行 Eclipse 速度非常慢 Eclipse 是这样启动的 usr qnx630 host qnx6 x86 usr qde eclipse eclipse data root workspa
  • 从 foreach 循环赋值

    我想并行化一个循环 例如 td lt data frame cbind c rep 1 4 2 rep 1 5 rep 1 10 2 names td lt c val id res lt rep NA NROW td for i in l
  • Scipy 最小化 fmin - 语法问题

    我有一个函数 它接受多个参数 一个数组和两个浮点数 并返回一个标量 浮点数 现在我想通过改变两个参数来最小化这个函数 两个浮点数 该数组在函数内部 解包 然后使用其内容 数组和浮点数 如何使用 SciPy 的 fmin 函数来完成此操作 我
  • 如何减少 JSF 中的 javax.faces.ViewState

    减少 JSF 中视图状态隐藏字段大小的最佳方法是什么 我注意到我的视图状态约为 40k 这会在每次请求和响应时下降到客户端并返回到服务器 特别是到达服务器时 这对用户来说会显着减慢 我的环境 JSF 1 2 MyFaces Tomcat T
  • Rglpk - 梦幻足球阵容优化器 - For 循环输出的 Rbind

    我有一个使用 Rgplk 的梦幻足球阵容优化器 它使用for循环生成多个最佳阵容 其数量由用户输入 代码如下 Lineups lt list for i in 1 Lineup no matrix lt rbind as numeric D

随机推荐

  • 我什么时候应该使用 Response.Redirect(url, true)?

    我正在重定向到一个错误页面 其中包含一条经过美化的错误消息Application Error 在 Global asax 中 目前它说 Response Redirect Error aspx true 应该是 Response Redir
  • MongoDB:cursor.toArray 返回 Promise { }

    情况 我写了一个查询 var results db collection diseases find ttl txt regex data options i toArray Problem 然后我打印了results到控制台 if res
  • VBA 中运行时错误 429,但类已注册

    我正在尝试重新创建一个程序 该程序使用 JavaScript 打开与 PLC 的连接 然后在网页上显示各种信息 由于各种原因 我宁愿将其以 MS Access 的形式保存 并且一直在努力寻找合适的 dll 来使用 Jet32X dll 如果
  • Chrome 中的 HTML5 视频边框半径不起作用

    我试图让我的 HTML5 视频具有透明的左上角和左下角圆角 就像使用 border radius 时的行为一样 不幸的是 在 Chrome 中 由于某种原因 border radius 在 HTML 视频标签上不起作用 但在 IE10 和
  • numpy stride_tricks.as_strided 与滚动窗口的列表理解

    在处理滚动窗口时 我以列表理解的方式编写函数 np std x i i framesize for i in range 0 len x framesize hopsize 最近我发现numpy lib stride tricks as s
  • bash 重命名带有空格的文件时出错 - mv 目标不是目录

    我正在尝试重命名一堆包含空格的文件 去掉空格 我以为我找到了正确的 bash 命令 for f in txt do mv f f done 但是 这会给每个文件带来错误 mv 目标不是目录 如果我在命令中将 mv 替换为 echo mv 它
  • 确定PDF文件中的页数[关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 我需要使用 C 代码 NET 2 0 确定指定 PDF 文件的页数 PDF 文件将从文件系统读取 而不是从 URL 读取 有谁知道如何
  • 如何在 Javascript 中格式化时间戳以将其显示在图表中? UTC 没问题

    基本上 我收到原始时间戳 需要将它们格式化为 HH MM SS 格式 这是一个提供灵活的 UTC 日期格式的函数 它接受类似于 Java 的 SimpleDateFormat 的格式字符串 function formatDate date
  • Rails:如何在搜索结果中使用构面

    我有一个铁路应用程序 我正在其中搜索维修店 搜索类方法如下所示 def self search params if params repairshop Repairshop where approved gt true if params
  • 通用 ELF 中的重定位(EM:40)

    我尝试从 Ubuntu 交叉编译到Friendly arm 但出现了奇怪的错误 root kevin VirtualBox home kevin Desktop makef make ARCH arm CROSS COMPILE arm n
  • 无法调用按钮命令:应用程序已被破坏

    下面给出了使用 Tkinter 和 Python 创建独立窗口的代码 import Tkinter Tkinter NoDefaultRoot win1 Tkinter Tk win2 Tkinter Tk Tkinter Button w
  • 连接远程redis服务器

    我想对 redis conf 进行一些更改 以便每当我输入 redis cli 时 它都会将我连接到远程服务器上安装的 redis 我知道我们可以通过以下方式连接到安装在远程服务器上的redis redis cli h IP Address
  • JAVA:如何创建 http url 连接选择要使用的 IP 地址

    我在多个 NIC 上配置了一个公共 IP 地址池 在我的 JAVA 项目中 该项目在 LINUX 计算机上运行 我需要从池中选择一个特定的 IP 地址 并使用该 IP 创建一个 HttpURLConnecion 此外 我将在池上循环 每次使
  • 模拟用户活动

    我想模拟 Windows 计算机中的用户活动 例如鼠标左键单击 此外我想执行预定义的步骤可重复性 有没有可用的工具 请建议我一个简单又好的方法来做到这一点 我已经使用 AutoIT v3 很长时间了 强烈推荐它 http www autoi
  • Python 异步任务和 CPU 密集型任务?

    我最近一直在使用 Flask 在 python 中开发一个宠物项目 它是一个简单的 Pastebin 具有服务器端语法突出显示 pygments 的支持 因为这是一项成本高昂的任务 所以我将语法突出显示委托给了 celery 任务队列 并在
  • 如何设置 html“select”元素选项的样式?

    这是我的 HTML
  • Laravel - 未找到模型类

    当开始使用模型时 我收到以下错误 找不到班级帖子 我所做的一切 使用命令创建模型php artisan make model 尝试从表中获取所有条目posts with echo Post all 我使用了以下代码 路由器 php Rout
  • React 中是否可以从容器内触发包含组件的渲染?

    所以我得到了App它实现了一个componentDidMount and render 应用程序包含 2 个组件 一个 一个AutoComplete输入 另一个是CardView 该计划是 一旦用户从列表中选择了一个项目AutoComple
  • 当有大量可用内存时出现 OutOfMemoryException

    我们有一个在 5 个 服务器 节点 16 个核心 每个 128 GB 内存 上运行的应用程序 在每台计算机上加载近 70 GB 的数据 该应用程序是分布式的并为并发客户端提供服务 因此 有大量的套接字使用 类似地 对于多个线程之间的同步 有
  • 通过字典使用 numba njit 并行化来加速代码的问题

    我编写了一个代码并尝试使用 numba 来加速代码 代码的主要目标是根据条件对一些值进行分组 在这方面 iter 用于收敛代码以满足条件 下面我准备了一个小案例来重现示例代码 import numpy as np import numba