tensorflow - 在 tensorflow 中展平批处理

标签 tensorflow

我有一个形状为 [None, 9, 2] 的 tensorflow 的输入(其中 None 是批处理)。

要对其执行进一步的操作(例如 matmul),我需要将其转换为 [None, 18]形状。怎么做?

最佳答案

您可以使用 tf.reshape() 轻松完成,而无需知道批量大小。

x = tf.placeholder(tf.float32, shape=[None, 9,2])
shape = x.get_shape().as_list()        # a list: [None, 9, 2]
dim = numpy.prod(shape[1:])            # dim = prod(9,2) = 18
x2 = tf.reshape(x, [-1, dim])           # -1 means "all"
-1最后一行表示整列,无论运行时的批处理大小如何。您可以在 tf.reshape() 中看到它.

更新:形状 = [无,3,无]

谢谢@kbrose。对于超过 1 个维度未定义的情况,我们可以使用 tf.shape()tf.reduce_prod()或者。
x = tf.placeholder(tf.float32, shape=[None, 3, None])
dim = tf.reduce_prod(tf.shape(x)[1:])
x2 = tf.reshape(x, [-1, dim])

tf.shape() 返回一个可以在运行时评估的形状张量。 tf.get_shape() 和 tf.shape() 的区别可见in the doc .

我也在另一个 .contrib.layers.flatten() 中尝试过。第一种情况最简单,但不能处理第二种情况。

关于tensorflow - 在 tensorflow 中展平批处理,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36668542/

相关文章:

python - 有没有办法加速 tf.keras 中的嵌入层?

tensorflow - OutputProjectionWrapper 与 RNN 之上的全连接层

python-3.x - "import keras"和 "import tensorflow.keras"有什么区别

tensorflow - 在 tensorflow 中跨 session 使用共享变量

python - 在联邦学习中将数据拆分为训练和测试

neural-network - 如何重用现有的神经网络来使用 TensorFlow 训练新的神经网络?

python - 类型错误 : while_loop() got an unexpected keyword argument 'maximum_iterations' In Jupyter Azure

python - 如何为keras使用自定义损失函数

python - tensorflow 2.0,模型.fit(): Your input ran out of data

python - 在 Keras 中实现 Rprop 算法