tensorflow - 检查目标 : expected dense_Dense2 to have shape x, 时出错,但得到形状为 y 的数组

标签 tensorflow tensorflow2.0 tensorflow.js

这是我在 tensorflow 中迈出的第一步。

想法

有一些数字模式(数字数组:Pattern = number[])。以及与此模式对应的类别(从 0 到 2 的数字:Category = 0 | 1 | 2)。我遵循结构数据:xs = Pattern[]ys = Category[]

例如:

xs = [[1, 2, 3, 4], [5, 6, 7, 8], ..., [9, 10, 11, 12]];
ys = [1, 0, ..., 2];

我希望神经网络找到 xs[0]xy[0] 之间的匹配,依此类推。我想传递像 [1, 2, 3, 4] 这样的神经网络数据并获得接近 1 的结果。

model.predict(tf.tensor([1, 2, 3, 4])) // ≈1

我的代码

import * as tf from '@tensorflow/tfjs';
require('@tensorflow/tfjs-node');

const xs = tf.tensor2d([
  [1, 2, 3, 4],
  [5, 6, 7, 8],
  [9, 10, 11, 12],
]);
const ys = tf.tensor1d([0, 1, 2]);

const model = tf.sequential();
model.add(tf.layers.dense({ units: 4, inputShape: xs.shape, activation: 'relu' }));
                                   ^ - Pattern length, it is constant
model.add(tf.layers.dense({ units: 3, activation: 'softmax' }));
model.compile({ optimizer: 'adam', loss: 'categoricalCrossentropy', metrics: ['accuracy'] });

model.fit(xs, ys, { epochs: 500 });

我收到以下错误:

Error when checking input: expected dense_Dense1_input to have 3 dimension(s). but got array with shape 3,4

我不明白如何解释我的神经网络数据结构。

最佳答案

模型 inputShape 为 [3,4] 。为了拟合或预测该模型,它需要 [b, 3, 4] 形式的数据,其中 b 是批处理形状。尝试使用 xs 拟合模型时,批量形状丢失。

模型 inputShape 应该是 [4],以便 xs 可以用于预测。可以使用 xs.shape.slice(-1),而不是使用 xs.shape

const xs = tf.tensor2d([
  [1, 2, 3, 4],
  [5, 6, 7, 8],
  [9, 10, 11, 12],
]);
const ys = tf.tensor1d([0, 1, 2]);

const model = tf.sequential();
model.add(tf.layers.dense({ units: 4, inputShape: xs.shape.slice(1), activation: 'relu' }));
                                  
model.add(tf.layers.dense({ units: 3, activation: 'softmax' }));
model.compile({ optimizer: 'adam', loss: 'categoricalCrossentropy', metrics: ['accuracy'] });

model.fit(xs, ys);
model.predict(xs).print()

此外,如果模型的目标是通过使用 softmaxcategoricalCrossentropy 来预测类别,那么标签应该是 one-hot 编码的.

类似答案:

关于tensorflow - 检查目标 : expected dense_Dense2 to have shape x, 时出错,但得到形状为 y 的数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63976057/

相关文章:

python - 尚未在model.summary()上建立此模型错误

tensorflow - 如何在不涉及 IBM 云的情况下转换我用 Tensorflow (python) 训练的模型以用于 TensorflowJS(从我现在的步骤开始)?

python - 无法在Raspberry Pi 3上通过pip3安装opencv-python

tensorflow - 如何为 tensorflow 服务准备预热请求文件?

python - TF2 中的 export_saved_model 是否也保存权重?

python - 关于 tf.function 的跟踪是什么

tensorflow - tf.loadFrozenModel 和 tf.loadModel 的预测时间会不同吗?

javascript - 如何以 Angular 转换数据以校正张量格式?

python - 导入错误: cannot import name 'compile'

tensorflow - 即使尺寸不匹配,自定义损失函数也有效