我尝试将一段代码从 Matlab 转换为 Python,但遇到了一些错误:
Matlab:
function [beta] = linear_regression_train(traindata)
y = traindata(:,1); %output
ind2 = find(y == 2);
ind3 = find(y == 3);
y(ind2) = -1;
y(ind3) = 1;
X = traindata(:,2:257); %X matrix,with size of 1389x256
beta = inv(X'*X)*X'*y;
Python:
def linear_regression_train(traindata):
y = traindata[:,0] # This is the output
ind2 = (labels==2).nonzero()
ind3 = (labels==3).nonzero()
y[ind2] = -1
y[ind3] = 1
X = traindata[ : , 1:256]
X_T = numpy.transpose(X)
beta = inv(X_T*X)*X_T*y
return beta
我收到错误:操作数无法与计算 beta 的行上的形状 (257,0,1389) (1389,0,257) 一起广播。
感谢任何帮助!
谢谢!
最佳答案
问题是您正在使用 numpy 数组,而不是像 MATLAB 中那样使用矩阵。默认情况下,矩阵进行矩阵数学运算。因此,X*Y
执行 X
和 Y
的矩阵乘法。然而,对于数组,默认情况下是使用逐个元素的操作。因此,X*Y
将 X
和 Y
的每个相应元素相乘。这相当于 MATLAB 的 .*
操作。
但就像 MATLAB 的矩阵可以进行逐个元素运算一样,Numpy 的数组也可以进行矩阵乘法。所以你需要做的是使用numpy的矩阵乘法而不是它的逐元素乘法。对于 Python 3.5 或更高版本(这是您应该用于此类工作的版本),这只是 @
运算符。所以你的行变成:
beta = inv(X_T @ X) @ X_T @ y
或者,更好的是,您可以使用更简单的 .T
转置,它与 np.transpose
相同,但更简洁(您可以摆脱`np.transpose 行完全):
beta = inv(X.T @ X) @ X.T @ y
对于 Python 3.4 或更早版本,您需要使用 np.dot
,因为这些版本的 Python 没有 @
矩阵乘法运算符:
beta = np.dot(np.dot(inv(np.dot(X.T, X)), X.T), y)
Numpy 有一个矩阵对象,它像 MATLAB 矩阵一样默认使用矩阵运算。 不要使用它!它很慢,支持很差,而且几乎不是你真正想要的。 Python 社区已经围绕数组进行了标准化,因此请使用它们。
traindata
的维度也可能存在一些问题。为了使其正常工作,traindata.ndim
应等于 3
。为了使 y
和 X
为 2D,traindata
应为 3D
。
如果 traindata
是二维的,并且您希望 y
是 MATLAB 样式的“向量”(MATLAB 所谓的“向量”并不是真正的向量),这可能会出现问题)。在 numpy 中,使用像 traindata[:, 0]
这样的单个索引可以减少维数,而像 traindata[:, :1]
这样的切片则不会。因此,当 traindata
为 2D 时,要保持 y
为 2D,只需执行长度为 1 的切片,traindata[:, :1]
。这是完全相同的值,但与 traindata
保持相同的维度数。
注释:使用逻辑索引可以显着简化您的代码:
def linear_regression_train(traindata):
y = traindata[:, 0] # This is the output
y[labels == 2] = -1
y[labels == 3] = 1
X = traindata[:, 1:257]
return inv(X.T @ X) @ X.T @ y
return beta
此外,定义 X
时您的切片是错误的。 Python 切片不包括最后一个值,因此要获得 256 长的切片,您需要执行 1:257
,就像我上面所做的那样。
最后,请记住,对函数内部数组的修改会延续到函数外部,并且索引不会进行复制。因此,您对 y
的更改(将某些值设置为 1
,将其他值设置为 -1
)将影响外部的 traindata
你的功能。如果您想避免这种情况,则需要在进行更改之前制作一份副本:
y = traindata[:, 0].copy()
关于python - 将线性回归从 Matlab 转换为 Python,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39195204/