该代码可以转换为 Numba,但并不简单。
首先,从Numba开始必须定义字典和列表类型njit
函数不能直接对反射列表进行操作(又名纯Python列表)。在 Numba 中这有点乏味,而且生成的代码也有点冗长:
String = nb.types.unicode_type
ValueArray = nb.float64[::1]
ValueDict = nb.types.DictType(String, ValueArray)
ParamDictValue = nb.types.Tuple([nb.int_, nb.float64, nb.int_, nb.float64])
ParamDict = nb.types.DictType(String, ParamDictValue)
FinalDictValue = nb.types.ListType(ValueArray)
FinalDict = nb.types.DictType(String, FinalDictValue)
然后你需要转换输入字典:
nbValues = nb.typed.typeddict.Dict.empty(String, ValueArray)
for key,value in values_.items():
nbValues[key] = value.copy()
nbParams = nb.typed.typeddict.Dict.empty(String, ParamDictValue)
for key,value in params.items():
nbParams[key] = (nb.int_(value[0]), nb.float64(value[1]), nb.int_(value[2]), nb.float64(value[3]))
然后,你需要编写核心功能。np.allclose
and np.isin
未在 Numba 中实现,因此应手动重新实现。但最主要的一点是 Numba 不支持rng
numpy 对象。我想短期内肯定不会支持。请注意,Numba 有一个随机数实现,它尝试模仿 Numpy 的行为,但种子的管理有点不同。另请注意,结果应该与np.random.xxx
如果种子设置为相同的值,Numpy 会起作用(Numpy 和 Numba 具有不同步的不同种子变量)。
@nb.njit(FinalDict(ValueDict, ParamDict, nb.int_, nb.int_, nb.int_))
def nbTest(values_, params, Sno, dec_, upd_):
final_dict = nb.typed.Dict.empty(String, FinalDictValue)
for i, j in enumerate(values_.keys()):
Rand_vals = nb.typed.List.empty_list(ValueArray)
goal_sum = params[j][1] * params[j][3]
tel = goal_sum / dec_ * 10
if params[j][0] != 0:
for k in range(Sno):
final_sum = 0.0
iter_ = 0
t = 1
vals_group = np.empty(0, dtype=nb.float64)
while np.abs(goal_sum - final_sum) > (1e-05 * np.abs(final_sum) + tel):
iter_ += 1
vals_group = np.random.choice(values_[j], size=params[j][0], replace=False)
final_sum = 0.0016 * np.sum(vals_group)
# final_sum = 0.0016 * np.sum(vals_group) # (for small data volume)
final_sum = np.sum(vals_group ** 3) # (for med or large data volume)
if iter_ == upd_:
t += 1
tel = t * tel
# Perform an in-place deletion
vals, gr = values_[j], vals_group
cur = 0
for l in range(vals.size):
found = False
for m in range(gr.size):
found |= vals[l] == gr[m]
if not found:
# Keep the value (delete it otherwise)
vals[cur] = vals[l]
cur += 1
values_[j] = vals[:cur]
Rand_vals.append(vals_group)
else:
for k in range(Sno):
Rand_vals.append(np.empty(0, dtype=nb.float64))
final_dict["R" + str(i)] = Rand_vals
return final_dict
请注意,替换实现np.isin
很幼稚,但在您的输入示例的实践中效果很好。
可以使用以下方式调用该函数:
nbFinalDict = nbTest(nbValues, nbParams, Sno, dec_, upd_)
最后,字典应该转换回基本的 Python 对象:
finalDict = dict()
for key,value in nbFinalDict.items():
finalDict[key] = list(value)
这种实现对于小输入来说很快,但对于大输入来说却不是,因为np.random.choice
几乎花费了所有的时间(>96%)。问题是这个函数是显然不是最佳的当请求的项目数量很少时(这是你的情况)。事实上,令人惊讶的是,它以输入数组的线性时间运行,而不是以请求的项目数量的线性时间运行。
进一步优化
该算法可以完全重写,以仅提取 12 个随机项,并以更有效的方式从主 currant 数组中丢弃它们。想法是交换n
数组末尾的项目(小目标样本)与随机位置的其他项目,然后检查总和,重复此过程,直到满足条件,最后提取视图到最后n
调整视图大小之前的项目,以便丢弃最后的项目。所有这一切都可以在O(n)
时间而不是O(m)
时间 地点m
是主电流数组的大小n << m
(例如,12 VS 20_000)。它也可以在没有任何昂贵的分配的情况下进行计算。这是生成的代码:
@nb.njit(nb.void(ValueArray, nb.int_, nb.int_))
def swap(arr, i, j):
arr[i], arr[j] = arr[j], arr[i]
@nb.njit(FinalDict(ValueDict, ParamDict, nb.int_, nb.int_, nb.int_))
def nbTest(values_, params, Sno, dec_, upd_):
final_dict = nb.typed.Dict.empty(String, FinalDictValue)
for i, j in enumerate(values_.keys()):
Rand_vals = nb.typed.List.empty_list(ValueArray)
goal_sum = params[j][1] * params[j][3]
tel = goal_sum / dec_ * 10
values = values_[j]
n = params[j][0]
if n != 0:
for k in range(Sno):
final_sum = 0.0
iter_ = 0
t = 1
m = values.size
assert n <= m
group = values[-n:]
while np.abs(goal_sum - final_sum) > (1e-05 * np.abs(final_sum) + tel):
iter_ += 1
# Swap the group view with other random items
for pos in range(m - n, m):
swap(values, pos, np.random.randint(0, m))
# For small data volume:
# final_sum = 0.0016 * np.sum(group)
# For med/large data volume
final_sum = 0.0
for v in group:
final_sum += v ** 3
if iter_ == upd_:
t += 1
tel *= t
assert iter_ > 0
values = values[:m-n]
Rand_vals.append(group)
else:
for k in range(Sno):
Rand_vals.append(np.empty(0, dtype=nb.float64))
final_dict["R" + str(i)] = Rand_vals
return final_dict
除了更快之外,这种实现的好处还在于更简单。结果看起来与之前的实现非常相似,尽管随机性使得结果检查变得棘手(特别是因为该函数不使用相同的方法来选择随机样本)。请注意,此实现不会删除中的项目values
那些在group
与前一个相反(但这可能不是想要的)。
基准
以下是我的机器上最后一次实现的结果(不包括编译和转换时间):
Provided small input (embedded in the question):
- Initial code: 42.71 ms
- Numba code: 0.11 ms
Medium input:
- Initial code: 3481 ms
- Numba code: 11 ms
Large input:
- Initial code: 6728 ms
- Numba code: 20 ms
请注意,转换时间与计算时间大致相同。
最后一个实现是快316~388倍比小输入上的初始代码。
Notes
请注意,由于字典和列表类型,编译时间需要几秒钟。
请注意,虽然可以并行化实现,但只能并行化最具包容性的循环。问题是只有很少的项目需要计算,而且时间已经相当短了(这不是多线程的最佳情况)。 rng.choice)肯定会导致并行循环无法很好地扩展。 --> 此外,列表/字典无法从多个线程安全地写入,因此需要在整个函数中使用 Numpy 数组才能做到这一点(或添加已经很昂贵的额外转换)。此外,Numba 并行性往往会显着增加本来就很重要的编译时间。最后,结果的确定性较差,因为每个 Numba 线程都有自己的随机数生成器种子,并且无法使用以下方法预测线程计算的项目prange
(取决于目标平台上选择的并行运行时)。请注意,在 Numpy 中,通常的随机函数默认使用一个全局种子(已弃用的方式),并且 RNG 对象有自己的种子(新的首选方式)。