python - 如何使用 numpy 数组作为 Tensorflow CNN 的输入而不会出现维度不匹配的情况

标签 python arrays numpy tensorflow

我将 tensorflow 卷积网络的输入作为 4 级张量 (32, 32, 3, 73257) (73257 来自 imgs 的数量),它们是 numpy 数组,但我的 x 输入的占位符变量是二维的, (无,3072)。 3072来自img高度ximg宽度x channel 。我的问题是,如何 reshape 或使用图像以使它们与占位符兼容?

附注这些图像来自 SVHN 裁剪后的 32x32 数据集

images = np.array(features, dtype=np.float32)
...
x = tf.placeholder(tf.float32, shape=[None, 3072])
...
for _ in range(1000):
  batch = next_batch(50, images, labels)
  train_step.run(feed_dict={x: batch[0], y_: batch[1]})
...
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  for i in range(20000):
    batch = next_batch(50, images, labels)
    if i % 100 == 0:
      train_accuracy = accuracy.eval(feed_dict={
          x: batch[0], y_: batch[1], keep_prob: 1.0})
      print('step %d, training accuracy %g' % (i, train_accuracy))
    train_step.run(feed_dict={x: images, y_: labels, keep_prob: 0.5})

最佳答案

假设您有 73257 张 32 x 32 像素的图像,包含 3 个波段(例如 RGB)。你可以做一个

input = tf.transpose(input, [3, 0, 1, 2])

将最后一个维度放在第一位。张量应类似于 (73257, 32, 32, 3)。

然后做

input = tf.reshape(input, [-1, 3072])

减少维度。张量应类似于 (73257, 3072)。

关于python - 如何使用 numpy 数组作为 Tensorflow CNN 的输入而不会出现维度不匹配的情况,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46231578/

相关文章:

Python selenium webdriver 代码性能

php - 在 MYSQL 中读取和写入 PHP 数组

python - numpy.unique 生成的列表在哪些方面是唯一的?

Python - 切片数组直到满足特定条件

javascript - 数组中的 indexOf 行为不正确

Python 使用 np.reshape 按一定顺序 reshape 数组

python - uWSGI 并优雅地杀死一个多线程 Flask 应用程序

python - Flask - 错误请求浏览器(或代理)发送了此服务器无法理解的请求

python - 如何创建一个空的 pandas DataFrame,并为每列分配不同的数据类型?

c - 如何在C中使用二维数组