python - 在 TensorFlow 中初始化变量的最佳方法是什么?

标签 python tensorflow tensor

在 TensorFlow 中,我可以通过两种方式初始化变量:

  1. 在声明变量之前调用global_variable_intializer():

    import tensorflow as tf
    
    # Initialize the global variable and session
    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
    
    W = tf.Variable([.3], tf.float32)
    x = tf.Variable([-.3], tf.float32)
    b = tf.Variable([-.3], tf.float32)
    linear_model = W * x + b
    
  2. 在声明变量后调用global_variable_intializer():

    import tensorflow as tf
    
    W = tf.Variable([.3], tf.float32)
    x = tf.Variable([-.3], tf.float32)
    b = tf.Variable([-.3], tf.float32)
    linear_model = W * x + b 
    
    # Initialize the global variable and session
    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
    

两者有什么区别?初始化变量的最佳方式是什么?

编辑

这是我正在运行的实际程序:

import tensorflow as tf

# Initialize the global variable and session
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

W = tf.Variable([.3], tf.float32)
b = tf.Variable([-.3], tf.float32)

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)



linear_model = W * x + b

square_delta = tf.square(linear_model - y)

loss = tf.reduce_sum(square_delta)

fixW = tf.assign(W, [-1.])
fixb = tf.assign(b, [1.])

sess.run([fixW, fixb])

print(sess.run(loss, {x:[1,2,3,4], y:[0,-1,-2,-3]}))

最佳答案

情况1,变量没有初始化,如果你试试

sess.run(linear_model)

它应该给你一些错误(我的编译器上的 FailedPreconditionError)。

案例 2 是工作案例。

命令

tf.global_variables_initializer()

应该在创建完所有变量后调用,否则会报同样的错误。

据我了解,每次调用 tf.Variable 时,与变量相关的节点都会添加到图中。这些是:

Variable/initial_value
Variable
Variable/Assign
Variable/read

(您使用命令获得到目前为止构建的节点

for n in tf.get_default_graph().as_graph_def().node:
    print n.name

)

变量本身没有任何值,直到您在 session 中运行变量/分配节点。

命令

init = tf.global_variables_initializer() 

创建一个节点,其中包含到目前为止构造的所有变量的所有分配节点,并将其关联到 python 变量“init”,以便在执行时执行该行

sess.run(init)

所有变量都获取初始值。

关于python - 在 TensorFlow 中初始化变量的最佳方法是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44284580/

相关文章:

c++ - 采用 Eigen::Tensor 的函数 - 模板参数推导失败

python - NLTK 无法找到 gs 文件

python - 是否可以通过将 lineEdit 插入 for 来缩短 PyQt5 中的代码?

python - TensorFlow 运行时错误 : MetaGraphDef associated with tags serve could not be found in SavedModel

machine-learning - 'training loss'在机器学习中意味着什么?

python - 如何有效地乘以具有重复行的 torch 张量,而不将所有行存储在内存中或迭代?

python - 程序完成但我收到此警告 : "too much output to process"

python - 为什么在不同线程中调用 asyncio subprocess.communicate 会挂起?

python - keras v1.2.2 与 keras v2+ 的奇怪行为(准确度存在巨大差异)

tensorflow - Tensorflow中哪个函数与Pytorch中的expand_as类似