我有一个名为 tensor
的 rank-3 张量,形状为 [batch_size, axis_1, axis_2]
并想将其拆分为 batch_size
切片像这样沿着第一个轴:
batch_size = tf.shape(tensor)[0]
batch_items = tf.split(tensor, num_or_size_splits=batch_size, axis=0)
不幸的是,这不起作用,因为 batch_size
的值在图的构造过程中尚不清楚。
我该如何解决这个问题?
我收到这个错误:
TypeError: Expected int for argument 'num_split' not <tf.Tensor 'decoded_predictions/strided_slice_15:0' shape=() dtype=int32>.
奇怪的是,尝试在其他 TensorFlow 函数中使用 batch_size
似乎可行:
tensor = tf.reshape(tensor, [batch_size, -1])
尽管 batch_size
的值在图形构造期间未知,但仍能正常工作。
tf.split()
是否有问题?
最佳答案
解决方法是:
batch_items = tf.map_fn(fn=lambda k: tensor[...,k],
elems=tf.range(batch_size),
dtype=tf.float32)
不过,我仍然对更好的解决方案感兴趣。
关于python - TensorFlow:将张量拆分为 `batch_size` 片,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49088778/