假设我有这段代码:
import numpy as np
def myf(c):
return c*11
def method_A(c):
return c*999
def method_B(c):
return c*55
minimum = 30
maximum = 100
the_method = 'A'
b = np.array([1, 20, 35, 3, 45, 52, 78, 101, 127, 135])
我想在 where 中使用 numpy 来满足某些条件。
类似于:
b = np.where( np.logical_or(b < minimum , b > maximum) , b,
(if the_method == 'A': method_A(b)) ,
(if the_method == 'B': method_B(b)))
如果条件b < min or b > max
满足,保留 b 中的每个元素,否则,if the_method is A , call method A
否则 if method is B, call method B
所以,我尝试了:
b = np.where( np.logical_or(b < minimum , b > maximum) , b,
(np.where(the_method == 'A',method_A(b),b)),
(np.where(the_method == 'B',method_B(b),b))
)
这给了我 function takes at most 3 arguments (4 given)
因为 np.where 不能接受超过 3 个参数。
有办法解决我的问题吗?
最佳答案
the_method
是标量而不是数组,因此您不需要内部 np.where
:
if the_method == 'A':
which_method = method_A
elif the_method == 'B':
which_method = method_B
else:
raise ValueError
b = np.where(
(b < minimum) | (b > maximum),
b,
which_method(b)
)
事实上,如果您在这两种方法中都插入一个 print
,您会发现与使用 np.where
时相比,它们都会运行。
如果您真的想用一个表达式来实现:
def _raise(x): raise x
b = np.where(
(b < minimum) | (b > maximum),
b,
(
method_A if the_method == 'A' else
method_B if the_method == 'B' else
_raise(ValueError)
)(b)
)
关于python - numpy where 最多接受 3 个参数 - 解决这个问题的方法?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41953928/