python - 如何高效的传递函数?

标签 python algorithm python-2.7 numpy scipy

动机

看看下面的图片。

enter image description here

给定的是红色、蓝色和绿色曲线。我想在 x 轴上的每个点找到支配曲线。这在图片中显示为黑色图形。从红色、绿色和蓝色曲线的属性(一段时间后增加并保持不变)归结为找到最右侧的主导曲线,然后向左侧移动找到所有交点并更新主导曲线曲线。

这个概述的问题应该解决 T 次。这个问题还有最后一个转折点。下一次迭代的蓝色、绿色和红色曲线是通过上一次迭代的主导解加上一些变化的参数构建的。如上图示例: 解决方案是黑色功能。此函数用于生成新的蓝色、绿色和红色曲线。然后问题再次开始,为这些新曲线等找到主导曲线。

简而言之问题
在每次迭代中,我从固定的最右边开始,评估所有三个函数,看看哪个是主导函数。这种评估在迭代中花费的时间越来越长。 我的感觉是,我没有以最佳方式通过旧的主导函数来构建新的蓝色、绿色和红色曲线。原因:我在早期版本中遇到了最大递归深度错误。 代码的其他部分需要当前主导函数的值(绿色、红色或蓝色曲线必不可少)随着迭代越来越长。

对于 5 次迭代,仅评估右侧一点上的函数增长:

结果产生于

test = A(5, 120000, 100000) 

然后运行

test.find_all_intersections()

>>> test.find_all_intersections()
iteration 4
to compute function values it took
0.0102479457855
iteration 3
to compute function values it took
0.0134601593018
iteration 2
to compute function values it took
0.0294270515442
iteration 1
to compute function values it took
0.109843969345
iteration 0
to compute function values it took
0.823768854141

我想知道为什么会这样,是否可以更有效地对其进行编程。

详细代码解释

我快速总结了最重要的功能。完整的代码可以在下面进一步找到。如果对代码有任何其他问题,我很乐意详细说明/澄清。

  1. 方法u:用于生成新批处理的循环任务 上面的绿色、红色和蓝色曲线我们需要旧的主导曲线。 u 是要在第一次迭代中使用的初始化。

  2. 方法 _function_template:该函数生成 通过使用不同的参数绘制绿色、蓝色和红色曲线。它返回 单个输入的函数。

  3. 方法eval:这是每次生成蓝绿红版本的核心函数。每次迭代采用三个不同的参数:vfunction,它是上一步中的主导函数,ms,这是两个参数(flaots ) 影响结果曲线的形状。其他参数在每次迭代中都相同。在代码中,每次迭代都有 ms 的示例值。对于更令人讨厌的:它是近似一个积分,其中 ms 是基础正态分布的预期均值和标准差。近似值是通过 Gauss-Hermite 节点/权重完成的。

  4. 方法find_all_intersections:这是查找的核心方法 每次迭代都是主导的。它构建了一个主导 通过蓝色、绿色和红色的分段连接来发挥作用 曲线。这是通过函数 piecewise 实现的。

完整代码如下

