python - numpy where 最多接受 3 个参数 - 解决这个问题的方法?

标签 python numpy

假设我有这段代码:

import numpy as np

def myf(c):
    return c*11

def method_A(c):
    return c*999

def method_B(c):
    return c*55

minimum = 30
maximum = 100
the_method = 'A'

b = np.array([1, 20, 35, 3, 45, 52, 78, 101, 127, 135])

我想在 where 中使用 numpy 来满足某些条件。

类似于:

b = np.where( np.logical_or(b < minimum , b > maximum) , b, 
             (if the_method == 'A': method_A(b)) ,
             (if the_method == 'B': method_B(b)))

如果条件b < min or b > max满足,保留 b 中的每个元素,否则,if the_method is A , call method A否则 if method is B, call method B

所以,我尝试了:

b = np.where( np.logical_or(b < minimum , b > maximum) , b, 
             (np.where(the_method == 'A',method_A(b),b)),
             (np.where(the_method == 'B',method_B(b),b))
            )

这给了我 function takes at most 3 arguments (4 given)因为 np.where 不能接受超过 3 个参数。

有办法解决我的问题吗?

最佳答案

the_method 是标量而不是数组,因此您不需要内部 np.where :

if the_method == 'A':
    which_method = method_A 
elif the_method == 'B':
    which_method = method_B 
else:
    raise ValueError

b = np.where(
    (b < minimum) | (b > maximum),
     b, 
     which_method(b)
)

事实上,如果您在这两种方法中都插入一个 print,您会发现与使用 np.where 时相比,它们都会运行。

如果您真的想用一个表达式来实现:

def _raise(x): raise x

b = np.where(
    (b < minimum) | (b > maximum),
     b, 
     (
         method_A if the_method == 'A' else
         method_B if the_method == 'B' else
         _raise(ValueError)
     )(b)
)

关于python - numpy where 最多接受 3 个参数 - 解决这个问题的方法?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41953928/

相关文章:

python - 如何获取包含 Pyspark Dataframe 中另一列中给出的多列值的列表列?

python - postgresql 中 "TEXT"数据类型的最大大小

python - NumPy ufunc 在一个轴上比另一个轴快 2 倍

python - 将每一行乘以不同的旋转矩阵

python - 是否有一个 numpy 内置函数来拒绝列表中的异常值

python - 属性错误 : 'str' object has no attribute 'insert'

python - django 1.9 不为自定义用户模型创建表

python - 如何使用 matplotlib.pyplot 更改图例大小

python - Python中是否有像TensorFlow的tf.image.resize_images函数那样调整图像大小的resize函数?

python - Tensorflow运行时确定Tensor的形状