我需要在低至 -150
的范围内计算以下函数的积分:
import numpy as np
from scipy.special import ndtr
def my_func(x):
return np.exp(x ** 2) * 2 * ndtr(x * np.sqrt(2))
问题是这部分函数
np.exp(x ** 2)
趋于无穷大——当 x
的值小于大约 -26
时,我得到 inf
。
还有这部分功能
2 * ndtr(x * np.sqrt(2))
相当于
from scipy.special import erf
1 + erf(x)
趋向于 0。
所以,一个非常非常大的数字乘以一个非常非常小的数字应该会得到一个合理大小的数字——但是,python
却给了我 nan
.
我该怎么做才能避免这个问题?
最佳答案
我认为 @askewchan 的解决方案和 scipy.special.log_ndtr
的组合会成功的:
from scipy.special import log_ndtr
_log2 = np.log(2)
_sqrt2 = np.sqrt(2)
def my_func(x):
return np.exp(x ** 2) * 2 * ndtr(x * np.sqrt(2))
def my_func2(x):
return np.exp(x * x + _log2 + log_ndtr(x * _sqrt2))
print(my_func(-150))
# nan
print(my_func2(-150)
# 0.0037611803122451198
对于 x <= -20
, log_ndtr(x)
uses a Taylor series expansion of the error function to iteratively compute the log CDF directly ,这比简单地采用 log(ndtr(x))
在数值上更稳定.
更新
正如您在评论中提到的,exp
如果 x
也可以溢出足够大。虽然您可以使用 mpmath.exp
解决此问题, 一个更简单和更快的方法是转换为 np.longdouble
在我的机器上,它可以表示最大 1.189731495357231765e+4932 的值:
import mpmath
def my_func3(x):
return mpmath.exp(x * x + _log2 + log_ndtr(x * _sqrt2))
def my_func4(x):
return np.exp(np.float128(x * x + _log2 + log_ndtr(x * _sqrt2)))
print(my_func2(50))
# inf
print(my_func3(50))
# mpf('1.0895188633566085e+1086')
print(my_func4(50))
# 1.0895188633566084842e+1086
%timeit my_func3(50)
# The slowest run took 8.01 times longer than the fastest. This could mean that
# an intermediate result is being cached 100000 loops, best of 3: 15.5 µs per
# loop
%timeit my_func4(50)
# The slowest run took 11.11 times longer than the fastest. This could mean
# that an intermediate result is being cached 100000 loops, best of 3: 2.9 µs
# per loop
关于python - 欺骗 numpy/python 来表示非常大和非常小的数字,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/32303006/