我正在使用 Tensorflow v1.15。我有以下示例中给出的 TensorArray
的非常基本的实现:
import tensorflow as tf
an_array = tf.TensorArray(dtype=tf.float32, size=5, dynamic_size=True, clear_after_read=False, element_shape=(16, 7, 2))
for i in range(5):
val = tf.random.normal(shape=(16, 7, 2))
an_array.write(i, val)
print(tf.Session().run(val))
tensors = [an_array.read(j) for j in range(5)]
print(tf.Session().run(tensors))
for 循环中的 print
不会打印全零,而最后一个 print
语句会打印。为什么会这样?谢谢。
最佳答案
tf.TensorArray.write
返回一个新的 tf.TensorArray
写操作发生的地方。一般来说,这个函数的输出应该替换之前对数组的引用:
import tensorflow as tf
an_array = tf.TensorArray(dtype=tf.float32, size=5, dynamic_size=True,
clear_after_read=False, element_shape=(16, 7, 2))
for i in range(5):
val = tf.random.normal(shape=(16, 7, 2))
# Replace tensor array reference with the written one
an_array = an_array.write(i, val)
print(tf.Session().run(val))
tensors = [an_array.read(j) for j in range(5)]
print(tf.Session().run(tensors))
关于python - 在 Tensorflow 中读取 TensorArray 总是返回零,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60259037/