python - 如何获取包含不同形状张量的 TensorArray 中的值

标签 python tensorflow

我通过 tf.while_loop() 得到一个 TensorArray,其中包含不同形状张量的列表,但我不知道如何将它们作为带有张量的普通列表。

例如:

TensorArray([[1,2], [1,2,3], ...]) -> [Tensor([1,2]), Tensor([1,2,3]), ...]
res = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True, infer_shape=False)
res = res.write(0, (1, 2))
res = res.write(0, (1, 2, 3))
with tf.Session() as sess:                                                        
     print sess.run(res.stack())

我在 sess.run(res.stack())

中收到错误消息

TensorArray has inconsistent shapes. Index 0 has shape: [2] but index 1 has shape: [3]

最佳答案

通常,您无法在张量数组中创建张量列表,因为它的大小仅在图执行时才知道。但是,如果您提前知道大小,则可以自己列出读取操作的列表:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    res = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True, infer_shape=False)
    res = res.write(0, (1, 2))
    res = res.write(1, (1, 2, 3))
    print(res.size()) # Value only known on graph execution
    # Tensor("TensorArraySizeV3:0", shape=(), dtype=int32)
    # Can make a list if the size is known in advance
    tensors = [res.read(i) for i in range(2)]
    print(tensors)
    # [<tf.Tensor 'TensorArrayReadV3:0' shape=<unknown> dtype=int32>, <tf.Tensor 'TensorArrayReadV3_1:0' shape=<unknown> dtype=int32>]
    print(sess.run(tensors))
    # [array([1, 2]), array([1, 2, 3])]

否则,您仍然可以使用 while 循环来迭代张量数组。例如,您可以像这样打印其内容:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    res = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True, infer_shape=False)
    res = res.write(0, (1, 2))
    res = res.write(1, (1, 2, 3))
    def loop_body(i, res):
        # Must import the following in Python 2:
        # from __future__ import print_function
        with tf.control_dependencies([tf.print(res.read(i))]):
            return i + 1, res
    i, res = tf.while_loop(
        lambda i, res: i < res.size(),
        loop_body,
        (tf.constant(0, tf.int32), res))
    print(sess.run(i))
    # [1 2]
    # [1 2 3]
    # 2

关于python - 如何获取包含不同形状张量的 TensorArray 中的值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56786937/

相关文章:

python - Bash 或 python 用于更改文件中的间距

python - 如何在python3中执行部分映射交叉?

python - 使用 Saver v2 的 Tensorflow 加载模型

python - 损失函数正在减少但度量函数保持不变?

python - RuntimeError : Expected all tensors to be on the same device, 但在恢复训练时发现至少两个设备,cuda :0 and cpu!

python - PyObject_CallMethod 返回 Null,提供了简单示例

python - 无法让简单的 TFRecord 阅读器工作

python - 训练过程中GAN结果图像是相同的

tensorflow - 使用 tensorflow 模型在 Google Cloud 中进行预测失败

python - 来自 plt.subplots() 的 matplotlib 轴的精确类型注释数组 (numpy.ndarray)