import numpy as np
import pandas as pd
from scipy.optimize import brentq
import multiprocessing as mp
import pathos as pt
import timeit
import math
class A(object):
    def u(self, w):
        _w = np.asarray(w).copy()
        _w[_w >= 120000] = 120000
        _p = np.maximum(0, 100000 - _w)
        return _w - 1000*_p**2

    def __init__(self, T, upper_bound, lower_bound):
        self.T = T
        self.upper_bound = upper_bound
        self.lower_bound = lower_bound

    def _function_template(self, *args):
        def _f(x):
            return self.evalv(x, *args)
        return _f

    def evalv(self, w, c, vfunction, g, m, s, gauss_weights, gauss_nodes):
        _A = np.tile(1 + m + math.sqrt(2) * s * gauss_nodes, (np.size(w), 1))
        _W = (_A.T * w).T
        _W = gauss_weights * vfunction(np.ravel(_W)).reshape(np.size(w),
                                                             len(gauss_nodes))
        evalue = g*1/math.sqrt(math.pi)*np.sum(_W, axis=1)
        return c + evalue

    def find_all_intersections(self):

        # the hermite gauss weights and nodes for integration
        # and additional paramters used for function generation

        gauss = np.polynomial.hermite.hermgauss(10)
        gauss_nodes = gauss[0]
        gauss_weights = gauss[1]
        r = np.asarray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                        1., 1., 1., 1., 1., 1., 1., 1., 1.])
        m = [[0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
             [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624]]

        s = [[0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
             [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142]]

        self.solution = []

        n_cpu = mp.cpu_count()
        pool = pt.multiprocessing.ProcessPool(n_cpu)

        # this function is used for multiprocessing
        def call_f(f, x):
            return f(x)

        # this function takes differences for getting cross points
        def _diff(f_dom, f_other):
            def h(x):
                return f_dom(x) - f_other(x)
            return h

        # finds the root of two function
        def find_roots(F, u_bound, l_bound):
                try:
                    sol = brentq(F, a=l_bound,
                                 b=u_bound)
                    if np.absolute(sol - u_bound) > 1:
                        return sol
                    else:
                        return l_bound
                except ValueError:
                    return l_bound

        # piecewise function
        def piecewise(l_comp, l_f):
            def f(x):
                _ind_f = np.digitize(x, l_comp) - 1
                if np.isscalar(x):
                    return l_f[_ind_f](x)
                else:
                    return np.asarray([l_f[_ind_f[i]](x[i])
                                       for i in range(0, len(x))]).ravel()
            return f

        _u = self.u

        for t in range(self.T-1, -1, -1):
            print('iteration' + ' ' + str(t))

            l_bound, u_bound = 0.5*self.lower_bound, self.upper_bound
            l_ordered_functions = []
            l_roots = []
            l_solution = []

            # build all function variations

            l_functions = [self._function_template(0, _u, r[t], m[t][i], s[t][i],
                                                   gauss_weights, gauss_nodes)
                           for i in range(0, len(m[t]))]

            # get the best solution for the upper bound on the very
            # right hand side of wealth interval

            array_functions = np.asarray(l_functions)
            start_time = timeit.default_timer()
            functions_values = pool.map(call_f, array_functions.tolist(),
                                        len(m[t]) * [u_bound])
            elapsed = timeit.default_timer() - start_time
            print('to compute function values it took')
            print(elapsed)

            ind = np.argmax(functions_values)
            cross_points = len(m[t]) * [u_bound]
            l_roots.insert(0, u_bound)
            max_m = m[t][ind]
            l_solution.insert(0, max_m)

            # move from the upper bound twoards the lower bound
            # and find the dominating solution by exploring all cross
            # points.

            test = True

            while test:
                l_ordered_functions.insert(0, array_functions[ind])
                current_max = l_ordered_functions[0]

                l_c_max = len(m[t]) * [current_max]
                l_u_cross = len(m[t]) * [cross_points[ind]]

                # Find new cross points on the smaller interval

                diff = pool.map(_diff, l_c_max, array_functions.tolist())
                cross_points = pool.map(find_roots, diff,
                                        l_u_cross, len(m[t]) * [l_bound])

                # update the solution, cross points and current
                # dominating function.

                ind = np.argmax(cross_points)
                l_roots.insert(0, cross_points[ind])
                max_m = m[t][ind]
                l_solution.insert(0, max_m)

                if cross_points[ind] <= l_bound:
                    test = False

            l_ordered_functions.insert(0, l_functions[0])
            l_roots.insert(0, 0)
            l_roots[-1] = np.inf

            l_comp = l_roots[:]
            l_f = l_ordered_functions[:]

            # build piecewise function which is used for next
            # iteration.

            _u = piecewise(l_comp, l_f)
            _sol = pd.DataFrame(data=l_solution,
                                index=np.asarray(l_roots)[0:-1])
            self.solution.insert(0, _sol)
        return self.solution

最佳答案

让我们首先更改代码以输出当前迭代:

_u = self.u
for t in range(0, self.T):
    print(t)
    lparams = np.random.randint(self.a, self.b, 6).reshape(3, 2).tolist()
    functions = [self._function_template(_u, *lparams[i])
                 for i in range(0, 3)]
    # evaluate functions
    pairs = list(itertools.combinations(functions, 2))
    fval = [F(diff(*pairs[i]), self.a, self.b) for i in range(0, 3)]
    ind = np.sort(np.unique(np.random.randint(self.a, self.b, 10)))
    _u = _temp(ind, np.asarray(functions)[ind % 3])

查看导致该行为的行,

fval = [F(diff(*pairs[i]), self.a, self.b) for i in range(0, 3)]

感兴趣的函数是Fdiff。后者很直接,前者:

def F(f, a, b):
    try:
        brentq(f, a=a, b=b)
    except ValueError:
        pass

嗯,吞下异常,让我们看看如果我们:

def F(f, a, b):
    brentq(f, a=a, b=b)

立即,对于第一个函数和第一次迭代,抛出一个错误:

ValueError: f(a) and f(b) must have different signs

查看 docs这是求根函数 brentq 的先决条件。让我们再次更改定义以在每次迭代中监视此条件。

def F(f, a, b):
    try:
        brentq(f, a=a, b=b)
    except ValueError as e:
        print(e)

输出是

i
f(a) and f(b) must have different signs
f(a) and f(b) must have different signs
f(a) and f(b) must have different signs

i 的范围从 0 到 57。意思是,函数 F 第一次做任何真正的工作是针对 i=58 .它会不断这样做以获得更高的 i 值。

结论:这些较高的值需要更长的时间,因为:

  1. 从不计算较低值的根
  2. i>58
  3. 的计算次数呈线性增长

关于python - 如何高效的传递函数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48249253/

相关文章:

python - 使用一个数据帧在其他数据帧上创建组,然后取平均值

c++ - 增量熵计算

c++ - 为什么带有多个 copy_n() 的 std::istream_iterator<> 总是写入第一个值

javascript - 使用/不同的 CSS 建模重叠的 HTML 跨度

Python Awis API 给出未定义的错误

css - Selenium Webdriver Python select a dropdown error local variable element referenced before assignment

python - 在 Python/Django 中显示 friend 关注的 Twitter 关注者

python - 如何在 ipython 中创建二维数组的直方图

python - 有没有办法将 Python 可执行文件与数据库一起使用,而无需在目标 PC 上安装数据库?

python - 当单词不存在时,将 0 分配给某些单词