我正在尝试构建一个基于 numpy.ndarray 的类,其中 __getitem__
和__setitem__
方法提供基于像这样的索引的数据 Point
类:
import numpy as np
class Point:
def __init__(self, number=20):
dt = np.dtype([("x",np.float64), ("y",np.float64), ("alive",np.bool)])
self.points = np.zeros((int(number),1), dtype=dt)
self.points["alive"] = True
# def __getitem__(self, i):
# mask = self.points["alive"] == True
# print("get")
# return self.points[mask].__getitem__(i)
def __getitem__(self, i):
mask = self.points["alive"] == True
print("get")
return self.points[mask][i]
def __setitem__(self, i, item):
mask = self.points["alive"] == True
print("set")
self.points[mask][i] = item
如果我尝试:
p = Point()
print(p[0])
>>>>get
>>>>(0., 0., True)
print(p[0]["alive"])
>>>>get
>>>>True
p[0]["alive"] = False
>>>>get
print(p.points[0]["alive"])
>>>>[ True]
所以没有考虑修改,但我没有收到错误,就像我正在修改副本一样。我也很困惑,因为我没有调用 __setitem__
方法但是 __getitem__
方法。我使用 __getitem__
尝试了另一个实现ndarray
的但也存在同样的问题。
我做错了什么以及如何正确执行此操作?
最佳答案
我找到了如何使用 pandas 来做到这一点,它允许在没有链式索引的情况下执行此操作,并获取数组的 View 而不是副本:
import numpy as np
import pandas as pd
class Point:
def __init__(self, number=20):
d = {"x":np.zeros((int(number),)), "y":np.zeros((int(number),))}
self.points = pd.DataFrame(d)
self.alive = pd.Series(np.ones((int(number),),dtype=bool))
def __getitem__(self, i):
return self.points.loc[self.alive,i]
def __setitem__(self, i, item):
self.points.loc[self.alive,i] = item
这给出了正确的行为:
p = Point(3)
print(p[:])
>>>> x y
>>>> 0 0.0 0.0
>>>> 1 0.0 0.0
>>>> 2 0.0 0.0
p.alive[0] = False
print(p[:])
>>>> x y
>>>> 1 0.0 0.0
>>>> 2 0.0 0.0
p["x"] = p["x"] + 5
print(p[:])
>>>> x y
>>>> 1 5.0 0.0
>>>> 2 5.0 0.0
关于python - 如何使用基于 numpy ndarray 的索引覆盖 getitem 方法?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59826394/