python - 从 ndarray 继承调用 __getitem__

标签 python numpy

您好,我正在尝试从 ndarray 派生一个类。我坚持在 docs 中找到的食谱但是当我重写 __getiem__() 函数时,我得到了一个我不明白的错误。我确定这是它应该如何工作,但我不明白如何正确地做到这一点。我的类基本上添加了一个“dshape”属性,如下所示:

class Darray(np.ndarray):
    def __new__(cls, input_array, dshape, *args, **kwargs):
        obj = np.asarray(input_array).view(cls)
        obj.SelObj = SelObj
        obj.dshape = dshape
        return obj

    def __array_finalize__(self, obj):
        if obj is None: return
        self.info = getattr(obj, 'dshape', 'N')  

    def __getitem__(self, index):        
        return self[index]

当我现在尝试做的时候:

D = Darray( ones((10,10)), ("T","N"))

解释器将因最大深度递归而失败,因为他一遍又一遍地调用 __getitem__

有人可以向我解释为什么以及如何实现 getitem 函数吗?

干杯, 大卫

最佳答案

can someone explain to me why and how one would implement a getitem function?

对于您当前的代码,不需要 __getitem__。当我删除 __getitem__ 实现时,您的类工作正常(未定义的 SelObj 除外)。

最大递归深度错误的原因是__getitem__的定义,它使用了self[index]:self.__getitem__(index )。如果您必须重写 __getitem__,请确保调用 __getitem__ 的父类(super class)实现:

def __getitem__(self, index):
    return super(Darray, self).__getitem__(index)

至于为什么要这样做:有很多原因需要重写这个函数,例如您可以将名称与数组的行相关联:

class NamedRows(np.ndarray):
    def __new__(cls, rows, *args, **kwargs):
        obj = np.asarray(*args, **kwargs).view(cls)
        obj.__row_name_idx = dict((n, i) for i, n in enumerate(rows))
        return obj

    def __getitem__(self, idx):
        if isinstance(idx, basestring):
            idx = self.__row_name_idx[idx]
        return super(NamedRows, self).__getitem__(idx)

演示:

>>> a = NamedRows(["foo", "bar"], [[1,2,3], [4,5,6]])
>>> a["foo"]
NamedRows([1, 2, 3])

关于python - 从 ndarray 继承调用 __getitem__,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/19305128/

相关文章:

csv - genfromtxt dtype=None 返回错误的形状

python - 按公共(public)日期对数组数据进行排序

Python fromtimestamp() 方法不一致

python - 没有使用请求库的 cookie

python - 显示评估选择的输出 - Sublime Text Python REPL

python - 在python中将大图像文件读取为数组

python - 从数组中删除重复的元素

python - 转动Python字典。到 Excel 工作表

python - 导入错误 : No module named 'telegram. vendor.ptb_urllib3

python - 带有 numpy 的二维数组