python - Cython/numpy 与纯 numpy 的最小二乘拟合

标签 python numpy cython least-squares

学校的助教向我展示了这段代码,作为最小二乘拟合算法的示例。

import numpy as np
#return the coefficients (a0,..aN) of the fit y=a0+a1*x+..an*x^n
#with associated sigma dy
#x,y,dy are all np.arrays with dtype= np.float64
def fit_poly(x,y,dy,n):
  V = np.asmatrix(np.diag(dy**2))
  M = []

  for k in range(n+1):
      M.append(x**k)
  M = np.asmatrix(M).T
  theta = (M.T*V.I*M).I*M.T*V.I*np.asmatrix(y).T
  cov_t = (M.T*V.I*M).I

  return np.asarray(theta.T)[0], np.asarray(cov_t)

我正在尝试使用 cython 优化他的代码。我得到了这个代码

cimport numpy as np
import numpy as np
cimport cython

@cython.boundscheck(False)
@cython.wraparound(False) 
cpdef poly_c(np.ndarray[np.float64_t, ndim=1] x ,
   np.ndarray[np.float64_t, ndim=1] y np.ndarray[np.float64_t,ndim=1]dy , np.int n):

   cdef np.ndarray[np.float64_t, ndim=2] V, M

   V=np.asmatrix(np.diag(dy**2),dtype=np.float64)
   M=np.asmatrix([x**k for k in range(n+1)],dtype=np.float64).T

   return ((M.T*V.I*M).I*M.T*V.I*(np.asmatrix(y).T))[0],(M.T*V.I*M).I

但是两个程序的运行时似乎是相同的,我确实使用了“断言”来确保输出相同。我错过了什么/做错了什么?

感谢您的宝贵时间,希望您能帮助我。

ps:这是我进行分析的代码(不确定我是否可以调用此分析,但可以使用)

import numpy as np
from polyC import poly_c
from time import time
from pancho_fit import fit_poly
#pancho's the T.A,sup pancho
x=np.arange(1,1000)
x=np.asarray(x,dtype=np.float64)
y=3*x+np.random.random(999)
y=np.asarray(y,dtype=np.float64)
dy=np.array([y.std() for i in range(1,1000)],dtype=np.float64)
t0=time()
a,b=poly_c(x,y,dy,4)
#a,b=fit_poly(x,y,dy,4)
print("time={}s".format(time()-t0))

最佳答案

除了 [x**k for k in range(n+1)] 之外,我没有看到 cython 有任何迭代需要改进。大部分作用发生在基质产品中。这些已经通过编译代码完成(使用 np.dot 表示 ndarrays)。

并且n只有4,迭代次数并不多。

但是为什么要迭代这个呢?

In [24]: x=np.arange(1,1000.)
In [25]: M1=x[:,None]**np.arange(5)
# np.matrix(M1)

做同样的事情。

所以不,这看起来不像是一个好的 cython 候选者 - 除非您准备好以可编译的细节写出所有这些矩阵产品。

我也会跳过 asmatrix 内容并使用常规的 dot@einsum,但那就是更多的是风格问题而不是速度问题。

关于python - Cython/numpy 与纯 numpy 的最小二乘拟合,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38387393/

相关文章:

Python - Virtualenv,python 3?

python - 如何为元组的每个元素添加一个数字?

python - 异构数据记录和分析

python - 安装 python-dev 和链接库后,Cython 中的 Hello World 程序因 gcc 而失败

python - 获取正则表达式匹配后的第一个单词

python - 无法保存背景减去视频Python openCV

python - 无法理解 numpy.random.RandomState

numpy - Tensorflow:在嵌套范围内按名称获取变量或张量

在 cython 中调用内置的 gcc?

python - 向量化 for 循环 NumPy