python - 动态定义一个函数

标签 python numpy scipy

我正在尝试编写一个返回最佳参数 a、b 和 c 的曲线拟合函数,这是一个简化的示例:

import numpy
import scipy
from scipy.optimize import curve_fit

def f(x, a, b, c):
    return x * 2*a + 4*b - 5*c

xdata = numpy.array([1,3,6,8,10])
ydata = numpy.array([  0.91589774,   4.91589774,  10.91589774,  14.91589774,  18.91589774])
popt, pcov = scipy.optimize.curve_fit(f, xdata, ydata)

这很好用,但我想让用户有机会提供一些(或不提供)参数 a、b 或 c,在这种情况下,它们应该被视为常量而不是估计值。我如何编写 f 使其只适合用户未提供的参数?

基本上,我需要使用正确的参数动态定义f。例如,如果用户知道 a,则 f 变为:

def f(x, b, c):
    a = global_version_of_a
    return x * 2*a + 4*b - 5*c

最佳答案

collections.namedtuple playbook 中获取一页,您可以使用 exec 来“动态地”定义 func:

import numpy as np
import scipy.optimize as optimize
import textwrap

funcstr=textwrap.dedent('''\
def func(x, {p}):
    return x * 2*a + 4*b - 5*c
''')
def make_model(**kwargs):
    params=set(('a','b','c')).difference(kwargs.keys())
    exec funcstr.format(p=','.join(params)) in kwargs
    return kwargs['func']

func=make_model(a=3, b=1)

xdata = np.array([1,3,6,8,10])
ydata = np.array([  0.91589774,   4.91589774,  10.91589774,  14.91589774,  18.91589774])
popt, pcov = optimize.curve_fit(func, xdata, ydata)
print(popt)
# [ 5.49682045]

注意这行

func=make_model(a=3, b=1)

您可以将任何您喜欢的参数传递给 make_model。您传递给 make_model 的参数成为 func 中的固定常量。保留的任何参数都将成为 optimize.curve_fit 将尝试拟合的自由参数。

例如,上面的a=3和b=1成为func中的固定常量。实际上,exec 语句将它们放在 func 的全局命名空间中。 func 因此被定义为 x 和单个参数 c 的函数。请注意,popt 的返回值是一个长度为 1 的数组,对应于剩余的自由参数 c


关于textwrap.dedent:在上面的例子中,调用textwrap.dedent是不必要的。但在“现实生活”脚本中,funcstr 是在函数内部或更深的缩进级别定义的,textwrap.dedent 允许您编写

def foo():
    funcstr=textwrap.dedent('''\
        def func(x, {p}):
            return x * 2*a + 4*b - 5*c
        ''')

而不是视觉上没有吸引力

def foo():
    funcstr='''\
def func(x, {p}):
    return x * 2*a + 4*b - 5*c
'''

有些人喜欢

def foo():
    funcstr=(
        'def func(x, {p}):\n'
        '    return x * 2*a + 4*b - 5*c'
        )

但我发现分别引用每一行并添加明确的 EOL 字符有点麻烦。但是,它确实为您节省了一次函数调用。

关于python - 动态定义一个函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/8196744/

相关文章:

python - 如何将以下输出更改为正确的字符串?

python - 在 Python 中切片列表列表

python - Django 迭代模板中的静态文件

python - 与 C 数组相比,带有 NumPy 数组内存 View 的 Cython 性能较差

python - 更新使用具有重复索引的叉积索引的 2D NumPy 数组

numpy - imshow 的颜色条,以 0 为中心并带有符号刻度

python - pygame.camera "livefeed"在 tkinter 窗口内(在 raspbian 上)

python-3.x - 如何使用 scipy optimization 找到 3 个参数和数据点列表的最小卡方?

parallel-processing - 并行处理对大型数据集的顺序任务的多次评估——GPU 计算的任务?

python - Scipy 稀疏矩阵中的行划分