python - TensorFlow 创建动态形状变量

标签 python tensorflow

我需要创建一个只有在执行时才知道形状的tf.Variable

我将代码简化为以下要点。我需要在占位符中找到大于 4 的数字,并且在结果张量中需要scatter_update 将第二项转换为 24 常量。 p>

import tensorflow as tf

def get_variable(my_variable):
    greater_than = tf.greater(my_variable, tf.constant(4))
    result = tf.boolean_mask(my_variable, greater_than)
    # result = tf.Variable(tf.zeros(tf.shape(result)), trainable=False, expected_shape=tf.shape(result), validate_shape=False)   # doesn't work either
    result = tf.get_variable("my_var", shape=tf.shape(my_variable), dtype=tf.int32)
    result = tf.scatter_update(result, [1], 24)
    return result

input = tf.placeholder(dtype=tf.int32, shape=[5])
    created_variable = get_variable(input)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    result = sess.run(created_variable, feed_dict={input: [2, 7, 4, 6, 9]})
    print(result)

我找到了few questions但他们没有答案,也没有帮助我。

最佳答案

我遇到了同样的问题,偶然发现了相同的未解答的问题,并设法拼凑出一个解决方案,用于创建在图形创建时具有动态形状的变量。请注意,必须在第一次执行 tf.Session.run(...) 之前或首次执行时定义形状。

import tensorflow as tf

def get_variable(my_variable):
    greater_than = tf.greater(my_variable, tf.constant(4))
    result = tf.boolean_mask(my_variable, greater_than)
    zerofill = tf.fill(tf.shape(my_variable), tf.constant(0, dtype=tf.int32))
    # Initialize
    result = tf.get_variable(
        "my_var", shape=None, validate_shape=False, dtype=tf.int32, initializer=zerofill
    )
    result = tf.scatter_update(result, [1], 24)
    return result

input = tf.placeholder(dtype=tf.int32, shape=[5])
created_variable = get_variable(input)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    result = sess.run(created_variable, feed_dict={input: [2, 7, 4, 6, 9]})
    print(result)

诀窍是使用 shape=Nonevalidate_shape=False 创建一个 tf.Variable 并提交一个 tf.Variable 。形状未知的张量作为初始值设定项。

关于python - TensorFlow 创建动态形状变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52235203/

相关文章:

python - 字典列表中的最大值和最小值,其中每个字典键都有一个值

tensorflow - 将 Tensorflow 模型从对象检测 API 转换为 uff

tensorflow - TensorFlow 中的条件执行

tensorflow - 如何获得 TensorFlow 字符串的长度?

python - 如何编写 Keras 自定义指标来过滤或屏蔽某些值?

python - 如何使用python在另一个图像中查找图像

python - 捕获输入错误后第二次迭代出现问题

python - 如何在python中的POST请求中发送urlencoded参数

android - 格式错误的加密 mp3 到 m3u8

python - tensorflow 推理时的批量归一化