python - 在Integrate.quad中使用上限作为数组

标签 python scipy

我有一组 (x,y) 数据点,我希望使用 scipy.optimize 中的 curve_fit 函数来拟合它们。 但是执行 curve_fit 函数会出现以下错误。

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-3-27ed918f0be6> in <module>()
      1 print np.any(xdata)
----> 2 popt, pcov = curve_fit(func,xdata,ydata,p0=[0.01,371,700],)

D:\anaconda\lib\site-packages\scipy\optimize\minpack.pyc in curve_fit(f, xdata, ydata, p0, sigma, absolute_sigma, check_finite, bounds, method, jac, **kwargs)
    740         # Remove full_output from kwargs, otherwise we're passing it in twice.
    741         return_full = kwargs.pop('full_output', False)
--> 742         res = leastsq(func, p0, Dfun=jac, full_output=1, **kwargs)
    743         popt, pcov, infodict, errmsg, ier = res
    744         cost = np.sum(infodict['fvec'] ** 2)

D:\anaconda\lib\site-packages\scipy\optimize\minpack.pyc in leastsq(func, x0, args, Dfun, full_output, col_deriv, ftol, xtol, gtol, maxfev, epsfcn, factor, diag)
    375     if not isinstance(args, tuple):
    376         args = (args,)
--> 377     shape, dtype = _check_func('leastsq', 'func', func, x0, args, n)
    378     m = shape[0]
    379     if n > m:

D:\anaconda\lib\site-packages\scipy\optimize\minpack.pyc in _check_func(checker, argname, thefunc, x0, args, numinputs, output_shape)
     24 def _check_func(checker, argname, thefunc, x0, args, numinputs,
     25                 output_shape=None):
---> 26     res = atleast_1d(thefunc(*((x0[:numinputs],) + args)))
     27     if (output_shape is not None) and (shape(res) != output_shape):
     28         if (output_shape[0] != 1):

D:\anaconda\lib\site-packages\scipy\optimize\minpack.pyc in func_wrapped(params)
    452     if transform is None:
    453         def func_wrapped(params):
--> 454             return func(xdata, *params) - ydata
    455     elif transform.ndim == 1:
    456         def func_wrapped(params):

<ipython-input-2-774b4351c6a2> in func(x, gb, V0, D)
     13     T = x
     14 
---> 15     integrand = integrate.quad((lambda z: (z**3)/(np.exp(z)-1)),1e-5,D/T)
     16     coeff = (9.0*Na*kb*T)*((T/D)**3.0)
     17     U = coeff*integrand[0]

D:\anaconda\lib\site-packages\scipy\integrate\quadpack.pyc in quad(func, a, b, args, full_output, epsabs, epsrel, limit, points, weight, wvar, wopts, maxp1, limlst)
    321     if (weight is None):
    322         retval = _quad(func, a, b, args, full_output, epsabs, epsrel, limit,
--> 323                        points)
    324     else:
    325         retval = _quad_weight(func, a, b, args, full_output, epsabs, epsrel,

D:\anaconda\lib\site-packages\scipy\integrate\quadpack.pyc in _quad(func, a, b, args, full_output, epsabs, epsrel, limit, points)
    370 def _quad(func,a,b,args,full_output,epsabs,epsrel,limit,points):
    371     infbounds = 0
--> 372     if (b != Inf and a != -Inf):
    373         pass   # standard integration
    374     elif (b == Inf and a != -Inf):

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Python代码是:

import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
%matplotlib inline
import scipy.integrate as integrate
from scipy.optimize import curve_fit

def func(x,gb,V0,D):    
    Na=(6.022e23)
    kb=(1.38e-23)
    T = x    
    integrand = integrate.quad((lambda z: (z**3)/(np.exp(z)-1)),1e-5,D/T)
    coeff = (9.0*Na*kb*T)*((T/D)**3.0)
    U = coeff*integrand[0]
    V = (gb*U)+V0
    return V

xdata=[70,80,100,150,200,250,300]
ydata=[372.106,372.141,372.186,372.467,372.699,372.99,373.598]

plt.plot(xdata,ydata,'x')
plt.xlabel('T')
plt.ylabel('V')
plt.show()

print np.any(xdata)
popt, pcov = curve_fit(func,xdata,ydata,p0=[0.01,371,700])

我的自变量x是积分的上限。这就是quad向我返回以下类型的错误消息的原因吗?

具有多个元素的数组的真值是不明确的。使用a.any()或a.all()

如果是这种情况,我该如何规避?

最佳答案

该问题是由于 integrate.quad() 中的变量 D/T 造成的。

由于D/T是一个数组,因此不能将其视为积分的上限。

我不确定你的函数想要实现的目标是什么。但您只需将 D/T 更改为 D/T[0]np.mean(D/T) ,错误就会消失.

关于python - 在Integrate.quad中使用上限作为数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50167892/

相关文章:

python - Tkinter Notebook 中的可滚动页面

python - 向 csr_matrix 添加元素的正确方法是什么?

python - 基于网格的多元三次插值

python - 我如何弄清楚为什么 Python 在 Windows 7 64 位中抛出这个非描述性错误

python - 在python中高效生成 "shifted"高斯核

python - scikit 分类器的准确率非常低(朴素贝叶斯、DecisionTreeClassifier)

python - 重置/更新 Couchbase 计数器的 TTL

python - 并发 future 以非阻塞方式将任务提交到进程池

python - coverage.py 针对 .pyc 文件

python - 使用 ffmpeg 保存 matplotlib 动画时遇到问题