我正在尝试为具有 numpy 数组输入参数的函数制作一个缓存装饰器
from functools import lru_cache
import numpy as np
from time import sleep
a = np.array([1,2,3,4])
@lru_cache()
def square(array):
sleep(1)
return array * array
square(a)
但 numpy 数组不可散列,
TypeError Traceback (most recent call last)
<ipython-input-13-559f69d0dec3> in <module>()
----> 1 square(a)
TypeError: unhashable type: 'numpy.ndarray'
因此需要将它们转换为元组。我的工作和缓存正确:
@lru_cache()
def square(array_hashable):
sleep(1)
array = np.array(array_hashable)
return array * array
square(tuple(a))
但我想把它全部包裹在一个装饰器中,到目前为止我已经尝试过:
def np_cache(function):
def outter(array):
array_hashable = tuple(array)
@lru_cache()
def inner(array_hashable_inner):
array_inner = np.array(array_hashable_inner)
return function(array_inner)
return inner(array_hashable)
return outter
@np_cache
def square(array):
sleep(1)
return array * array
But 缓存不起作用。计算已执行但未正确缓存,因为我总是等待 1 秒。
我在这里缺少什么?我正在猜测lru_cache
没有获得正确的上下文并且它在每次调用中都被实例化,但我不知道如何修复它。
我试过盲目地扔functools.wraps
装饰者到处都没有运气。