我发现自己在几个不同的场景中一直面临这个问题。于是想到这里分享一下,看看有没有最优的解决方法。
假设我有一个大数组,其中包含任何 X 和另一个与 X 大小相同的数组,称为 y,上面有 x 所属的标签。所以像下面这样。
X = np.array(['obect1', 'object2', 'object3', 'object4', 'object5'])
y = np.array([0, 1, 1, 0, 2])
我想要的是构建一个字典/散列,它使用标签集作为键,X中所有带有这些标签的对象的索引作为项目 .所以在这种情况下,所需的输出将是:
{0: (array([0, 3]),), 1: (array([1, 2]),), 2: (array([4]),)}
请注意,实际上 X 上的内容并不重要,但为了完整起见,我将其包括在内。
现在,我对这个问题的幼稚解决方案是遍历所有标签并使用np.where==label
构建字典。更详细地说,我使用这个函数:
def get_key_to_indexes_dic(labels):
"""
Builds a dictionary whose keys are the labels and whose
items are all the indexes that have that particular key
"""
# Get the unique labels and initialize the dictionary
label_set = set(labels)
key_to_indexes = {}
for label in label_set:
key_to_indexes[label] = np.where(labels==label)
return key_to_indexes
现在我的问题的核心是: 有没有办法做得更好?有没有一种自然的方法可以使用 numpy 函数来解决这个问题?我的方法是否被误导了?
作为次要的横向问题:上述定义中解决方案的复杂性是什么?我认为解决方案的复杂性如下:
或者换句话说,标签的数量乘以在大小为 y 的集合中使用 np.where
的复杂性加上从数组中创建集合的复杂性。这是正确的吗?
PD我找不到与此特定问题相关的帖子,如果您有更改标题或任何内容的建议,我将不胜感激。
最佳答案
如果在遍历时使用字典存储索引,则只需要遍历一次:
from collections import defaultdict
def get_key_to_indexes_ddict(labels):
indexes = defaultdict(list)
for index, label in enumerate(labels):
indexes[label].append(index)
缩放看起来很像你为你的选项分析过的,因为它上面的函数是 O(N),其中 N 是 y
的大小,因为检查一个值是否在字典中是 O( 1).
所以有趣的是,由于 np.where
的遍历速度快得多,只要只有少量标签,您的函数就会更快。当有许多不同的标签时,我的速度似乎更快。
以下是函数的扩展方式:
蓝线是你的函数,红线是我的。线条样式表示不同标签的数量。 {10: ':', 100: '--', 1000: '-.', 10000: '-'}
。你可以看到我的功能相对独立于标签的数量,而当标签很多时你的功能很快变慢。如果您的标签很少,最好使用自己的标签。
关于python - numpy 数组 : a dictionary (hash) of label to index 中标签索引的最快逆运算,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36353258/