我有一个数组和一个 bool 数组(作为一种热编码)
a = np.arange(12).reshape(4,3)
b = np.array([
[1,0,0],
[0,1,0],
[0,0,1],
[0,0,1],
], dtype=bool)
print(a)
print(b)
# [[ 0 1 2]
# [ 3 4 5]
# [ 6 7 8]
# [ 9 10 11]]
# [[ True False False]
# [False True False]
# [False False True]
# [False False True]]
我想使用 bool 数组来选择元素
print(a[:, [True, False, False]])
# array([[0],
# [3],
# [6],
# [9]])
print(a[:, [False, True, False]])
# array([[ 1],
# [ 4],
# [ 7],
# [10]])
但是这会根据所有行的相同模板 bool 值进行选择。我想在每行的基础上执行此操作:
print(a[:, b])
# IndexError: too many indices for array
我应该在...
中放入什么,这样我就会得到:
print(a[:, ...])
# array([[0],
# [4],
# [8],
# [11]])
编辑:这类似于臭名昭著的 CS231 中使用的内容。类(class):
dscores = a
num_examples = 4
# They had 300
y = b
dscores[range(num_examples),y]
# equivalent to
# a{:,b]
编辑2:在CS231中例如,y
是一维的,不是热编码的!
他们正在做dscores[[rowIdx],[columnIdx]]
最佳答案
通过b
过滤后广播它
a[b][:,None]
Out[168]:
array([[ 0],
[ 4],
[ 8],
[11]])
或者
a[b,None]
Out[174]:
array([[ 0],
[ 4],
[ 8],
[11]])
关于python - Numpy:根据 bool 数组选择元素,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55635473/