我正在尝试在 nd 数组上使用 numpy 的 setdiff1d 函数:
import numpy as np
#a,b being ndarrays
in_a_not_b = np.setdiff1d(a,b)
但它不起作用,因为它是按 nd 数组元素工作的。
例如,如果:
a = [[1,2,3],[4,5,6]]
b = [[7,2,3],[4,5,6],[7,8,9]]
我希望输出是:
[[1,2,3]]
但这里是:
[1]
是否有一种简单的方法将 setdiff1d 推广到 nd 数组?
最佳答案
您可以连接数组[b, a]
,然后检查沿axis=0
的唯一性,并仅选择那些大于len(b )
(即来自 a
的唯一值):
__, indices = np.unique(np.concatenate([b, a]), return_index=True, axis=0)
indices = indices[indices >= len(b)] - len(b)
result = a[indices]
这将返回沿第一个轴的 a
的唯一值(但这正是 setdiff1d
无论如何都会做的)。元素的顺序在这里很重要,因此如果不重要,那么您需要沿最后一个轴对数组进行排序。
关于python - 将 numpy.set1d 推广到 nd 数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57486168/