python - 函数内的 Tensorflow variable_scope : tf. 占位符和 tf.get_variable

标签 python tensorflow naming

我只是试图了解 TensorFlow 命名行为,但仍然需要一些说明。 我从事的一个项目在张量的命名方面遇到了麻烦,因为它们是在函数中预定义的,稍后会调用该函数。

所以我这里有以下示例:

import tensorflow as tf


    def foo():

        with tf.variable_scope("foo", reuse=True):

            a = tf.placeholder(tf.float32,name="a")
            b = tf.placeholder(tf.float32,name="b")

        return a,b
    ##

    a,b  = foo()

    print(a)
    print(b)

我得到输出:

Tensor("foo/a:0", dtype=float32)
Tensor("foo/b:0", dtype=float32)

当我再次调用它时,我得到输出:

Tensor("foo_1/a:0", dtype=float32)
Tensor("foo_1/b:0", dtype=float32)

为什么会出现这样的情况呢?我将重用设置为 true,因此我希望张量再次位于同一变量范围“foo”中,或者程序会抛出“张量已定义”之类的错误。

所以,我尝试了使用 tf.get_variable 的解决方法:

    def foo():
    with tf.variable_scope("foo", reuse=True):

        a = tf.get_variable("v", [1])


    return a
##

a1 = foo()
print(a1)

graph = tf.get_default_graph()
#call tensors by name in tensorflow to avoid confusion with the naming
graph.get_tensor_by_name("foo/v:0")

在这里,我总是得到相同的输出:

<tf.Variable 'foo/v:0' shape=(1,) dtype=float32_ref>

不幸的是,我无法使用变量,因为您无法为它们定义动态形状。您需要占位符来定义可变形状。 有人可以解释一下为什么程序继续为占位符创建新的variable_scopes,但当我调用 tf.get_variable() 时却没有?

谢谢!

最佳答案

您可以通过在名称后添加“/”来强制重复使用范围,即:tf.variable_scope("foo/", reuse=True):

但这并不能解决您的问题。

对于变量,调用tf.Variable将始终创建一个新变量,而调用 tf.get_variable如果已经存在,将重用它。

但是对于占位符来说,没有 tf.get_placeholder .

您可以做的就是在 foo 之外定义占位符,仅一次,并使用 tf.get_default_graph().get_tensor_by_name(name) 按名称获取它们或者在需要时直接使用 python 变量。

示例为 get_tensor_by_name :

import tensorflow as tf

with tf.name_scope("scope"):
    tf.placeholder(tf.float32,name="a")
    tf.placeholder(tf.float32,name="b")

def foo():
    a = tf.get_default_graph().get_tensor_by_name("scope/a:0")
    b = tf.get_default_graph().get_tensor_by_name("scope/b:0")

    return a,b

a,b = foo()

print(a)
print(b)

请注意,与变量不同,占位符不维护可重用或不可重用的状态。它们只是指向稍后将提供的张量的“指针”。它们不应该是模型的一部分,而应该是模型的输入,因此无论如何您都不应该多次创建它们。

关于python - 函数内的 Tensorflow variable_scope : tf. 占位符和 tf.get_variable,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50974629/

相关文章:

python - 在 Python 行为测试框架中处理异常

python - 如何绕过第一组条件

tensorflow - 在 IBM power8 上安装 TensorFlow

python - 如何创建具有本地连接层和密集父层的本地连接层?

ruby-on-rails - ActiveRecord 模型名称的约定

python - 值与一组值的矢量化比较

Python - 在列表列表中查找列表

python - Tensorflow CNN 模型未训练?恒定损耗和准确性

naming-conventions - "Pascal Case"这个词是从哪里来的?

python - 'site' 中的 'site-packages' 到底是什么意思?