python - 过滤numpy数组的行?

标签 python numpy filter

我希望将一个函数应用于 numpy 数组的每一行。如果此函数计算结果为 true,我将保留该行,否则我将丢弃它。例如,我的函数可能是:

def f(row):
    if sum(row)>10: return True
    else: return False

我想知道是否有类似的东西:

np.apply_over_axes()

将函数应用于 numpy 数组的每一行并返回结果。我希望得到类似的东西:

np.filter_over_axes()

这会将一个函数应用于 numpy 数组的每一行,并且只返回函数返回 true 的行。有这样的吗?还是应该只使用 for 循环?

最佳答案

理想情况下,您将能够实现函数的矢量化版本并使用它来执行 boolean indexing .对于绝大多数问题,这是正确的解决方案。 Numpy 提供了相当多的函数,可以作用于各种轴以及所有基本操作和比较,因此大多数有用的条件应该是可向量化的。

import numpy as np

x = np.random.randn(20, 3)
x_new = x[np.sum(x, axis=1) > .5]

如果您绝对确定不能执行上述操作,我建议您使用列表推导(或 np.apply_along_axis)来创建一个 bool 数组以作为索引。

def myfunc(row):
    return sum(row) > .5

bool_arr = np.array([myfunc(row) for row in x])
x_new = x[bool_arr]

这将以相对干净的方式完成工作,但会比矢量化版本慢得多。一个例子:

x = np.random.randn(5000, 200)

%timeit x[np.sum(x, axis=1) > .5]
# 100 loops, best of 3: 5.71 ms per loop

%timeit x[np.array([myfunc(row) for row in x])]
# 1 loops, best of 3: 217 ms per loop

关于python - 过滤numpy数组的行?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/26154711/

相关文章:

python - 如何知道命令或方法是否为 'in place algorithm'?

python - 由于unicode字符,无法在Python中将字符串转换为Json

python - 使用 Lambda 函数时出现语法错误

Python - 每次 A 列等于 0 时从 B 列返回值

java - 如何在 Jersey 2.4 过滤器中获取资源注释?

Yii 中使用 GridView 进行 Ajax 过滤

python - ANSI SQL 相当于 Pandas `factorize()` ?

python - 如何使用 matplotlib 和 seaborn 在绘制的值上显示轴刻度标签?

python - 如何使用多个索引从 NumPy 数组中获取值

python - 打印格式化的 numpy 数组