python - 将 Pandas DataFrame 传递给 Scipy.optimize.curve_fit

标签 python pandas scipy mathematical-optimization model-fitting

我想知道使用 Scipy 来拟合 Pandas DataFrame 列的最佳方法。如果我有一个数据表(Pandas DataFrame),其中包含列(ABCDZ_real),其中 Z 取决于 A、B、C 和 D,我想拟合每个 DataFrame 行(系列)的函数,该函数对 Z 进行预测 (Z_pred)。

每个要拟合的函数的签名是

func(series, param_1, param_2...)

其中series是与DataFrame的每一行对应的Pandas Series。我使用 Pandas Series,以便不同的函数可以使用不同的列组合。

我尝试使用将 DataFrame 传递给 scipy.optimize.curve_fit

curve_fit(func, table, table.loc[:, 'Z_real'])

但由于某种原因,每个 func 实例都会传递整个数据表作为其第一个参数,而不是每行的 Series。我还尝试将 DataFrame 转换为 Series 对象列表,但这会导致我的函数传递一个 Numpy 数组(我认为是因为 Scipy 执行从 Series 列表到 Numpy 数组的转换,这不会保留 Pandas系列对象)。

最佳答案

您对 curve_fit 的调用不正确。来自 the documentation :

xdata : An M-length sequence or an (k,M)-shaped array for functions with k predictors.

The independent variable where the data is measured.

ydata : M-length sequence

The dependent data — nominally f(xdata, ...)

在这种情况下,您的自变量 xdata 是 A 到 D 列,即 table[['A', 'B', 'C', 'D']],您的因变量 ydatatable['Z_real']

另请注意,xdata 应为 (k, M) 数组,其中 k 是预测变量(即列)的数量M 是观测值的数量(即行)。因此,您应该转置输入数据帧,使其为 (4, M) 而不是 (M, 4),即 table[['A', ' B'、'C'、'D']].T

curve_fit的整个调用可能如下所示:

curve_fit(func, table[['A', 'B', 'C', 'D']].T, table['Z_real'])

这是一个显示多元线性回归的完整示例:

import numpy as np
import pandas as pd
from scipy.optimize import curve_fit

X = np.random.randn(100, 4)     # independent variables
m = np.random.randn(4)          # known coefficients
y = X.dot(m)                    # dependent variable

df = pd.DataFrame(np.hstack((X, y[:, None])),
                  columns=['A', 'B', 'C', 'D', 'Z_real'])

def func(X, *params):
    return np.hstack(params).dot(X)

popt, pcov = curve_fit(func, df[['A', 'B', 'C', 'D']].T, df['Z_real'],
                       p0=np.random.randn(4))

print(np.allclose(popt, m))
# True

关于python - 将 Pandas DataFrame 传递给 Scipy.optimize.curve_fit,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35233664/

相关文章:

python - 计算多个 pandas 数据帧中某个 pandas 行的特定列值低于另一个特定 pandas 行的次数

python - 如何可逆地将 Pandas 数据帧存储到磁盘或从磁盘加载

python - 使用 ODEINT 或其他方法求解许多耦合微分方程组

Python:将 Dataframe 的最多 3 列合并为 1 列,但 3 列中的任何一个都不存在

python - 在不使用 python 模块的情况下从 .csv 的列中查找最小值

python - 提高pandas groupby的性能

Python:使用 sklearn 时为 "ValueError: setting an array element with a sequence"

python - 访问 scipy.sparse.csr_matrix,所有行都没有零列 j

python - Pandas 新手 : Sort by nth row in dataframe

python - 模拟 Flask 的 `send_from_directory` 用于测试