python - NP阵列的形状为(2186, 128)。我有兴趣将获得的数组应用到SVM

标签 python arrays numpy machine-learning svm

我提取了 CNN 倒数第二层(#第 12 层)的输出。提取的 NP 数组的形状为 (2186, 128)。我有兴趣将获得的数组应用到 SVM 中。

提取特征的代码:

import numpy as np
X_train=np.array(get_activations(model=model,layer=12, X_batch=x_train)[0], dtype=np.float32)
print(X_train)

这给了我形状 (2186, 128) 的输出

将上述 np 数组应用到 SVM 的代码:

from sklearn.svm import SVC
clf = SVC()
clf.fit(X_train, y)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
    decision_function_shape='ovr', degree=3, gamma='auto', kernel='rbf',
    max_iter=-1, probability=False, random_state=None, shrinking=True,
    tol=0.001, verbose=False)

这给出了错误:

ValueError                                Traceback (most recent call last)
<ipython-input-54-2d6b8b03f3c1> in <module>()
----> 1 clf.fit(X_train, y)
      2 SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
      3     decision_function_shape='ovr', degree=3, gamma='auto', kernel='rbf',
      4     max_iter=-1, probability=False, random_state=None, shrinking=True,
      5     tol=0.001, verbose=False)

~/anaconda3/lib/python3.6/site-packages/sklearn/svm/base.py in fit(self, X, y, sample_weight)
    147         self._sparse = sparse and not callable(self.kernel)
    148 
--> 149         X, y = check_X_y(X, y, dtype=np.float64, order='C', accept_sparse='csr')
    150         y = self._validate_targets(y)
    151 

~/anaconda3/lib/python3.6/site-packages/sklearn/utils/validation.py in check_X_y(X, y, accept_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, multi_output, ensure_min_samples, ensure_min_features, y_numeric, warn_on_dtype, estimator)
    576                         dtype=None)
    577     else:
--> 578         y = column_or_1d(y, warn=True)
    579         _assert_all_finite(y)
    580     if y_numeric and y.dtype.kind == 'O':

~/anaconda3/lib/python3.6/site-packages/sklearn/utils/validation.py in column_or_1d(y, warn)
    612         return np.ravel(y)
    613 
--> 614     raise ValueError("bad input shape {0}".format(shape))
    615 
    616 

ValueError: bad input shape (2186, 3)

最佳答案

看起来你的y有one-hot-encoding(3类)。将其转换为整数标签 (0, 1, 2) 就可以了。

关于python - NP阵列的形状为(2186, 128)。我有兴趣将获得的数组应用到SVM,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50890791/

相关文章:

Python 在用户生成的整数列表、While 循环和 Try-Except 中防止重复

python - 当我尝试执行 pip install 时出现错误

python - 使用两个 bool 数组索引 2D np.array 时出现意外行为

python - 给定条件的矩阵上的 Numpy 高级索引

python - 在 numpy 中获取唯一行位置的更快方法是什么

python - 检索与 pandas 中另一列中元素第一次出现相对应的列中的值 - python

python - 为什么我不能按用户属性查询?

javascript - 递归循环 JavaScript 后使用自定义 HTML 元素作为 JSON 键变量

c++ - 如何在 C++ 中创建一个大小为 1 位的元素数组

python - 我可以在 PyPy 中嵌入 CPython 吗?