tensorflow - 如何在for循环中创建张量列表并在tensorflow2中使用tf.stack

标签 tensorflow

我正在尝试创建一个张量列表,并在 tensorflow2 中使用 for 循环将它们堆叠在一起。我创建了一个测试示例并尝试如下。

import tensorflow as tf

@tf.function
def test(x):
    tensor_list = []
    for i in tf.range(x):
        tensor_list.append(tf.ones(4)*tf.cast(i, tf.float32))
    return  tf.stack(tensor_list)

result = test(5)
print(result)

但我收到如下错误:
Traceback (most recent call last):
  File "test.py", line 10, in <module>
    result = test(5)
  File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 457, in __call__
    result = self._call(*args, **kwds)
  File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 503, in _call
    self._initialize(args, kwds, add_initializers_to=initializer_map)
  File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 408, in _initialize
    *args, **kwds))
  File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 1848, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2150, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2041, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py", line 915, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 358, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py", line 905, in wrapper
    raise e.ag_error_metadata.to_exception(e)
tensorflow.python.framework.errors_impl.InaccessibleTensorError: in converted code:

    test.py:8 test  *
        return  tf.stack(tensor_list)
    /root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/util/dispatch.py:180 wrapper
        return target(*args, **kwargs)
    /root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/ops/array_ops.py:1165 stack
        return gen_array_ops.pack(values, axis=axis, name=name)
    /root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/ops/gen_array_ops.py:6304 pack
        "Pack", values=values, axis=axis, name=name)
    /root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/framework/op_def_library.py:793 _apply_op_helper
        op_def=op_def)
    /root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py:544 create_op
        inp = self.capture(inp)
    /root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py:603 capture
        % (tensor, tensor.graph, self))

    InaccessibleTensorError: The tensor 'Tensor("mul:0", shape=(4,), dtype=float32)' cannot be accessed here: it is defined in another function or code block. Use return values, explicit Python locals or TensorFlow collections to access it. Defined in: FuncGraph(name=while_body_8, id=139870442952744); accessed from: FuncGraph(name=test, id=139870626510608).

有谁知道我做错了什么?如何创建张量列表并将它们与 tensorflow 2 中的 for 循环堆叠在一起?

最佳答案

循环张量通常应该使用“tf.map_fn”来完成。这是一个有效的解决方案:

import tensorflow as tf
import numpy as np

@tf.function
def test(x):

    tensor_list = tf.map_fn(lambda inp: tf.ones(4)*tf.cast(inp, tf.float32), x, dtype=tf.dtypes.float32)

    return  tf.stack(tensor_list)

result = test(np.arange(5))
print(result)

但是,您必须在 test() 中输入一个真实数组。函数,但也可以调用 tf.range()tf.function将标量转换为张量。

关于tensorflow - 如何在for循环中创建张量列表并在tensorflow2中使用tf.stack,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58811289/

相关文章:

tensorflow - 损失随着批量归一化而增加(tf.Keras)

tensorflow - 更改 Inception-v4 架构以在 Tensorflow 中进行多标签分类

python - tensorflow 中的 tf.Session() 之外的参数值是否可用?

python - 如何 "one hot encode"Tensorflow 数据集?

pandas - 如何更改 tensorflow 的 numpy 数组的数据类型

python-3.x - TensorFlow - 简单前馈神经网络未训练

machine-learning - 我可以使用 `tf.nn.dropout`来实现DropConnect吗?

python - 如何使用相同的测试数据测试 .tflite 模型以证明它的行为与原始模型一样?

tensorflow - tensorflow 中的移动维度或滚轴

python - reduce_sum() 在 tensorflow 中是如何工作的?