python - Python 3 确保数组参数尺寸正确的方法是什么?

标签 python arrays numpy type-hinting

在我的 Python 3.7 新手项目中,许多函数中的参数都是 numpy.ndarray 的。这些必须是二维 r x n 矩阵。行维度 r 至关重要:某些函数需要 1 x n 向量,其他函数需要 2 x n 矩阵,其中 r 向上到三个甚至更多。还有为任何 r x n 数组定义的函数。 (列尺寸n对于设计目的来说并不是必需的。)

根据我的 Matlab 经验,此要求可能会令人困惑且容易出错。所以我考虑了以下方法:

  1. 记录方法参数(当然!)
  2. 单元测试(当然!)
  3. 在某些函数内进行验证并抛出异常。 (但是,这不是很实用,也不是很高效。)
  4. 定义数据类:OneRowTwoRowsThreeRowsFourPlusRows。每个都有一个 ndarray 字段,在构造函数中进行验证。其优点包括类型提示和更好的领域建模,类似于 DDD。缺点是额外的复杂性。

问题:考虑到 Python 3 中引入的类型提示以及函数式编程的趋势,当前解决此问题的 Python 方法是什么?

最佳答案

Python 最好的事情之一是 duck typing ,并且 Numpy 通常与该设计方法非常兼容。假设您有一个纯矢量函数 vecfunc。您可以在函数的开头添加一些样板,将任何一维数组膨胀为 1 x n 向量:

def vecfunc(arr):
    if arr.ndim==1:
        arr = arr[None, :]

    ...function body goes here...

这将避免由于 arr 维度太少而导致的任何问题,并且在大多数情况下仍可能给出正确的行为。但是,它不会执行任何操作来阻止用户传入 r x n x m 数组或 15 x n 数组。最终,您将不得不使用方法3.来处理一堆这样的东西,并在合适的地方抛出一些异常。例如:

def vecfunc(arr):
    if not 0 < arr.ndim < 3:
        raise ValueError("arr must have ndim of 1 or 2. arr.ndim: %d" % arr.ndim)
    elif arr.ndim==1:
        arr = arr[None, :]

如果这让你感觉好一点,两个 numpy 的代码库和 scipy在需要的时间和地点,在许多函数中进行基于形状的异常检查。

当然,您始终可以不添加此类异常检查,直到开发任何给定函数的最后为止。您可能会对产生合理行为的输入范围感到惊讶。

如果您对类型注释一心一意,您可以通过writing your code using Cython获得类似的东西。例如,如果您想要一个仅接受 2D 整数数组的 add 函数,您可以在 .pyx 文件中编写以下函数:

import numpy as np

def add(long[:, :] arr1, long[:, :] arr2):
    assert tuple(arr1.shape) == tuple(arr2.shape)

    result = np.zeros((arr1.shape[0], arr1.shape[1]), dtype=np.long)
    cdef long[:, :] result_view = result

    for x in range(arr1.shape[0]):
        for y in range(arr1.shape[1]):
            result_view[x, y] = arr1[x, y] + arr2[x, y]

    return result

有关编写和编译 Cython 的更多详细信息,请参阅上面链接的文档。

这与其说是“类型注释”,不如说是真正的强类型,但它可能会做你想要的事情。遗憾的是,我无法找到一种方法来固定单个维度的大小,而只能固定维度的总数。

关于python - Python 3 确保数组参数尺寸正确的方法是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53268485/

相关文章:

python - Python Flask 上有单例模式吗?

python - Numpy 优化

python - 通过求和降低数组的分辨率

python - 在Python中缩小图像的一部分

python - 当Python中最后一个字符是 `\`时如何创建原始字符串

python - 无法在 MacOS sierra 上使用 pip3 安装 mysqlclient

python - 类型错误 : expected string or buffer | Python

python - numpy reshape()和transpose()之间有交互规则吗?

java - 开关内的数组

javascript - 如何制作排序数组和计数数组按钮?