python - 遍历 n 维数组的通用函数

标签 python arrays numpy cython memoryview

使用 Cython,有没有办法编写快速的通用函数,这些函数适用于不同维度的数组?例如,对于这种去混叠函数的简单情况:

import numpy as np
cimport numpy as np

ctypedef np.uint8_t DTYPEb_t
ctypedef np.complex128_t DTYPEc_t


def dealiasing1D(DTYPEc_t[:, :] data, 
                 DTYPEb_t[:] where_dealiased):
    """Dealiasing data for 1D solvers."""
    cdef Py_ssize_t ik, i0, nk, n0

    nk = data.shape[0]
    n0 = data.shape[1]

    for ik in range(nk):
        for i0 in range(n0):
            if where_dealiased[i0]:
                data[ik, i0] = 0.


def dealiasing2D(DTYPEc_t[:, :, :] data, 
                 DTYPEb_t[:, :] where_dealiased):
    """Dealiasing data for 2D solvers."""
    cdef Py_ssize_t ik, i0, i1, nk, n0, n1

    nk = data.shape[0]
    n0 = data.shape[1]
    n1 = data.shape[2]

    for ik in range(nk):
        for i0 in range(n0):
            for i1 in range(n1):
                if where_dealiased[i0, i1]:
                    data[ik, i0, i1] = 0.


def dealiasing3D(DTYPEc_t[:, :, :, :] data, 
                 DTYPEb_t[:, :, :] where_dealiased):
    """Dealiasing data for 3D solvers."""
    cdef Py_ssize_t ik, i0, i1, i2, nk, n0, n1, n2

    nk = data.shape[0]
    n0 = data.shape[1]
    n1 = data.shape[2]
    n2 = data.shape[3]

    for ik in range(nk):
        for i0 in range(n0):
            for i1 in range(n1):
                for i2 in range(n2):
                    if where_dealiased[i0, i1, i2]:
                        data[ik, i0, i1, i2] = 0.

在这里,我需要三个函数来处理一维、二维和三维情况。是否有一种好的方法来编写一个函数来完成所有(合理的)维度的工作?

PS:在这里,我尝试使用内存 View ,但我不确定这是执行此操作的正确方法。我很惊讶 if where_dealiased[i0]: data[ik, i0] = 0.cython -a 命令生成的带注释的 html 中不是白色的。有什么问题吗?

最佳答案

我要说的第一件事是,想要保留这 3 个函数是有原因的,如果使用更通用的函数,您可能会错过 cython 编译器和 c 编译器的优化。

制作一个包装这 3 个函数的函数是非常可行的,它只需将两个数组作为 python 对象,检查形状,然后调用相关的其他函数。

但如果要尝试这样做,那么我会尝试的只是为最高维度编写函数,然后使用较低维度的数组将它们重新转换为较高维度的数组,方法是使用 new axis符号:

cdef np.uint8_t [:] a1d = np.zeros((256, ), np.uint8) # 1d
cdef np.uint8_t [:, :] a2d = a1d[None, :]             # 2d
cdef np.uint8_t [:, :, :] a3d = a1d[None, None, :]    # 3d
a2d[0, 100] = 42
a3d[0, 0, 200] = 108
print(a1d[100], a1d[200])
# (42, 108)

cdef np.uint8_t [:, :] data2d = np.zeros((128, 256), np.uint8) #2d
cdef np.uint8_t [:, :, :, :] data4d = data2d[None, None, :, :] #4d
data4d[0, 0, 42, 108] = 64
print(data2d[42, 108])
# 64

如您所见,内存 View 可以转换为更高的维度,并可用于修改原始数据。在将新 View 传递给最高维函数之前,您可能仍想编写一个包装函数来执行这些技巧。我怀疑这个技巧在你的情况下会很好用,但你必须四处游玩才能知道它是否会用你的数据做你想做的事。

PS: 有一个非常简单的解释。 “额外代码”是生成索引错误、类型错误的代码,它允许您使用 [-1] 从数组的末尾而不是开始(环绕)进行索引。 您可以禁用这些额外的 Python 功能,并通过使用 compiler directives 将其减少为 C 数组功能。 ,例如,要从整个文件中删除这些额外代码,您可以在文件开头添加注释:

# cython: boundscheck=False, wraparound=False, nonecheck=False

编译器指令也可以使用装饰器在函数级别应用。文档解释。

关于python - 遍历 n 维数组的通用函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/26207080/

相关文章:

python - 使用 OpenID 的 Pyramid 应用程序

python - 如何将列中的所有项目对齐到 QTableWidget 的中心

python - 在 Python 中计算 BLEU 分数

c# - 如何测试一个对象是否是数组的数组/锯齿状数组

python - 使用 numpy.savez() 保存标题信息字典

python - 成功更新资源后,我收到 JSON 解码错误

ios - 如何使用包含某些键/值的字典填充数组

c - C中解析指针参数

python - 导入 sklearn 时出现不可排序类型错误

python - Pandas DataFrame 到多维 NumPy 数组