python - 如何使用基于 numpy ndarray 的索引覆盖 getitem 方法?

标签 python numpy-ndarray

我正在尝试构建一个基于 numpy.ndarray 的类,其中 __getitem____setitem__方法提供基于像这样的索引的数据 Point类:

import numpy as np

class Point:
    def __init__(self, number=20):
        dt = np.dtype([("x",np.float64), ("y",np.float64), ("alive",np.bool)])
        self.points = np.zeros((int(number),1), dtype=dt)
        self.points["alive"] = True

#    def __getitem__(self, i):
#        mask = self.points["alive"] == True
#        print("get")
#        return self.points[mask].__getitem__(i)

    def __getitem__(self, i):
        mask = self.points["alive"] == True
        print("get")
        return self.points[mask][i]

    def __setitem__(self, i, item):
        mask = self.points["alive"] == True
        print("set")
        self.points[mask][i] = item

如果我尝试:

p = Point()
print(p[0])
>>>>get
>>>>(0., 0., True)
print(p[0]["alive"])
>>>>get
>>>>True
p[0]["alive"] = False
>>>>get
print(p.points[0]["alive"])
>>>>[ True]

所以没有考虑修改,但我没有收到错误,就像我正在修改副本一样。我也很困惑,因为我没有调用 __setitem__方法但是 __getitem__方法。我使用 __getitem__ 尝试了另一个实现ndarray的但也存在同样的问题。

我做错了什么以及如何正确执行此操作?

最佳答案

我找到了如何使用 pandas 来做到这一点,它允许在没有链式索引的情况下执行此操作,并获取数组的 View 而不是副本:

import numpy as np
import pandas as pd

class Point:
    def __init__(self, number=20):
        d = {"x":np.zeros((int(number),)), "y":np.zeros((int(number),))}
        self.points = pd.DataFrame(d)
        self.alive = pd.Series(np.ones((int(number),),dtype=bool))

    def __getitem__(self, i):
        return self.points.loc[self.alive,i]

    def __setitem__(self, i, item):
        self.points.loc[self.alive,i] = item

这给出了正确的行为:

p = Point(3)
print(p[:])

>>>>     x    y
>>>> 0  0.0  0.0
>>>> 1  0.0  0.0
>>>> 2  0.0  0.0

p.alive[0] = False
print(p[:])

>>>>     x    y
>>>> 1  0.0  0.0
>>>> 2  0.0  0.0

p["x"] = p["x"] + 5
print(p[:])

>>>>     x    y
>>>> 1  5.0  0.0
>>>> 2  5.0  0.0

关于python - 如何使用基于 numpy ndarray 的索引覆盖 getitem 方法?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59826394/

相关文章:

python - 如何在 python 中从具有多个条件的数组 A 获取 bool 数组?

python - 迭代 numpy 数组的函数

python - 使用 numpy/ctypes 公开 C 分配的内存缓冲区的更安全方法?

python - 线程减慢响应时间 - python

python - 如何在 numpy 数组中应用条件语句?

python - 将字符串转换为numpy.ndarray python

python - BeautifulSoup:将标签(包含其他标签)拆分为两个字符串

python - 如何阻止警告对话框停止执行控制它的 Python 程序?

python - 获取嵌套字典中所有键的列表

python - 通过合并邻近的值来减少排序的数字列表的更好方法?