我有一个自定义类Field
的对象,它本质上包裹着一个numpy.ndarray
对象。该对象由两个输入定义:一个值数组 (values
) 和一个切片对象 (segment
),该对象定义这些值应放置在某个较大数组 (网格
)。
我希望能够使用grid
的索引来访问values
的项目。这应该可以通过定义自定义 Field.__getitem__
方法来实现。
import numpy as np
class Field:
def __init__(self, values, segment, grid):
if (not isinstance(segment, slice)) \\
or (not isinstance(values, np.ndarray)) :
raise TypeError
if segment.step not in [1, -1]:
raise ValueError('Segment must be continuous')
if len(grid[segment]) != len(values):
raise ValueError('values length must match segment')
self.values = values
self.segment = segment
self.grid = grid
def __getitem__(self, key):
new_key = ... # <--- Code goes here
return self.values[new_key]
grid = np.array([0.5, 1.5, 2.5, 3.5, 4.5])
values = np.array([42., 43., 44.])
segment = slice(2, 5)
my_field = Field(values, segment, grid)
print(grid[segment]) # output: [2.5, 3.5, 4.5]
print(my_field[2]) # Desired output: 42.
print(my_field[3]) # Desired output: 43.
print(my_field[0]) # Desired output: IndexError
重点是,segment
定义了 grid
中定义 my_field
的位置集。
我的处理方法被证明非常不优雅和笨拙,并且基于定义一些 bool 数组 index = np.zeros_like(grid, dtype=bool); index[segment] = True
然后涉及一些 np.cumsum(index)
...
如何以更简单的方式实现此行为?
最佳答案
您可以使用明确的步骤定义切片:
segment = slice(2, 5, 1)
这是为了确保 __init__
中的 segment.step
返回 1
。然后定义一个方法来检查您的输入key
是否在适当的范围
内:
def __getitem__(self, key):
start, stop = self.segment.start, self.segment.stop
new_key = key - start
if new_key not in range(stop - start):
raise IndexError(f'Key must be in range({start}, {stop})')
return self.values[new_key]
这给出:
my_field = Field(values, segment, grid)
print(grid[segment]) # [2.5, 3.5, 4.5]
print(my_field[2]) # 42.0
print(my_field[3]) # 43.0
print(my_field[0]) # IndexError: Key must be in range(2, 5)
关于python - 如何使用一个数组的索引来定义另一个数组的 __getitem__ ?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53697018/