python - 将 tensorflow 数据集输入模型

标签 python tensorflow keras tf.keras

我有一个包含 102 个特征的输入数据集,并且有相应的二进制输出。输出为 0 或 1,具体取决于 102 个特征。

输入:

tf.Tensor(
[-1.72999993e-01 -8.20000023e-02  3.38000000e-01  1.35000005e-01
  ...
  0.00000000e+00  2.00000009e-03], shape=(102,), dtype=float64)

输出:

tf.Tensor([1], shape=(1,), dtype=int32)

我正在尝试遵循此custom training tutorial并按如下方式创建此模型:

train_dataset = tf.data.Dataset.from_tensor_slices((train_x,tf.dtypes.cast(label_x, tf.int32)))
features, labels = next(iter(train_dataset))

model = tf.keras.Sequential([
  tf.keras.layers.Dense(10, activation=tf.nn.relu, input_shape=(102,)),  # input shape required
  tf.keras.layers.Dense(10, activation=tf.nn.relu),
  tf.keras.layers.Dense(1)
])

predictions = model(features)

但是,当我尝试运行它时出现错误:

---------------------------------------------------------------------------

InvalidArgumentError                      Traceback (most recent call last)

<ipython-input-12-d7be7f733930> in <module>()
      6 ])
      7 
----> 8 predictions = model(features)

7 frames

/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: In[0] is not a matrix. Instead it has shape [102] [Op:MatMul]

最佳答案

您需要调整用于创建数据集的batch,或者调整模型中的input_shape以适应尺寸。

train_x = np.arange(100, dtype=np.int32)
label_x = np.arange(100, dtype=np.int32)

train_dataset = tf.data.Dataset.from_tensor_slices((train_x, label_x)).batch(10)

model = tf.keras.Sequential([
  tf.keras.layers.Dense(10, activation=tf.nn.relu, input_shape=(1,)),  # input shape required
  tf.keras.layers.Dense(10, activation=tf.nn.relu),
  tf.keras.layers.Dense(1)
])

for features, labels in train_dataset:
    pred = model(features[..., tf.newaxis])
print(pred)

#tf.Tensor(
#[[-21.829016]
# [-22.071556]
# [-22.314102]
# [-22.556648]
# [-22.799194]
# [-23.041737]
# [-23.284283]
# [-23.52683 ]
# [-23.76937 ]
# [-24.011917]], shape=(10, 1), dtype=float32)

关于python - 将 tensorflow 数据集输入模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58536836/

相关文章:

python - 创建具有多个输入的 TimeseriesGenerator

python - 如何过滤pandas数据框?

python - 使用 chromedriver 和 Selenium 创建 Python 可执行文件

python-3.x - 已存在另一个同名指标

python - 尝试运行 TensorFlow 时的 CUDNN_STATUS_NOT_INITIALIZED

python - “密集”对象没有属性 'op'

python - LSTM 验证

python - 使用队列 Tensorflow 训练模型

java - 将后端代码(Java、Python)与 HTML 集成

python - Tensorflow:在cpu上的多个线程中加载数据