python - tensorflow 中的优化器如何访问在单独函数中创建的变量

标签 python tensorflow namespaces

代码中感兴趣的行后面跟着多个井号 (#) 符号

为了理解目的,我在 tensorflow 中运行一个简单的线性回归。我使用的代码是:

def generate_dataset():
#y = 2x+e where is the normally distributed error
x_batch = np.linspace(-1,1,101)
y_batch = 2*x_batch +np.random.random(*x_batch.shape)*0.3
return x_batch, y_batch

def linear_regression():   ##################
x = tf.placeholder(tf.float32, shape = (None,), name = 'x')
y = tf.placeholder(tf.float32, shape = (None,), name = 'y')
with tf.variable_scope('lreg') as scope: ################
    w = tf.Variable(np.random.normal()) ##################
    y_pred = tf.multiply(w,x)
    loss = tf.reduce_mean(tf.square(y_pred - y))
return x,y, y_pred, loss
def run():
x_batch, y_batch = generate_dataset()
x, y, y_pred, loss = linear_regression()
optimizer = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

init = tf.global_variables_initializer()
with tf.Session() as session:
    session.run(init) 
    feed_dict = {x: x_batch, y: y_batch}
    for _ in range(30):
        loss_val, _ = session.run([loss, optimizer], feed_dict)
        print('loss:', loss_val.mean())
    y_pred_batch = session.run(y_pred, {x:x_batch})

    print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) ############
    print(session.run(fetches = [w])) #############
run()      

我似乎无法通过对“w”或“lreg/w”的提取调用来获取变量的值(它实际上是一个操作吗?)“w”,如果我理解正确的话是由于“w”是在 Linear_regression() 中定义的,并且它不将其 namespace 借给 run()。但是,我可以通过对其变量名称“lreg/vairable:0”的 fetch 调用来访问“w”。优化器工作得很好并且更新被完美应用

优化器如何访问“w”并应用更新,如果您能让我深入了解如何在 Linear_regression() 和 run() 之间共享操作“w”,那就太好了

最佳答案

您创建的每个操作和变量都是 tensorflow 中的一个节点 graph 。当您没有显式创建图表时(就像您的情况一样),则会使用默认图表。

此行将 w 添加到默认图表中。

 w = tf.Variable(np.random.normal())

该行访问图形以执行计算

loss_val, _ = session.run([loss, optimizer], feed_dict)

您可以像这样检查图表

tf.get_default_graph().as_graph_def()

关于python - tensorflow 中的优化器如何访问在单独函数中创建的变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46406938/

相关文章:

python - tensorflow 文本生成

r - 从 .onLoad 包函数中调用 getNamespaceExports()

python - 计算Python程序的行数

python - 在多个不同名称的列上合并两个 Pandas 数据框

python - 在OpenCV中将图像序列转换为视频+阅读范围

tensorflow - 在 Tensorflow 对象检测 API 中将图像裁剪到边界框

python - 在 TensorFlow 2 中使用 tf.ConfigProto 初始化 tf.Session 有什么等价物?

python - 用于算术运算的 BFS

c++ - 类函数和命名空间之间重新声明的符号

xml - 使用 XSL 替换默认命名空间