python - 使用具有多个输入因子的 sklearn 决策树进行回归会产生错误

标签 python input scikit-learn regression

我想用 sklearn 的决策树回归器进行回归。我的输入数据由多个传感器数据组成,我将时间序列划分为更小的窗口,并计算每个时间窗口和每个传感器的平均值和标准差。该数组如下所示(以两个传感器和 100 个时间窗口为例):

features_x = np.array([[x[:,0].mean(), x[:,0].std(), x[:,1].mean(), x[:,1].std()]
                    for x in np.array_split(train_array, 100)])

然后我想预测第三个传感器的平均值:

features_y = np.array([[x[:,2].mean()]
                        for x in np.array_split(train_array, 100)])

然后我使用决策树回归器:

clf = tree.DecisionTreeRegressor()
clf.fit(features_x.reshape(-1,1),features_y.reshape(-1,1))

但是当我使用此代码时,我收到一条错误消息:

ValueError: Number of labels does not match number of samples

这一定是因为我使用具有 4 个“值”的数组作为输入,但使用仅具有 1 个“值”的数组作为输出。但我实际上想使用来自许多传感器的数据作为输入来预测仅另一个传感器的值作为输出。是否有可能使回归以这种方式进行?

编辑:两个特征矩阵均由浮点值组成。 features_x 有 4 列和 100 行,每列是平均值或标准差。每一行都是一个时间窗口。 features_y 有 1 列和 100 行。我只是计算每个时间窗口中一个传感器的平均值。

最佳答案

问题在于您在输入上使用 reshape 函数的位置:

clf.fit(features_x.reshape(-1,1),features_y.reshape(-1,1))

数组 features_x 具有多列,经过此 reshape 后,它只有一列包含所有元素,因此它变得比 features_y 长,因此会出现错误。为了让您更好地了解正在发生的情况,请考虑以下示例:

In [4]: a = np.zeros(8).reshape((4,2))                                                                        

In [5]: a                                                                                                     
Out[5]: 
array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]])

In [6]: a.reshape(-1,1)                                                                                       
Out[6]: 
array([[0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.]])

您可以使用任何 numpy 数组的 .shape 属性来验证输入和标签的格式。

关于python - 使用具有多个输入因子的 sklearn 决策树进行回归会产生错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59841434/

相关文章:

python - 最佳实践 : handle functions with lots of parameters and reserved names

Python 正则表达式lookbehind 交换组

winapi - 使用 ToUnicode(正确,所以它有效)

python - 模型的准确度是 0.86 而 AUC 是 0.50?

python - XGBoost 产生非二进制预测

python - Nginx + Gunicorn + Flask : How to figure out the real base URL

python - 合并(更新\插入)pandas 数据帧的更好方法

jquery - 将输入 [type ='number' ] 上的值乘以 250

java - System.out.print() 读取 Scanner 错误

python - 按升序生成 Kmeans 的质心