我想编写一个对 ndarray(任意维度)进行操作的函数。
我发现的 cython 示例如下:
@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False) # turn off negative index wrapping for entire function
def c_fill_na_inplace_f64(np.ndarray[np.int64_t, cast=False, mode="c"] X, np.int64_t fillval):
cdef size_t N = X.size
cdef np.int64_t* = &X[0]
但是,我怀疑对于不同维度的数组我需要不同的索引:
cdef np.int64_t* = &X[0]
cdef np.int64_t* = &X[0,0]
cdef np.int64_t* = &X[0,0,0]
是否有更好的方法来获取数据。
就此而言,如何编写一个接受任意维度的 ndarray 的函数? np.ndarray[ndim=1...
假设未提供 ndim
。
(注意。由于 const bug,我无法使用 MemoryView)。
最佳答案
也许最困难的部分是获得正确的签名。我将使用通用的 python-object +reshape(-1)
,它将数组重新解释为一维数组而不进行复制。
这是一个有点草率的例子。它至少适用于连续数组,对于其他数组则更复杂,也是因为 C 代码必须意识到它,并且仅传递指针 &X[0]
是不够的:
%%cython
cimport numpy as np
def print_nd(object X):
cdef np.ndarray[np.int64_t] arr=X.reshape(-1)
cdef np.int64_t *ptr=&arr[0]
print("addr", <np.uint64_t>(ptr))
print("ptr1:", ptr[0])
这里是一个测试(针对 2 维数组):
%%cython
cimport numpy as np
def print_2d(np.ndarray[np.int64_t, ndim=2] X):
cdef np.int64_t *ptr=&X[0,0]
print("addr", <np.uint64_t>(ptr))
print("ptr1:", ptr[0])
>>> print_nd(a)
addr 94635612809360
ptr1: 1
>>> print_2d(a)
addr 94635612809360
ptr1: 1
如您所见,这两种情况下都是相同的地址。
关于python - 如何在 cython 中获取指向 ndarray 数据(任意维度)开头的指针,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48232319/