python - TensorFlow:将张量拆分为 `batch_size` 片

标签 python tensorflow

我有一个名为 tensor 的 rank-3 张量,形状为 [batch_size, axis_1, axis_2] 并想将其拆分为 batch_size 切片像这样沿着第一个轴:

batch_size = tf.shape(tensor)[0]

batch_items = tf.split(tensor, num_or_size_splits=batch_size, axis=0)

不幸的是,这不起作用,因为 batch_size 的值在图的构造过程中尚不清楚。

我该如何解决这个问题?

我收到这个错误:

TypeError: Expected int for argument 'num_split' not <tf.Tensor 'decoded_predictions/strided_slice_15:0' shape=() dtype=int32>.

奇怪的是,尝试在其他 TensorFlow 函数中使用 batch_size 似乎可行:

tensor = tf.reshape(tensor, [batch_size, -1])

尽管 batch_size 的值在图形构造期间未知,但仍能正常工作。

tf.split() 是否有问题?

最佳答案

解决方法是:

batch_items = tf.map_fn(fn=lambda k: tensor[...,k],
                        elems=tf.range(batch_size),
                        dtype=tf.float32)

不过,我仍然对更好的解决方案感兴趣。

关于python - TensorFlow:将张量拆分为 `batch_size` 片,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49088778/

相关文章:

python - 如何使用 keras 计算具有 4 个神经元的输出的类权重?

python - Tensorflow 2 坐标分类器

python - 无法将 TensorFlow (Keras) 模型转换为 ONNX

python - Pandas 在将所有日期转换为一周的开始日期时出错

python - 尝试设置 cookie 但得到 TypeError : string indices must be integers, not str

python - 在 Keras 模型中保存元数据/信息

python - keras tensorflow 代码中的K.eval(loss)是什么

python - 如何使用 scikit learn 计算多类案例的准确率、召回率、准确率和 f1 分数?

Python Spark Streaming 仅运行一次

python - 向 HSM 发出 DE 命令时在响应代码中获取 02