python - 二分法的网格应用

标签 python numpy bisection

我需要找到广义状态空间的根。也就是说,我有一个离散的维度网格 grid=AxBx(...)xX ,其中我事前不知道它有多少维度(该解决方案应该适用于任何 grid.size )。

我想找到每个状态的根 ( f(z) = 0 ) z里面grid使用 bisection method .说 remainder包含 f(z) ,我知道 f'(z) < 0 .那我需要

  • 增加z如果remainder > 0
  • 减少z如果remainder < 0

Wlog,说矩阵history形状(grid.shape, T)包含 z 的早期值的历史记录对于网格中的每个点,我都需要增加 z (因为 remainder > 0)。然后我需要选择 zAlternative里面history[z, :]那是“最小的那些,大于 z ”。在伪代码中,即:

zAlternative =  hist[z,:][hist[z,:] > z].min()

I had asked this earlier .我得到的解决方案是

b = sort(history[..., :-1], axis=-1)
mask = b > history[..., -1:]
index = argmax(mask, axis=-1)
indices = tuple([arange(j) for j in b.shape[:-1]])
indices = meshgrid(*indices, indexing='ij', sparse=True)
indices.append(index)
indices = tuple(indices)
lowerZ = history[indices]

b = sort(history[..., :-1], axis=-1)
mask = b <= history[..., -1:]
index = argmax(mask, axis=-1)
indices = tuple([arange(j) for j in b.shape[:-1]])
indices = meshgrid(*indices, indexing='ij', sparse=True)
indices.append(index)
indices = tuple(indices)
higherZ = history[indices]

newZ = history[..., -1]
criterion = 0.05
increase = remainder > 0 + criterion
decrease = remainder < 0 - criterion
newZ[increase] = 0.5*(newZ[increase] + higherZ[increase])
newZ[decrease] = 0.5*(newZ[decrease] + lowerZ[decrease])

但是,此代码对我不再起作用。承认这一点让我感到非常难过,但我从来不理解指数的神奇之处,因此很遗憾我需要帮助。

代码实际上做了什么,它分别给我最低最高。也就是说,如果我确定两个特定的 z值(value)观:

history[z1] = array([0.3, 0.2, 0.1])
history[z2] = array([0.1, 0.2, 0.3])

我会得到 higherZ[z1] = 0.3lowerZ[z2] = 0.1 ,也就是极值。这两种情况的正确值都是 0.2 .这里出了什么问题?

如果需要,为了生成测试数据,您可以使用类似

的东西
history = tile(array([0.1, 0.3, 0.2, 0.15, 0.13])[newaxis,newaxis,:], (10, 20, 1))
remainder = -1*ones((10, 20))

测试第二种情况。

预期结果

我调整了 history上面的变量,给出向上和向下的测试用例。预期结果是

lowerZ = 0.1 * ones((10,20))
higherZ = 0.15 * ones((10,20))

也就是说,对于每个点 z在 history[z, :] 中,下一个最高的先前值(higherZ)和下一个最小的先前值(lowerZ)。自所有点z具有完全相同的历史记录 ( [0.1, 0.3, 0.2, 0.15, 0.13] ),它们都将具有相同的 lowerZ 值和 higherZ .当然,一般来说,每个z的历史记录将是不同的,因此这两个矩阵将在每个网格点上包含可能不同的值。

最佳答案

我将您在此处发布的内容与 the solution for your previous post 进行了比较并注意到一些差异。

对于较小 z,你说

mask = b > history[..., -1:]
index = argmax(mask, axis=-1)

他们说:

mask = b >= a[..., -1:]
index = np.argmax(mask, axis=-1) - 1

对于更大 z,你说

mask = b <= history[..., -1:]
index = argmax(mask, axis=-1)

他们说:

mask = b > a[..., -1:]
index = np.argmax(mask, axis=-1)

使用 the solution for your previous post ,我得到:

import numpy as np
history = np.tile(np.array([0.1, 0.3, 0.2, 0.15, 0.13])[np.newaxis,np.newaxis,:], (10, 20, 1))
remainder = -1*np.ones((10, 20))

a = history

# b is a sorted ndarray excluding the most recent observation
# it is sorted along the observation axis
b = np.sort(a[..., :-1], axis=-1)

# mask is a boolean array, comparing the (sorted)
# previous observations to the current observation - [..., -1:]
mask = b > a[..., -1:]

# The next 5 statements build an indexing array.
# True evaluates to one and False evaluates to zero.
# argmax() will return the index of the first True,
# in this case along the last (observations) axis.
# index is an array with the shape of z (2-d for this test data).
# It represents the index of the next greater
# observation for every 'element' of z.
index = np.argmax(mask, axis=-1)

# The next two statements construct arrays of indices
# for every element of z - the first n-1 dimensions of history.
indices = tuple([np.arange(j) for j in b.shape[:-1]])
indices = np.meshgrid(*indices, indexing='ij', sparse=True)

# Adding index to the end of indices (the last dimension of history)
# produces a 'group' of indices that will 'select' a single observation
# for every 'element' of z
indices.append(index)
indices = tuple(indices)
higherZ = b[indices]

mask = b >= a[..., -1:]
# Since b excludes the current observation, we want the
# index just before the next highest observation for lowerZ,
# hence the minus one.
index = np.argmax(mask, axis=-1) - 1
indices = tuple([np.arange(j) for j in b.shape[:-1]])
indices = np.meshgrid(*indices, indexing='ij', sparse=True)
indices.append(index)
indices = tuple(indices)
lowerZ = b[indices]
assert np.all(lowerZ == .1)
assert np.all(higherZ == .15)

这似乎有效

关于python - 二分法的网格应用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/24098205/

相关文章:

python - 来自 amavisd-new-cronjob sa-sync 的错误

python - 如何在有条件的字符串之间添加字符

python - 如何在科学文章或论文中引用 Python?

python - 查找最近索引值的最快方法

python - 多元线性回归的预测函数

python - 获取两个数组的行的有效方法,这些数组在其列的分数中具有匹配的值

algorithm - 快速 block 放置算法,需要建议吗?

python - 使用 Python 更新 SQLite 数据库中的行

python - 在 Python 中,如何找到排序列表中第一个大于阈值的值的索引?

牛顿与二分法的javascript实现