我需要找到广义状态空间的根。也就是说,我有一个离散的维度网格 grid=AxBx(...)xX
,其中我事前不知道它有多少维度(该解决方案应该适用于任何 grid.size
)。
我想找到每个状态的根 ( f(z) = 0
) z
里面grid
使用 bisection method .说 remainder
包含 f(z)
,我知道 f'(z) < 0
.那我需要
- 增加
z
如果remainder
> 0 - 减少
z
如果remainder
< 0
Wlog,说矩阵history
形状(grid.shape, T)
包含 z
的早期值的历史记录对于网格中的每个点,我都需要增加 z
(因为 remainder
> 0)。然后我需要选择 zAlternative
里面history[z, :]
那是“最小的那些,大于 z
”。在伪代码中,即:
zAlternative = hist[z,:][hist[z,:] > z].min()
I had asked this earlier .我得到的解决方案是
b = sort(history[..., :-1], axis=-1)
mask = b > history[..., -1:]
index = argmax(mask, axis=-1)
indices = tuple([arange(j) for j in b.shape[:-1]])
indices = meshgrid(*indices, indexing='ij', sparse=True)
indices.append(index)
indices = tuple(indices)
lowerZ = history[indices]
b = sort(history[..., :-1], axis=-1)
mask = b <= history[..., -1:]
index = argmax(mask, axis=-1)
indices = tuple([arange(j) for j in b.shape[:-1]])
indices = meshgrid(*indices, indexing='ij', sparse=True)
indices.append(index)
indices = tuple(indices)
higherZ = history[indices]
newZ = history[..., -1]
criterion = 0.05
increase = remainder > 0 + criterion
decrease = remainder < 0 - criterion
newZ[increase] = 0.5*(newZ[increase] + higherZ[increase])
newZ[decrease] = 0.5*(newZ[decrease] + lowerZ[decrease])
但是,此代码对我不再起作用。承认这一点让我感到非常难过,但我从来不理解指数的神奇之处,因此很遗憾我需要帮助。
代码实际上做了什么,它分别给我最低和最高。也就是说,如果我确定两个特定的 z
值(value)观:
history[z1] = array([0.3, 0.2, 0.1])
history[z2] = array([0.1, 0.2, 0.3])
我会得到 higherZ[z1]
= 0.3
和 lowerZ[z2] = 0.1
,也就是极值。这两种情况的正确值都是 0.2
.这里出了什么问题?
如果需要,为了生成测试数据,您可以使用类似
的东西history = tile(array([0.1, 0.3, 0.2, 0.15, 0.13])[newaxis,newaxis,:], (10, 20, 1))
remainder = -1*ones((10, 20))
测试第二种情况。
预期结果
我调整了 history
上面的变量,给出向上和向下的测试用例。预期结果是
lowerZ = 0.1 * ones((10,20))
higherZ = 0.15 * ones((10,20))
也就是说,对于每个点 z
在 history[z, :] 中,下一个最高的先前值(higherZ
)和下一个最小的先前值(lowerZ
)。自所有点z
具有完全相同的历史记录 ( [0.1, 0.3, 0.2, 0.15, 0.13]
),它们都将具有相同的 lowerZ
值和 higherZ
.当然,一般来说,每个z
的历史记录将是不同的,因此这两个矩阵将在每个网格点上包含可能不同的值。
最佳答案
我将您在此处发布的内容与 the solution for your previous post 进行了比较并注意到一些差异。
对于较小 z,你说
mask = b > history[..., -1:]
index = argmax(mask, axis=-1)
他们说:
mask = b >= a[..., -1:]
index = np.argmax(mask, axis=-1) - 1
对于更大 z,你说
mask = b <= history[..., -1:]
index = argmax(mask, axis=-1)
他们说:
mask = b > a[..., -1:]
index = np.argmax(mask, axis=-1)
使用 the solution for your previous post ,我得到:
import numpy as np
history = np.tile(np.array([0.1, 0.3, 0.2, 0.15, 0.13])[np.newaxis,np.newaxis,:], (10, 20, 1))
remainder = -1*np.ones((10, 20))
a = history
# b is a sorted ndarray excluding the most recent observation
# it is sorted along the observation axis
b = np.sort(a[..., :-1], axis=-1)
# mask is a boolean array, comparing the (sorted)
# previous observations to the current observation - [..., -1:]
mask = b > a[..., -1:]
# The next 5 statements build an indexing array.
# True evaluates to one and False evaluates to zero.
# argmax() will return the index of the first True,
# in this case along the last (observations) axis.
# index is an array with the shape of z (2-d for this test data).
# It represents the index of the next greater
# observation for every 'element' of z.
index = np.argmax(mask, axis=-1)
# The next two statements construct arrays of indices
# for every element of z - the first n-1 dimensions of history.
indices = tuple([np.arange(j) for j in b.shape[:-1]])
indices = np.meshgrid(*indices, indexing='ij', sparse=True)
# Adding index to the end of indices (the last dimension of history)
# produces a 'group' of indices that will 'select' a single observation
# for every 'element' of z
indices.append(index)
indices = tuple(indices)
higherZ = b[indices]
mask = b >= a[..., -1:]
# Since b excludes the current observation, we want the
# index just before the next highest observation for lowerZ,
# hence the minus one.
index = np.argmax(mask, axis=-1) - 1
indices = tuple([np.arange(j) for j in b.shape[:-1]])
indices = np.meshgrid(*indices, indexing='ij', sparse=True)
indices.append(index)
indices = tuple(indices)
lowerZ = b[indices]
assert np.all(lowerZ == .1)
assert np.all(higherZ == .15)
这似乎有效
关于python - 二分法的网格应用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/24098205/