我正在尝试使用 fmin 进行逻辑回归,但由于数组形状不同而出现错误。这是代码。
import numpy as np
import scipy.optimize as sp
data= #an array of dim (188,3)
X=data[:,0:2]
y=data[:,2]
m,n=np.shape(X)
y=y.reshape(m,1)
x=np.c_[np.ones((m,1)),X]
theta=np.zeros((n+1,1))
def hypo(x,theta):
return np.dot(x,theta)
def sigmoid(z):
return 1/(1+np.exp(-z))
def gradient(x,y,theta):#calculating Gradient
m=np.shape(x)[0]
t=hypo(x,theta)
hx=sigmoid(t)
J=-(np.dot(np.transpose(np.log(hx)),y)+np.dot(np.transpose(np.log(1-hx)),(1-y)))/m
grad=np.dot(np.transpose(x),(hx-y))/m
J= J.flatten()
grad=grad.flatten()
return J,grad
def costFunc(x,y,theta):
return gradient(x,y,theta)[0]
def Grad():
return gradient(x,y,theta)[1]
sp.fmin( costFunc, x0=theta, args=(x, y), maxiter=500, full_output=True)
显示的错误
File "<ipython-input-3-31a0d7ca38c8>", line 35, in costFunc
return gradient(x,y,theta)[0]
File "<ipython-input-3-31a0d7ca38c8>", line 25, in gradient
t=hypo(x,theta)
File "<ipython-input-3-31a0d7ca38c8>", line 16, in hypo
return np.dot(x,theta)
ValueError: shapes (3,) and (118,1) not aligned: 3 (dim 0) != 118 (dim 0)
任何形式的帮助将不胜感激
最佳答案
data= #an array of dim (188,3)
X=data[:,0:2]
y=data[:,2]
m,n=np.shape(X)
y=y.reshape(m,1)
x=np.c_[np.ones((m,1)),X]
theta=np.zeros((n+1,1))
所以在这之后
In [14]: y.shape
Out[14]: (188, 1) # is this (118,1)?
In [15]: x.shape
Out[15]: (188, 3)
In [16]: theta.shape
Out[16]: (3, 1)
这x
和 theta
可以dotted
- np.dot(x,theta)
, 和 (188,3) 与 (3,1) - 匹配 3。
但这不是你的 costFunc
正在得到。从错误消息中追溯它看起来像 x
是(3,)
, 和 theta
是(118,1)
.这显然不能是 dotted
.
您需要查看如何 fmin
调用你的函数。您的参数顺序正确吗?例如,也许 costFunc(theta, x, y)
是正确的顺序(假设 x
和 y
在 costFunc
是为了匹配 args=(x,y)
。
fmin
的文档包括:
func : callable func(x,*args) The objective function to be minimized. x0 : ndarray Initial guess. args : tuple, optional Extra arguments passed to func, i.e. ``f(x,*args)``.
看起来像fmin
正在喂你的costFunc
3 个参数,大小对应于您的 (theta, x, y)
,即 (3,)
, (118,3)
, (118,1)
.这些数字不太匹配,但我想你明白了。 consFunc
的第一个参数是fmin
会有所不同,其余的请在 args
中提供.
关于Python:ValueError: 形状 (3,) 和 (118,1) 未对齐:3 (dim 0) != 118 (dim 0),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/28735344/