python - 在 Google Quickdraw 数据集上应用机器学习算法

标签 python machine-learning scikit-learn

我正在尝试应用 python scikit-learn 包中提供的机器学习算法来从一组涂鸦图像中预测涂鸦名称。

因为我是机器学习的初学者,而且我还不了解神经网络是如何工作的。我想尝试使用 scikit-learn 的算法。

我在名为 quickdraw 的 api 的帮助下下载了涂鸦(猫和吉他)。 .

然后我使用以下代码加载图像

import numpy as np
from PIL import Image
import random

#To hold image arrays
images = []

#0-cat, 1-guitar
target = []

#5000 images of cats and guitar each
for i in range(5000):

   #cat images are named like cat0.png, cat1.png ...
   img = Image.open('data/cats/cat'+str(i)+'.png')
   img = np.array(img)
   img = img.flatten()
   images.append(img)
   target.append(0)

   #guitar images are named like guitar0.png, guitar1.png ...
   img = Image.open('data/guitars/guitar'+str(i)+'.png')
   img = np.array(img)
   img = img.flatten()
   images.append(img)
   target.append(1)

random.shuffle(images)
random.shuffle(target)

然后我应用了算法:-

from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(images,target,test_size=0.2, random_state=0)

from sklearn.naive_bayes import GaussianNB
GB = GaussianNB()
GB.fit(X_train,y_train)
print(GB.score(X_test,y_test))

运行上述代码(也使用其他算法,如 SVM、MLP)后,我的系统就卡住了。我已经强制关机才能回来。我不确定为什么会发生这种情况。

我尝试通过更改来减少要加载的图像数量

for i in range(5000):

for i in range(1000):

但是我的准确率只有 50% 左右

最佳答案

首先,请允许我这么说:

Since I'm a complete beginner in machine learning and I have no knowledge about >how neural network work yet. I wanted to try with scikit-learn's algorithms.

总的来说,这不是学习 ML 的好方法,我强烈建议你至少开始学习基础知识,否则你根本无法知道发生了什么(这不是你可以通过尝试来弄清楚的)它)。

回到你的问题,将朴素贝叶斯方法应用于原始图像这不是一个好的策略:问题是图像的每个像素都是一个特征,并且对于图像,你可以获得非常高的数字轻松地调整尺寸(还假设每个像素独立于其邻居,这不是您想要的)。 NB 通常与文档一起使用,请查看 wikipedia 上的这个示例。可能会帮助您更多地了解该算法。

简而言之,NB 归结为计算联合条件概率,归结为计算特征(维基百科示例中的单词)的同时出现像素在你的例子中,这又归结为计算一个巨大的事件矩阵,你需要用它来制定你的 NB 模型。

现在,如果您的矩阵由一组文档中的所有单词组成,那么这在时间和空间上都会变得相当昂贵 (O(n^2)/2),其中 n是特征的数量;相反,想象一下矩阵由训练集中的所有像素组成,就像您在示例中所做的那样......这会爆炸得非常快。

这就是为什么将数据集削减为 1000 个图像可以让您的电脑不会耗尽内存。 希望对您有所帮助。

关于python - 在 Google Quickdraw 数据集上应用机器学习算法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58362063/

相关文章:

python - Pandas 'Replacement not allowed with overlapping keys and values'

machine-learning - 为什么在决策树中使用交叉熵而不是0/1损失

python - scikit-learn 中逻辑回归的输入格式与 R 中一样

python - 如何将csv文件保存在django的静态文件夹中?

python - 运行 flask 应用程序时将数据传递给配置?

python - 在keras中拆分图层的输出

python - 使用神经网络根据用户输入预测结果

python - scikit learn 插补 NaN 以外的值

python-3.x - 使大型数据集的 CountVectorizer 更快

python - 使用 BeautifulSoup 进行网页抓取(Jupyter Notebook)