python - 使用 numpy 或 scipy 将 3D 数据数组拟合到 1D 函数

标签 python arrays numpy scipy

4我目前正在尝试将大量数据拟合到一个正弦函数中。在我只有一组数据(一维数组)的情况下,scipy.optimize.curve_fit() 工作正常。但是,据我所知,如果函数本身只是一维的,则它不允许输入更高维的数据。我不想使用 for 循环遍历数组,因为这在 python 中运行得非常慢。

到目前为止,我的代码应该与此类似:

from scipy import optimize
import numpy as np    
def f(x,p1,p2,p3,p4): return p1 + p2*np.sin(2*np.pi*p3*x + p4)      #fit function

def fit(data,guess):
   n = data.shape[0] 
   leng = np.arange(n)
   param, pcov = optimize.curve_fit(f,leng,data,guess)
   return param, pcov

其中数据是一个三维数组 (shape=(x,y,z)),我想适应每一行 data[:,a,b]param(4,y,z) 形数组作为输出的函数。 当然,对于多维数据,这会导致

ValueError:操作数无法与形状 (2100,2100) (5) 一起广播

也许有一个简单的解决方案,但我不确定该怎么做。有什么建议吗?

要为我的问题寻找答案非常困难,因为大多数带有这些关键字的主题都与高维函数的拟合有关。

最佳答案

使用 np.apply_along_axis()解决你的问题。只需这样做:

func1d = lambda y, *args: optimize.curve_fit(f, xdata=x, ydata=y, *args)[0] #<-- [0] to get only popt
param = np.apply_along_axis( func1d, axis=2, arr=data )

请看下面的例子:

from scipy import optimize
import numpy as np
def f(x,p1,p2,p3,p4):
    return p1 + p2*np.sin(2*np.pi*p3*x + p4)
sx = 50  # size x
sy = 200 # size y
sz = 100 # size z
# creating the reference parameters
tmp = np.empty((4,sy,sz))
tmp[0,:,:] = (1.2-0.8) * np.random.random_sample((sy,sz)) + 0.8
tmp[1,:,:] = (1.2-0.8) * np.random.random_sample((sy,sz)) + 0.8
tmp[2,:,:] = np.ones((sy,sz))
tmp[3,:,:] = np.ones((sy,sz))*np.pi/4
param_ref = np.empty((4,sy,sz,sx))     # param_ref in this shape will allow an
for i in range(sx):                    # one-shot evaluation of f() to create 
    param_ref[:,:,:,i] = tmp           # the data sample
# creating the data sample
x = np.linspace(0,2*np.pi)
factor = (1.1-0.9)*np.random.random_sample((sy,sz,sx))+0.9
data = f(x, *param_ref) * factor       # the one-shot evalution is here
# finding the adjusted parameters
func1d = lambda y, *args: optimize.curve_fit(f, xdata=x, ydata=y, *args)[0] #<-- [0] to get only popt
param = np.apply_along_axis( func1d, axis=2, arr=data )

关于python - 使用 numpy 或 scipy 将 3D 数据数组拟合到 1D 函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/15094619/

相关文章:

arrays - swift 2.0 : Json Array Parsing errors

JavaScript 数组 - forEach 内的拼接调用出现意外结果

javascript - 如何检查数组是否为空或不存在?

python - 是否有一个Python函数(最好是seaborn)可以帮助我用散点图上的一条线连接两组点?

python - 在 h5py 中压缩文件更大

python - 在Python中从列表中的每个元素减去自身的有效方法

python - 如何通过 Python 使用 GeckoDriver 和 Firefox 使 Selenium 脚本无法检测?

python - 如何在Python Flask项目中启用Cloud Foundry的粘性 session ?

python - Numpy:更改所有矩阵元素 10% 的最快方法

python - 带有 bool 输出的opencv python颜色检测