假设我定义了以下 Cython 类
cdef class Kernel:
cdef readonly double a
def __init__(self, double a):
self.a = a
cdef public double GetValue(self, double t):
return self.a*t
现在我想定义另一个具有内核序列作为属性的扩展类型。就像是:
cdef class Model:
cdef readonly Kernel[:] kernels
cdef unsigned int n_kernels
def __init__(self, Kernel[:] ker):
self.kernels = ker
self.n_kernels = ker.shape[0]
cdef double Run(self, double t):
cdef int i
cdef double out=0.0
for i in range(self.n_kernels):
out += self.kernels[i].GetValue(t)
return out
然而,这不起作用。首先我需要替换Kernel[:]
with object[:]
,否则我会收到以下错误gcc
‘PyObject’ has no member named ‘__pyx_vtab’
如果我使用object[:]
一切编译正常,但在尝试访问时出现错误GetValue
method:
AttributeError: "AttributeError: "'cytest.Kernel' object has no attribute 'GetValue'" in 'cytest.Model.Run'
我想要什么
- 访问
cdef
的方法Kernel
来自cdef
method Run
,没有 Python 开销。
- 对内核元素进行类型检查。
我目前的解决方法
目前我使用以下解决方案,但不满足上述要求:
cdef class Kernel:
cdef readonly double a
def __init__(self, double a):
self.a = a
cpdef public double GetValue(self, double t):
return self.a*t
cdef class Model:
cdef readonly object[:] kernels
cdef unsigned int n_kernels
def __init__(self, object[:] ker):
self.kernels = ker
self.n_kernels = ker.shape[0]
def Run(self, double t):
cdef int i
cdef double out=0.0
for i in range(self.n_kernels):
out += self.kernels[i].GetValue(t)
return out
即我将 Kernel 类的方法声明为cpdef
以便可以从 Python 访问它们并使用object[:]
.
Question
有没有办法在 Cython 中实现上述第 1 点和第 2 点而不需要 Python 开销?
在此先感谢您的时间。
注意:我事先不知道序列的长度。
Edit
根据@DavidW的建议,我修改了代码如下
# module cytest
import cython
cdef class Kernel:
cdef readonly double a
def __init__(self, double a ):
self.a = a
cdef public double GetValue(self, double t):
return self.a*t
cdef class Model:
cdef readonly Kernel[:] kernels
### added this attribute
cdef Kernel k
cdef unsigned int n_kernels
def __cinit__(self, Kernel[:] ker):
self.kernels = ker
self.n_kernels = ker.shape[0]
cpdef double Run(self, double t):
cdef int i
cdef double out=0.0
for i in range(self.n_kernels):
# now i assign to the new attribute each time
# and access the cdef method from it
self.k = self.kernels[i]
out += self.k.GetValue(t)
return out
现在它编译并运行良好(并且比我以前的解决方法更快),即使我在访问时仍然有一些 python 开销Kernel[:]
属性。
我在这里举了一个构建和调用的示例Model
import cytest
import numpy as np
ker_list = [cytest.Kernel(i*1.0) for i in range(3)]
# transform it to a numpy array
# to be able to pass it to the 'Model' constructor
ker_arr = np.array(ker_list)
# create a model instance
model = cytest.Model(ker_arr)
# call the method Run
print model.Run(1.0)