假设我们有 4 个张量,a
、b
、c
和 d
,它们都具有相同的维度(batch_size, T, C)
,我们要创建一个新的张量 X
,其形状为 (batch_size, T*4, C)
> 其中 T*4
在所有张量之间交错循环。
例如,如果a
、b
、c
和d
是全一、二的张量,分别是三和四,我们期望 X
类似于
[[[1,1,1...],
[2,2,2...],
[3,3,3...],
[4,4,4...],
[1,1,1...],
[2,2,2...],
.
.
.
]]
最佳答案
在我看来,您的示例数组实际上具有形状 (batch_size, T, C*4)
而不是 (batch_size, T*4, C)
。不管怎样,你可以通过 tf.concat、tf.reshape 和 tf.transpose 得到你需要的东西。一个更简单的 2d 示例如下:
A = tf.ones([2,3])
B = tf.ones([2,3]) * 2
AB = tf.concat([A,B], axis=1)
AB = tf.reshape(AB, [-1, 3])
AB.eval() #array([[1., 1., 1.],
# [2., 2., 2.],
# [1., 1., 1.],
# [2., 2., 2.]], dtype=float32)
连接 A 和 B 以获得形状为 (2,6) 的矩阵。然后你 reshape 它的形状,使行交错。要在 3d 中执行此操作,乘以 4 的维度必须是最后一个维度。因此,您可能需要使用 tf.transpose,使用 concat 和 reshape 进行交错,然后再次转置以重新排序维度。
关于python - TensorFlow 连接/堆叠 N 个张量交错最后一个维度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56831972/