python - Tensorflow 2.0 stack() 引发未初始化的张量错误

标签 python tensorflow

我正在编写一个自定义层,我需要在其中循环遍历批处理维度,然后遍历图像的 rgb 维度。我仍在尝试了解 Tensorflow 如何实现 for 循环,但我不确定这与我在此处提出的错误有关。

下面是一些伪代码:

    @tf.function()
    def _crop_and_resize(self, imgs, boxes, to_size):
        # prepare kernel_h and kernel_w

        n_images = tf.shape(imgs)[0]
        outputs = tf.TensorArray(dtype=tf.float32, size=n_images)
        for i in tf.range(n_images):
            # in the call to _bilinear we enter the inner loop
            output = self._bilinear(
                kernel_h[i],
                kernel_w[i],
                imgs[i])
            outputs.write(i, output)
        return outputs.stack()


    def _bilinear(self, kernel_h, kernel_w, img):
        n_channels = tf.shape(img)[2]
        result_channels = tf.TensorArray(dtype=tf.float32, size=n_channels)
        for i in tf.range(n_channels):
            result_channels.write(i,
                tf.matmul(
                    tf.matmul(kernel_h, tf.tile(img[:, :, i], [1, 1])),
                    kernel_w, transpose_b=True))
        return tf.transpose(result_channels.stack(), perm=[1,2,0])

我收到以下错误:

InvalidArgumentError: Tried to stack list which only contains uninitialized tensors and has a non-fully-defined element_shape: [?,?,?] [[{{node model_17/att_1/PartitionedCall/TensorArrayV2Stack/TensorListStack}}]] [Op:__inference_distributed_function_11150] Function call stack: distributed_function



我见过很多使用 TensorArray 的例子。和 stack以这种方式用于单个 for 循环,但我不确定我的嵌套 for 循环是否导致问题。

最佳答案

我有一个类似的问题,并通过此错误响应中的评论解决了它:https://github.com/tensorflow/tensorflow/issues/30409#issuecomment-508962873
基本上,在急切模式下, .stack() 调用就地工作以方便起见,但在图形设置中,您需要将 .stack() 调用链接为图形中的节点,例如

outputs = outputs.write(i, output)


这为我解决了它。

关于python - Tensorflow 2.0 stack() 引发未初始化的张量错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61062833/

相关文章:

Python——sympysolve()返回另一个方程而不是值

python - 如何使用 Keras 功能 API 模型的输出作为另一个模型的输入

tensorflow - tensorflow 中的 tf.GraphKeys.TRAINABLE_VARIABLES 和 tf.GraphKeys.UPDATE_OPS 有什么区别?

python - 为我的 PyQt 应用程序选择 IPython Qt 控制台

python - Theano 中的名称冲突

python - ffmpeg|sed 命令的 subprocess.call 格式?

Tensorflow:空计算图和垃圾收集

python - ValueError : labels shape must be [batch_size, labels_dimension], 得到 (128, 2)

gradient-descent - TensorFlow的ReluGrad声称输入不是有限的

python - CSV 到文本文件布局