在我的 Python 3.7 新手项目中,许多函数中的参数都是 numpy.ndarray 的。这些必须是二维 r x n 矩阵。行维度 r
至关重要:某些函数需要 1 x n
向量,其他函数需要 2 x n
矩阵,其中 r
向上到三个甚至更多。还有为任何 r x n 数组定义的函数。 (列尺寸n
对于设计目的来说并不是必需的。)
根据我的 Matlab 经验,此要求可能会令人困惑且容易出错。所以我考虑了以下方法:
- 记录方法参数(当然!)
- 单元测试(当然!)
- 在某些函数内进行验证并抛出异常。 (但是,这不是很实用,也不是很高效。)
- 定义数据类:
OneRow
、TwoRows
、ThreeRows
和FourPlusRows
。每个都有一个 ndarray 字段,在构造函数中进行验证。其优点包括类型提示和更好的领域建模,类似于 DDD。缺点是额外的复杂性。
问题:考虑到 Python 3 中引入的类型提示以及函数式编程的趋势,当前解决此问题的 Python 方法是什么?
最佳答案
Python 最好的事情之一是 duck typing ,并且 Numpy 通常与该设计方法非常兼容。假设您有一个纯矢量函数 vecfunc
。您可以在函数的开头添加一些样板,将任何一维数组膨胀为 1 x n
向量:
def vecfunc(arr):
if arr.ndim==1:
arr = arr[None, :]
...function body goes here...
这将避免由于 arr
维度太少而导致的任何问题,并且在大多数情况下仍可能给出正确的行为。但是,它不会执行任何操作来阻止用户传入 r x n x m
数组或 15 x n
数组。最终,您将不得不使用方法3.
来处理一堆这样的东西,并在合适的地方抛出一些异常。例如:
def vecfunc(arr):
if not 0 < arr.ndim < 3:
raise ValueError("arr must have ndim of 1 or 2. arr.ndim: %d" % arr.ndim)
elif arr.ndim==1:
arr = arr[None, :]
如果这让你感觉好一点,两个 numpy
的代码库和 scipy
在需要的时间和地点,在许多函数中进行基于形状的异常检查。
当然,您始终可以不添加此类异常检查,直到开发任何给定函数的最后为止。您可能会对产生合理行为的输入范围感到惊讶。
如果您对类型注释一心一意,您可以通过writing your code using Cython获得类似的东西。例如,如果您想要一个仅接受 2D 整数数组的 add
函数,您可以在 .pyx
文件中编写以下函数:
import numpy as np
def add(long[:, :] arr1, long[:, :] arr2):
assert tuple(arr1.shape) == tuple(arr2.shape)
result = np.zeros((arr1.shape[0], arr1.shape[1]), dtype=np.long)
cdef long[:, :] result_view = result
for x in range(arr1.shape[0]):
for y in range(arr1.shape[1]):
result_view[x, y] = arr1[x, y] + arr2[x, y]
return result
有关编写和编译 Cython 的更多详细信息,请参阅上面链接的文档。
这与其说是“类型注释”,不如说是真正的强类型,但它可能会做你想要的事情。遗憾的是,我无法找到一种方法来固定单个维度的大小,而只能固定维度的总数。
关于python - Python 3 确保数组参数尺寸正确的方法是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53268485/