动机
看看下面的图片。
给定的是红色、蓝色和绿色曲线。我想在 x
轴上的每个点找到支配曲线。这在图片中显示为黑色图形。从红色、绿色和蓝色曲线的属性(一段时间后增加并保持不变)归结为找到最右侧的主导曲线,然后向左侧移动找到所有交点并更新主导曲线曲线。
这个概述的问题应该解决 T
次。这个问题还有最后一个转折点。下一次迭代的蓝色、绿色和红色曲线是通过上一次迭代的主导解加上一些变化的参数构建的。如上图示例: 解决方案是黑色功能。此函数用于生成新的蓝色、绿色和红色曲线。然后问题再次开始,为这些新曲线等找到主导曲线。
简而言之问题
在每次迭代中,我从固定的最右边开始,评估所有三个函数,看看哪个是主导函数。这种评估在迭代中花费的时间越来越长。 我的感觉是,我没有以最佳方式通过旧的主导函数来构建新的蓝色、绿色和红色曲线。原因:我在早期版本中遇到了最大递归深度错误。 代码的其他部分需要当前主导函数的值(绿色、红色或蓝色曲线必不可少)随着迭代越来越长。
对于 5 次迭代,仅评估右侧一点上的函数增长:
结果产生于
test = A(5, 120000, 100000)
然后运行
test.find_all_intersections()
>>> test.find_all_intersections()
iteration 4
to compute function values it took
0.0102479457855
iteration 3
to compute function values it took
0.0134601593018
iteration 2
to compute function values it took
0.0294270515442
iteration 1
to compute function values it took
0.109843969345
iteration 0
to compute function values it took
0.823768854141
我想知道为什么会这样,是否可以更有效地对其进行编程。
详细代码解释
我快速总结了最重要的功能。完整的代码可以在下面进一步找到。如果对代码有任何其他问题,我很乐意详细说明/澄清。
方法
u
:用于生成新批处理的循环任务 上面的绿色、红色和蓝色曲线我们需要旧的主导曲线。u
是要在第一次迭代中使用的初始化。方法
_function_template
:该函数生成 通过使用不同的参数绘制绿色、蓝色和红色曲线。它返回 单个输入的函数。方法
eval
:这是每次生成蓝绿红版本的核心函数。每次迭代采用三个不同的参数:vfunction
,它是上一步中的主导函数,m
和s
,这是两个参数(flaots ) 影响结果曲线的形状。其他参数在每次迭代中都相同。在代码中,每次迭代都有m
和s
的示例值。对于更令人讨厌的:它是近似一个积分,其中m
和s
是基础正态分布的预期均值和标准差。近似值是通过 Gauss-Hermite 节点/权重完成的。方法
find_all_intersections
:这是查找的核心方法 每次迭代都是主导的。它构建了一个主导 通过蓝色、绿色和红色的分段连接来发挥作用 曲线。这是通过函数piecewise
实现的。
完整代码如下
import numpy as np
import pandas as pd
from scipy.optimize import brentq
import multiprocessing as mp
import pathos as pt
import timeit
import math
class A(object):
def u(self, w):
_w = np.asarray(w).copy()
_w[_w >= 120000] = 120000
_p = np.maximum(0, 100000 - _w)
return _w - 1000*_p**2
def __init__(self, T, upper_bound, lower_bound):
self.T = T
self.upper_bound = upper_bound
self.lower_bound = lower_bound
def _function_template(self, *args):
def _f(x):
return self.evalv(x, *args)
return _f
def evalv(self, w, c, vfunction, g, m, s, gauss_weights, gauss_nodes):
_A = np.tile(1 + m + math.sqrt(2) * s * gauss_nodes, (np.size(w), 1))
_W = (_A.T * w).T
_W = gauss_weights * vfunction(np.ravel(_W)).reshape(np.size(w),
len(gauss_nodes))
evalue = g*1/math.sqrt(math.pi)*np.sum(_W, axis=1)
return c + evalue
def find_all_intersections(self):
# the hermite gauss weights and nodes for integration
# and additional paramters used for function generation
gauss = np.polynomial.hermite.hermgauss(10)
gauss_nodes = gauss[0]
gauss_weights = gauss[1]
r = np.asarray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1.])
m = [[0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
[0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
[0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
[0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
[0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
[0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
[0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
[0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
[0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
[0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
[0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
[0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
[0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
[0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
[0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
[0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
[0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
[0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
[0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
[0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
[0.038212567720998125, 0.08509661835487026, 0.15484578903763624]]
s = [[0.01945441966324046, 0.04690600929081242, 0.200125178687699],
[0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
[0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
[0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
[0.01945441966324046, 0.04690600929081242, 0.200125178687699],
[0.01945441966324046, 0.04690600929081242, 0.200125178687699],
[0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
[0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
[0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
[0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
[0.01945441966324046, 0.04690600929081242, 0.200125178687699],
[0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
[0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
[0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
[0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
[0.01945441966324046, 0.04690600929081242, 0.200125178687699],
[0.01945441966324046, 0.04690600929081242, 0.200125178687699],
[0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
[0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
[0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
[0.019491796104351332, 0.04699612658674578, 0.20050966545654142]]
self.solution = []
n_cpu = mp.cpu_count()
pool = pt.multiprocessing.ProcessPool(n_cpu)
# this function is used for multiprocessing
def call_f(f, x):
return f(x)
# this function takes differences for getting cross points
def _diff(f_dom, f_other):
def h(x):
return f_dom(x) - f_other(x)
return h
# finds the root of two function
def find_roots(F, u_bound, l_bound):
try:
sol = brentq(F, a=l_bound,
b=u_bound)
if np.absolute(sol - u_bound) > 1:
return sol
else:
return l_bound
except ValueError:
return l_bound
# piecewise function
def piecewise(l_comp, l_f):
def f(x):
_ind_f = np.digitize(x, l_comp) - 1
if np.isscalar(x):
return l_f[_ind_f](x)
else:
return np.asarray([l_f[_ind_f[i]](x[i])
for i in range(0, len(x))]).ravel()
return f
_u = self.u
for t in range(self.T-1, -1, -1):
print('iteration' + ' ' + str(t))
l_bound, u_bound = 0.5*self.lower_bound, self.upper_bound
l_ordered_functions = []
l_roots = []
l_solution = []
# build all function variations
l_functions = [self._function_template(0, _u, r[t], m[t][i], s[t][i],
gauss_weights, gauss_nodes)
for i in range(0, len(m[t]))]
# get the best solution for the upper bound on the very
# right hand side of wealth interval
array_functions = np.asarray(l_functions)
start_time = timeit.default_timer()
functions_values = pool.map(call_f, array_functions.tolist(),
len(m[t]) * [u_bound])
elapsed = timeit.default_timer() - start_time
print('to compute function values it took')
print(elapsed)
ind = np.argmax(functions_values)
cross_points = len(m[t]) * [u_bound]
l_roots.insert(0, u_bound)
max_m = m[t][ind]
l_solution.insert(0, max_m)
# move from the upper bound twoards the lower bound
# and find the dominating solution by exploring all cross
# points.
test = True
while test:
l_ordered_functions.insert(0, array_functions[ind])
current_max = l_ordered_functions[0]
l_c_max = len(m[t]) * [current_max]
l_u_cross = len(m[t]) * [cross_points[ind]]
# Find new cross points on the smaller interval
diff = pool.map(_diff, l_c_max, array_functions.tolist())
cross_points = pool.map(find_roots, diff,
l_u_cross, len(m[t]) * [l_bound])
# update the solution, cross points and current
# dominating function.
ind = np.argmax(cross_points)
l_roots.insert(0, cross_points[ind])
max_m = m[t][ind]
l_solution.insert(0, max_m)
if cross_points[ind] <= l_bound:
test = False
l_ordered_functions.insert(0, l_functions[0])
l_roots.insert(0, 0)
l_roots[-1] = np.inf
l_comp = l_roots[:]
l_f = l_ordered_functions[:]
# build piecewise function which is used for next
# iteration.
_u = piecewise(l_comp, l_f)
_sol = pd.DataFrame(data=l_solution,
index=np.asarray(l_roots)[0:-1])
self.solution.insert(0, _sol)
return self.solution
最佳答案
让我们首先更改代码以输出当前迭代:
_u = self.u
for t in range(0, self.T):
print(t)
lparams = np.random.randint(self.a, self.b, 6).reshape(3, 2).tolist()
functions = [self._function_template(_u, *lparams[i])
for i in range(0, 3)]
# evaluate functions
pairs = list(itertools.combinations(functions, 2))
fval = [F(diff(*pairs[i]), self.a, self.b) for i in range(0, 3)]
ind = np.sort(np.unique(np.random.randint(self.a, self.b, 10)))
_u = _temp(ind, np.asarray(functions)[ind % 3])
查看导致该行为的行,
fval = [F(diff(*pairs[i]), self.a, self.b) for i in range(0, 3)]
感兴趣的函数是F
和diff
。后者很直接,前者:
def F(f, a, b):
try:
brentq(f, a=a, b=b)
except ValueError:
pass
嗯,吞下异常,让我们看看如果我们:
def F(f, a, b):
brentq(f, a=a, b=b)
立即,对于第一个函数和第一次迭代,抛出一个错误:
ValueError: f(a) and f(b) must have different signs
查看 docs这是求根函数 brentq
的先决条件。让我们再次更改定义以在每次迭代中监视此条件。
def F(f, a, b):
try:
brentq(f, a=a, b=b)
except ValueError as e:
print(e)
输出是
i f(a) and f(b) must have different signs f(a) and f(b) must have different signs f(a) and f(b) must have different signs
i
的范围从 0 到 57。意思是,函数 F
第一次做任何真正的工作是针对 i=58
.它会不断这样做以获得更高的 i
值。
结论:这些较高的值需要更长的时间,因为:
- 从不计算较低值的根
i>58
的计算次数呈线性增长
关于python - 如何高效的传递函数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48249253/