python - 使用神经网络进行分类

标签 python machine-learning neural-network classification

我已经建立了一个用于分类的神经网络,但是在尝试编译时,我遇到了输入和输出维度的问题:

from keras.models import Sequential
from keras.layers import Dense

# data splited into input (X) and output (y) variables
model = Sequential()
model.add(Dense(12, input_dim=456, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(8, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

这是我的 yX 的尺寸

print(y.shape, X.shape)
(8000, 1) (8000, 456, 3)

我有 8000 个子集,其中包含 456 个粒子(x,y,z); 我有 y 范围从 0 到 7 的标签;这也是为什么我的输出层有 8 个节点。

但是当我适应

model.fit(X, y, epochs=15, batch_size=10)

我不明白为什么会发生这个错误:

ValueError: Error when checking input: expected dense_26_input to have 2 dimensions, but got array with shape (8000, 456, 3)

有什么建议吗?

最佳答案

要回答您的问题,您可以通过执行以下操作来实现您想要的目标:

from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense

model = Sequential()
model.add(Dense(12, input_shape=(456,3), activation='relu'))

model.add(Dense(8, activation='relu'))
model.add(Dense(8, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

model.summary()

编辑:

我认为您正在寻找的是这种类型的架构:

from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Flatten

model = Sequential()
model.add(Dense(12, input_shape=(456,3), activation='relu'))
model.add(Flatten())
model.add(Dense(8, activation='relu'))
model.add(Dense(8, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.summary()

这样它只输出8个标签

关于python - 使用神经网络进行分类,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59988844/

相关文章:

python - 与 SciPy.optimize 的并行性

python - Numpy 和 TensorFlow 之间的差异

java - 使用 DL4J 评估图像,类似于 AlphaGo

sql - 在数据库中存储神经网络的最佳实践

python - tensorflow 损失中的logits可以是占位符

python - 如何让python优雅地失败?

python - 为什么我不能在 Python 中访问父类(super class)的私有(private)变量?

python - 使用 python graphviz ImportError : No module named _gv

python - 如何将 pandas.to_datetime 与 "strange"字符串格式一起使用

r - 从数据帧创建稀疏矩阵