我有一个形状为 [None, 9, 2]
的 tensorflow 的输入(其中 None
是批处理)。
要对其执行进一步的操作(例如 matmul),我需要将其转换为 [None, 18]
形状。怎么做?
最佳答案
您可以使用 tf.reshape() 轻松完成,而无需知道批量大小。
x = tf.placeholder(tf.float32, shape=[None, 9,2])
shape = x.get_shape().as_list() # a list: [None, 9, 2]
dim = numpy.prod(shape[1:]) # dim = prod(9,2) = 18
x2 = tf.reshape(x, [-1, dim]) # -1 means "all"
-1
最后一行表示整列,无论运行时的批处理大小如何。您可以在 tf.reshape() 中看到它.更新:形状 = [无,3,无]
谢谢@kbrose。对于超过 1 个维度未定义的情况,我们可以使用 tf.shape()与 tf.reduce_prod()或者。
x = tf.placeholder(tf.float32, shape=[None, 3, None])
dim = tf.reduce_prod(tf.shape(x)[1:])
x2 = tf.reshape(x, [-1, dim])
tf.shape() 返回一个可以在运行时评估的形状张量。 tf.get_shape() 和 tf.shape() 的区别可见in the doc .
我也在另一个 .contrib.layers.flatten() 中尝试过。第一种情况最简单,但不能处理第二种情况。
关于tensorflow - 在 tensorflow 中展平批处理,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36668542/