使用 C 函数扩展 Numpy

2024-04-02

我正在尝试加速我的 Numpy 代码,并决定实现一个特定的函数,而我的代码大部分时间都在 C 中使用。

我实际上是 C 的菜鸟,但我设法编写了一个函数,将矩阵中的每一行归一化为 1。我可以编译它,并用一些数据(在 C 中)测试它,它满足了我的要求。那时我为自己感到非常自豪。

现在我尝试从 Python 调用我的光荣函数,它应该接受 2d-Numpy 数组。

我尝试过的各种事情是

  • SWIG

  • SWIG + numpy.i

  • ctypes

我的函数有原型

void normalize_logspace_matrix(size_t nrow, size_t ncol, double mat[nrow][ncol]);

因此它需要一个指向可变长度数组的指针并就地修改它。

我尝试了以下纯 SWIG 接口文件:

%module c_utils

%{
extern void normalize_logspace_matrix(size_t, size_t, double mat[*][*]);
%}

extern void normalize_logspace_matrix(size_t, size_t, double** mat);

然后我会这样做(在 Mac OS X 64 位上):

> swig -python c-utils.i

> gcc -fPIC c-utils_wrap.c -o c-utils_wrap.o \
     -I/Library/Frameworks/Python.framework/Versions/6.2/include/python2.6/ \
     -L/Library/Frameworks/Python.framework/Versions/6.2/lib/python2.6/ -c
c-utils_wrap.c: In function ‘_wrap_normalize_logspace_matrix’:
c-utils_wrap.c:2867: warning: passing argument 3 of ‘normalize_logspace_matrix’ from   incompatible pointer type

> g++ -dynamiclib c-utils.o -o _c_utils.so

在 Python 中,我在导入模块时收到以下错误:

>>> import c_utils
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ImportError: dynamic module does not define init function (initc_utils)

接下来我使用 SWIG + 尝试了这种方法numpy.i:

%module c_utils

%{
#define SWIG_FILE_WITH_INIT
#include "c-utils.h"
%}
%include "numpy.i"
%init %{
import_array();
%}

%apply ( int DIM1, int DIM2, DATA_TYPE* INPLACE_ARRAY2 ) 
       {(size_t nrow, size_t ncol, double* mat)};

%include "c-utils.h"

但是,我没有得到比这更进一步的信息:

> swig -python c-utils.i
c-utils.i:13: Warning 453: Can't apply (int DIM1,int DIM2,DATA_TYPE *INPLACE_ARRAY2). No typemaps are defined.

SWIG 似乎没有找到定义的类型映射numpy.i,但我不明白为什么,因为numpy.i位于同一目录中,SWIG 不会抱怨找不到它。

使用 ctypes,我并没有走得太远,但很快就迷失在文档中,因为我不知道如何向它传递一个 2d 数组,然后返回结果。

那么有人可以向我展示如何让我的函数在 Python/Numpy 中可用的魔术吗?


除非您有充分的理由不这样做,否则您应该使用 cython 来连接 C 和 python。 (我们开始在 numpy/scipy 本身中使用 cython 而不是原始 C)。

你可以在我的 scikits 中看到一个简单的例子talkbox https://github.com/cournape/talkbox(由于 cython 从那时起已经有了很大的改进,我认为你今天可以写得更好)。

def cslfilter(c_np.ndarray b, c_np.ndarray a, c_np.ndarray x):
    """Fast version of slfilter for a set of frames and filter coefficients.
    More precisely, given rank 2 arrays for coefficients and input, this
    computes:

    for i in range(x.shape[0]):
        y[i] = lfilter(b[i], a[i], x[i])

    This is mostly useful for processing on a set of windows with variable
    filters, e.g. to compute LPC residual from a signal chopped into a set of
    windows.

    Parameters
    ----------
        b: array
            recursive coefficients
        a: array
            non-recursive coefficients
        x: array
            signal to filter

    Note
    ----

    This is a specialized function, and does not handle other types than
    double, nor initial conditions."""

    cdef int na, nb, nfr, i, nx
    cdef double *raw_x, *raw_a, *raw_b, *raw_y
    cdef c_np.ndarray[double, ndim=2] tb
    cdef c_np.ndarray[double, ndim=2] ta
    cdef c_np.ndarray[double, ndim=2] tx
    cdef c_np.ndarray[double, ndim=2] ty

    dt = np.common_type(a, b, x)

    if not dt == np.float64:
        raise ValueError("Only float64 supported for now")

    if not x.ndim == 2:
        raise ValueError("Only input of rank 2 support")

    if not b.ndim == 2:
        raise ValueError("Only b of rank 2 support")

    if not a.ndim == 2:
        raise ValueError("Only a of rank 2 support")

    nfr = a.shape[0]
    if not nfr == b.shape[0]:
        raise ValueError("Number of filters should be the same")

    if not nfr == x.shape[0]:
        raise ValueError, \
              "Number of filters and number of frames should be the same"

    tx = np.ascontiguousarray(x, dtype=dt)
    ty = np.ones((x.shape[0], x.shape[1]), dt)

    na = a.shape[1]
    nb = b.shape[1]
    nx = x.shape[1]

    ta = np.ascontiguousarray(np.copy(a), dtype=dt)
    tb = np.ascontiguousarray(np.copy(b), dtype=dt)

    raw_x = <double*>tx.data
    raw_b = <double*>tb.data
    raw_a = <double*>ta.data
    raw_y = <double*>ty.data

    for i in range(nfr):
        filter_double(raw_b, nb, raw_a, na, raw_x, nx, raw_y)
        raw_b += nb
        raw_a += na
        raw_x += nx
        raw_y += nx

    return ty

正如您所看到的,除了在 python 中执行的常见参数检查之外,它几乎是相同的事情(filter_double 是一个函数,如果您愿意,可以在单独的库中用纯 C 编写)。当然,由于它是编译后的代码,因此未能检查您的参数将使您的解释器崩溃,而不是引发异常(不过,最近的 cython 可以在安全性与速度之间进行多种级别的权衡)。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用 C 函数扩展 Numpy 的相关文章

随机推荐