python - 如何在numpy中找到给定行值的列索引?

标签 python numpy

我想在 numpy 中执行一些相对简单的操作:

  • 如果行中有1,则返回包含1的列的索引+1。
  • 如果一行中有零个或多个1,则返回0。

但是我最终得到了一个相当复杂的代码:

predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])

one_count = (predictions == 1).sum(1)
valid_rows_idx = np.where(one_count==1)

result = np.zeros(predictions.shape[0])
for idx in valid_rows_idx:
    result[idx] = np.where(predictions[idx,:]==1)[1] + 1

如果我打印result,程序会打印 [ 1. 0. 4. 0.] 这是期望的结果。

我想知道是否有更简单的方法使用 numpy 编写最后一行。

最佳答案

我不确定这是否更好,但您可以尝试使用 argmax为了那个原因。此外,您不需要使用 for 循环和 np.where 来获取有效索引:

predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])

idx = (predictions == 1).sum(1) == 1
result = np.zeros(predictions.shape[0])
result[idx] = (predictions[idx]==1).argmax(axis=1) + 1

In [55]: result
Out[55]: array([ 1.,  0.,  4.,  0.])

或者您可以使用 np.whereargmax 在一行中完成所有这些工作:

predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])

In [72]: np.where((predictions==1).sum(1)==1, (predictions==1).argmax(axis=1)+1, 0)
Out[72]: array([1, 0, 4, 0])

关于python - 如何在numpy中找到给定行值的列索引?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35613034/

相关文章:

Python-读取文件后数据未附加到列表和缺少字典键

python - 像交互式 Ruby 一样的交互式 Python?

python - 如何修复自定义 conda 包的 conda UnsatisfiableError?

python - 如何添加列表作为 numpy n 维矩阵的元素

Python:嵌套列表合并并在同一索引中添加 int

python - 混合 Tornado 和sqlalchemy

带有全局常量变量的Python遗传优化多处理,如何加速?

python - 如何将 STL、网格导入 Django views.py 或 models.py

python - 为 32 位和 64 位编写 cython?

Python - 从 3D 数组写入三维图像