假设我有以下功能:
import numpy, itertools
def my_func(x):
Z = 0
for y in itertools.product([-1, 1], repeat=10):
Z += numpy.exp(numpy.dot(x,y))
return numpy.log(Z)
该函数适用于不包含“太极端”值的输入向量。但是,如果输入值过于极端,在某些时候 numpy.exp 函数会出现溢出。这是一个例子:
x = numpy.random.normal(5, 10, 10)
my_func(x)
到目前为止,该功能运行良好。但是,如果我用极值替换 x 的一个元素,则会出现溢出错误:
x[3] = 6000
my_func(x)
有没有办法重写函数以避免溢出?我知道溢出出现的位置和原因。尽管如此,我还是找不到重写该函数的方法来避免它。
最佳答案
NumPy 有 np.logaddexp(x1, x2)
,它计算 log(exp(x1) + exp(x2))
。所以你可以将循环重写为:
z = np.logaddexp(z, np.dot(x, y))
并跳过最后一次调用np.log
。
关于python - 重写函数以避免 numpy.exp 中的溢出,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44139895/