floating-point - 当加数包含舍入误差时如何评估交替序列?

标签 floating-point julia precision numerical-analysis numerical-stability

我想用数值评估线性生灭过程的转移概率

哪里 是二项式系数,

对于大多数参数组合,我能够以可接受的数值误差(使用对数和 Kahan-Neumaier 求和算法)对其进行评估。

当加数的符号交替变化并且数字误差在总和中占主导地位时(在这种情况下条件数趋于无穷大),就会出现问题。发生这种情况时

例如,我在计算 p(1000, 2158, 72.78045, 0.02, 0.01) 时遇到问题。它应该是 0,但我得到了很大的值 log(p) ≈ 99.05811,这对于概率来说是不可能的。

我尝试以多种不同方式重构求和,并使用各种“精确”求和算法,例如 Zhu-Hayes .我总是得到大致相同的错误值,这让我认为问题不在于我对数字求和的方式,而在于每个加数的内部表示。

由于二项式系数,值很容易溢出。我尝试了线性变换,以便将每个(绝对)元素保持在最低正常数和 1 之间的总和中。它没有帮助,我认为这是因为许多类似量级的代数运算。

我现在陷入了死胡同,不知道如何继续。我可以使用任意精度的算术库,但计算成本对于我的 Markov Chain Monte Carlo 应用程序来说太高了。

当我们不能以足够好的精度将部分和存储在 IEEE-754 double 中时,是否有适当的方法或技巧来计算此类和?

这是一个基本的工作示例,其中我仅按最大值重新调整值并使用 Kahan 求和算法求和。显然,大多数值最终都是 Float64 的次正规值。

# this is the logarithm of the absolute value of element h
@inline function log_addend(a, b, h, lα, lβ, lγ)
  log(a) + lgamma(a + b - h) - lgamma(h + 1) - lgamma(a - h + 1) -
  lgamma(b - h + 1) + (a - h) * lα + (b - h) * lβ + h * lγ
end

# this is the logarithm of the ratio between (absolute) elements i and j
@inline function log_ratio(a, b, i, j, q)
  lgamma(j + 1) + lgamma(a - j + 1) + lgamma(b - j + 1) + lgamma(a + b - i) -
  lgamma(i + 1) - lgamma(a - i + 1) - lgamma(b - i + 1) - lgamma(a + b - j) +
  (j - i) * q
end

# function designed to handle the case of an alternating series with λ > μ > 0
function log_trans_prob(a, b, t, λ, μ)
  n = a + b
  k = min(a, b)

  ω = μ / λ
  η = exp((μ - λ) * t)

  if b > zero(b)
    lβ = log1p(-η) - log1p(-ω * η)
    lα = log(μ) + lβ - log(λ)
    lγ = log(ω - η) - log1p(-ω * η)
    q = lα + lβ - lγ

    # find the index of the maximum addend in the sum
    # use a numerically stable method for solving quadratic equations
    x = exp(q)
    y = 2 * x / (1 + x) - n
    z = ((b - x) * a - x * b) / (1 + x)

    sup = if y < zero(y)
      ceil(typeof(a), 2 * z / (-y + sqrt(y^2 - 4 * z)))
    else
      ceil(typeof(a), (-y - sqrt(y^2 - 4 * z)) / 2)
    end

    # Kahan summation algorithm
    val = zero(t)
    tot = zero(t)
    err = zero(t)
    res = zero(t)
    for h in 0:k
      # the problem happens here when we call the `exp` function
      # My guess is that log_ratio is either very big or very small and its
      # `exp` cannot be properly represented by Float64
      val = (-one(t))^h * exp(log_ratio(a, b, h, sup, q))
      tot = res + val
      # Neumaier modification
      err += (abs(res) >= abs(val)) ? ((res - tot) + val) : ((val - tot) + res)
      res = tot
    end

    res += err

    if res < zero(res)
      # sum cannot be negative (they are probabilities), it might be because of
      # rounding errors
      res = zero(res)
    end

    log_addend(a, b, sup, lα, lβ, lγ) + log(res)
  else
    a * (log(μ) + log1p(-η) - log(λ) - log1p(-ω * η))
  end
end

# ≈ 99.05810564477483 => impossible
log_trans_prob(1000, 2158, 72.78045, 0.02, 0.01)

# increasing precision helps but it is too slow for applications
log_trans_prob(BigInt(1000), BigInt(2158), BigFloat(72.78045), BigFloat(0.02),
               BigFloat(0.01))

最佳答案

我终于解决了这个问题并写了一篇论文详细描述了解决方案:https://arxiv.org/abs/1909.10765

简而言之,将每个加数除以和中的第一项得到

p(a, b, t, λ, μ) = ω(a, b, t, λ, μ) 2F1(-a, -b; -(a + b - 1); -z (t, λ, μ))

其中 ω(a, b, t, λ, μ) 是级数中的第一项,2F1 是高斯超几何函数。 超几何函数2F1(-a, -b; -(a + b - k); -z)(ab正整数, k <= 1, z 实数)可以用以下三项递推关系(TTRR)计算:

u(a, b, k) y(b + 1) + v(a, b, k, z) y(b) + w(b, k, z) y(b - 1) = 0

在哪里

u(a, b, k) = (a + b + 1 − k) (a + b − k)

v(a, b, k, z) = − (a + b − k) (a + b + 1 − k + (a − b) z)

w(b, k, z) = − b (b − k) z

如果 b > a 交换两个变量(即 a' = max(a, b)b' = min(a, b) )。

从值 y(0) = 2F1(-a, 0; -(a - k); -z) = 1y( 1) = 2F1(-a, -1; -(a + 1 - k); -z) = 1 + (a z)/(a + 1 - k)

我在 Julia 包中实现了之前的算法 SimpleBirthDeathProcess .

关于floating-point - 当加数包含舍入误差时如何评估交替序列?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51748069/

相关文章:

python-3.x - Cart-Pole Python 性能比较

julia - 字典的向量化索引

arrays - 如何将长格式(可能稀疏)的 DataFrame 转换为多维 Array 或 NamedArray

hibernate - JPA 注释不保留 BigDecimal 精度

sql - MSSQL - 联合所有具有不同小数精度的

python - 如何在 Python 中将 float 转换为 10 次方的科学记数法?

perl - 为什么 Perl 的 sprintf 不能正确舍入 float ?

java - Java HashMap 拷贝构造函数为什么会影响 float 精度?

linux - 在 x86 linux 上使用软件浮点

java - 在 Java 中从 Matlab 计算复杂的算术表达式