python - 在处理 MNIST 数据集时 opencv 中出现大小错误

标签 python opencv tensorflow machine-learning keras

我正在使用 OpenCV 和 ML 模块训练 MLP。我收到一个未知错误,但无法修复它:

"error: OpenCV(3.4.3) /io/opencv/modules/ml/src/data.cpp:257: error: (-215:Assertion failed) samples.type() == CV_32F || samples.type() == CV_32S in function 'setData'"

这是我的代码:

    from keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train.shape, y_train.shape
import numpy as np
np.unique(y_train)
import matplotlib.pyplot as plt
%matplotlib inline
for i in range(10):
  plt.subplot(2, 5, i+1)
  plt.imshow(X_train[i, :, :], cmap='gray')
  plt.axis('off')
from sklearn.preprocessing import OneHotEncoder
enc=OneHotEncoder(sparse=False, dtype=np.float32)
y_train_pre=enc.fit_transform(y_train.reshape(-1,1))
y_test_pre=enc.fit_transform(y_test.reshape(-1,1))
X_train_pre=X_train.reshape((X_train.shape[0], -1))
X_train_pre=X_train.astype(np.float32) /255.0
X_test_pre=X_test.reshape((X_test.shape[0], -1))
X_test_pre=X_test.astype(np.float32) / 255.0
import cv2
mlp=cv2.ml.ANN_MLP_create()
mlp.setLayerSizes(np.array([784, 512, 512, 10]))
mlp.setActivationFunction(cv2.ml.ANN_MLP_SIGMOID_SYM, 2.5, 1.0)
mlp.setTrainMethod(cv2.ml.ANN_MLP_BACKPROP)
mlp.setBackpropWeightScale(0.00001)
term_mode= (cv2.TERM_CRITERIA_MAX_ITER + cv2.TERM_CRITERIA_EPS)
term_max_iter=10
term_eps=0.01
mlp.setTermCriteria((term_mode, term_max_iter, term_eps))
mlp.train(X_train_pre, cv2.ml.ROW_SAMPLE, y_train_pre)

运行最后一个单元格后出现错误。意思就是在训练的时候!我无法修复它,但它们与层的大小有什么关系吗?或者使用 numpy 进行类型转换?如果你们能指导我,这会对我有帮助。预先感谢各位。

最佳答案

图像需要是一维向量,但它们是以形状[28,28]放入的。例如,这将 reshape 图像并起作用:

mlp.train(X_train_pre.reshape(60000,-1), cv2.ml.ROW_SAMPLE, y_train_pre)

关于python - 在处理 MNIST 数据集时 opencv 中出现大小错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54426725/

相关文章:

python - 在 Keras 中实现批量相关损失

tensorflow - 在 tensorflow 中,如何使用dynamic_decode的输出计算序列丢失

python - Pandas :为什么你可以对具有不同索引的系列进行算术运算,但不能进行比较?

python - 如何启用 python 的 json 包来编码 attrdict.AttrDict 对象?

python - Pandas read_csv : Columns are being imported as rows

python - 使用字典替换列表中的项目

opencv - opencv 3.0上的haartraining目录以构建mergevec

c++ - 受光照影响的边缘检测

python - 使用StereoBM测量绝对距离

python - 如何在 TensorFlow 图中正确引发异常