了解 numba 并行化中的竞争条件

2024-01-06

Numba 文档中有一个关于并行竞争条件的示例

import numba as nb
import numpy as np
@nb.njit(parallel=True)
def prange_wrong_result(x):
    n = x.shape[0]
    y = np.zeros(4)
    for i in nb.prange(n):
        y[:]+= x[i]
    return y

我已经运行了,它确实输出了异常结果,例如

prange_wrong_result(np.ones(10000))
#array([5264., 5273., 5231., 5234.])

然后我尝试将循环更改为

import numba as nb
import numpy as np
@nb.njit(parallel=True)
def prange_wrong_result(x):
    n = x.shape[0]
    y = np.zeros(4)
    for i in nb.prange(n):
        y+= x[i]
    return y

它输出

prange_wrong_result(np.ones(10000))
#array([10000., 10000., 10000., 10000.])

我读过一些竞争条件的解释。但我还是不明白

  1. 为什么第二个例子没有赛车条件?有什么区别y[:]= vs y=
  2. 为什么第一个例子中四个元素的输出不一样?

在第一个示例中,您有多个线程/进程共享同一数组并读取+分配给共享数组。这y[:] += x[i]大致相当于:

y[0] += x[i]
y[1] += x[i]
y[2] += x[i]
y[3] += x[i]

事实上+=只是读取、加法和赋值操作的语法糖,所以y[0] += x[i]事实上是:

_value = y[0]
_value = _value + x[i]
y[0] = _value

循环体由多个线程/进程同时执行,这就是竞争条件出现的地方。维基百科上关于竞争条件的示例适用于此处:

这就是返回的数组包含错误值以及每个元素可能不同的原因。因为它根本不确定哪个线程/进程何时运行。因此,在某些情况下,一个元素上存在竞争条件,有时没有,有时多个元素上存在竞争条件。

然而,numba 开发人员在不发生竞争条件的情况下实现了一些受支持的减少。其中之一是y +=。这里重要的是它是变量本身,而不是变量的切片/元素。在这种情况下,numba 会做一些非常聪明的事情。它们为每个线程/进程复制变量的初始值,然后对该副本进行操作。并行循环完成后,它们将复制的值相加。以您的第二个示例为例,假设它使用 2 个进程,则它大致如下所示:

y = np.zeros(4)
y_1 = y.copy()
y_2 = y.copy()
for i in nb.prange(n):
    if is_process_1:
        y_1[:] += x[i]
    if is_process_2:
        y_2[:] += x[i]
y += y_1
y += y_2

由于每个线程都有自己的数组,因此不可能出现竞争条件。为了让 numba 能够推断出这一点,你必须遵守他们的限制。文档指出 numba 创建无竞争条件的并行代码+=关于标量和数组 (y += x[i]), but 不在数组元素/切片上 (y[:] += x[i] or y[1] += x[i]).

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

了解 numba 并行化中的竞争条件 的相关文章

随机推荐