我在 Python 2.7 中有一个 numpy 索引数组,它对应于字典中的一个值。所以我想从字典中创建一个对应值的 numpy 数组。代码可能会立即清晰:
import numpy as np
indices = np.array([(0, 1), (2, 0), (2, 0)], dtype=[('A', int), ('B', int)])
d = {(0, 1): 10,
(2, 0): 9}
values = d[(indices['A'], indices['B'])]
最后一行的调用是不可哈希的(我试图找到一个 way to make a np.array hashable 但没有成功):
TypeError: unhashable type: 'numpy.ndarray'
我可以用一个循环来代替它,但这需要很长时间来编写变量 values
:
np.array([d[(indices[i]['A'], indices[i]['B'])] for i in range(len(indices))])
或者 dict 是否有任何替代方案可以使此类任务成为 pythonic,即更快?变量 indices
无法更改,但我可以更改 dict
的类型。
编辑
实际索引数组还包含其他条目。这就是为什么我把调用写得如此复杂:
indices = np.array([(0, 1, 's'), (2, 0, 's'), (2, 0, 't')],
dtype=[('A', int), ('B', int), ('C', str)])
最佳答案
我相信您可以对此使用列表理解(它会比普通的 for
循环方法快一点)。示例 -
values = [d[tuple(a)] for a in indices]
请注意,我使用 d
而不是 dict
,因为不建议使用 dict
作为变量名,因为这会影响内置类型 dict
。
演示 -
In [73]: import numpy as np
In [74]: indices = np.array([(0, 1), (2, 0), (2, 0)], dtype=[('A', int), ('B', int)])
In [76]: d = {(0, 1): 10,
....: (2, 0): 9}
In [78]: values = [d[tuple(a)] for a in indices]
In [79]: values
Out[79]: [10, 9, 9]
对于更大的阵列,更快的方法是使用 np.vectorize()
向量化 dict.get()
方法,然后将其应用于 indices
数组。示例 -
vecdget = np.vectorize(lambda x: d.get(tuple(x)))
vecdget(indices)
带计时结果的演示 -
In [88]: vecdget = np.vectorize(lambda x: d.get(tuple(x)))
In [89]: vecdget(indices)
Out[89]: array([10, 9, 9])
In [98]: indices = np.array([(0, 1), (2, 0), (2, 0)] * 100, dtype=[('A', int), ('B', int)])
In [99]: %timeit [d[tuple(a)] for a in indices]
100 loops, best of 3: 1.72 ms per loop
In [100]: %timeit vecdget(indices)
1000 loops, best of 3: 341 µs per loop
@hpaulj 在评论中建议的新方法的时间测试 - [d.get(x.item()) for x in indices]
-
In [114]: %timeit [d.get(x.item()) for x in indices]
1000 loops, best of 3: 417 µs per loop
In [115]: %timeit vecdget(indices)
1000 loops, best of 3: 331 µs per loop
In [116]: %timeit [d.get(x.item()) for x in indices]
1000 loops, best of 3: 354 µs per loop
In [117]: %timeit vecdget(indices)
1000 loops, best of 3: 262 µs per loop
关于python - 将 : np. 索引数组转换为相应字典条目的 np.array,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/32870709/