python - 如何使用一个数组的索引来定义另一个数组的 __getitem__ ?

标签 python oop slice

我有一个自定义类Field的对象,它本质上包裹着一个numpy.ndarray对象。该对象由两个输入定义:一个值数组 (values) 和一个切片对象 (segment),该对象定义这些值应放置在某个较大数组 (网格)。

我希望能够使用grid的索引来访问values的项目。这应该可以通过定义自定义 Field.__getitem__ 方法来实现。

import numpy as np

class Field:
    def __init__(self, values, segment, grid):
        if (not isinstance(segment, slice)) \\
        or (not isinstance(values, np.ndarray)) :
            raise TypeError
        if segment.step not in [1, -1]:
            raise ValueError('Segment must be continuous')
        if len(grid[segment]) != len(values):
            raise ValueError('values length must match segment')

        self.values = values
        self.segment = segment 
        self.grid = grid

    def __getitem__(self, key):
        new_key = ...  # <--- Code goes here
        return self.values[new_key]

grid = np.array([0.5, 1.5, 2.5, 3.5, 4.5])

values = np.array([42., 43., 44.])
segment = slice(2, 5)

my_field = Field(values, segment, grid)
print(grid[segment])  # output: [2.5, 3.5, 4.5]
print(my_field[2])  # Desired output: 42.
print(my_field[3])  # Desired output: 43.
print(my_field[0])  # Desired output: IndexError

重点是,segment 定义了 grid 中定义 my_field 的位置集。 我的处理方法被证明非常不优雅和笨拙,并且基于定义一些 bool 数组 index = np.zeros_like(grid, dtype=bool); index[segment] = True 然后涉及一些 np.cumsum(index) ...

如何以更简单的方式实现此行为?

最佳答案

您可以使用明确的步骤定义切片:

segment = slice(2, 5, 1)

这是为了确保 __init__ 中的 segment.step 返回 1。然后定义一个方法来检查您的输入key是否在适当的范围内:

def __getitem__(self, key):
    start, stop = self.segment.start, self.segment.stop
    new_key = key - start
    if new_key not in range(stop - start):
        raise IndexError(f'Key must be in range({start}, {stop})')
    return self.values[new_key]

这给出:

my_field = Field(values, segment, grid)
print(grid[segment])  # [2.5, 3.5, 4.5]
print(my_field[2])    # 42.0
print(my_field[3])    # 43.0
print(my_field[0])    # IndexError: Key must be in range(2, 5)

关于python - 如何使用一个数组的索引来定义另一个数组的 __getitem__ ?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53697018/

相关文章:

python - 删除文本中的相关连字符

python - 以 block 的形式循环遍历 Pandas Dataframe

python - objective C/iOS 中 python 的 file.read() 的等价物是什么?

python - 在 Python 中列出 JSON 字段

c++ - C++如何在内部实现多态性?

java - 我应该如何在Java中实现接口(interface)? [代码正确性]

java - 直接调用方法 vs 方法重载

Python 切片和负步幅——为什么这些例子明显矛盾?

python - 在Python中对列表进行切片

arrays - 复制变量时的数组地址