(原始问题已编辑) 在 Tensorflow 中,我们经常需要定义包含变量的函数,以在中间神经网络层上实现。有没有办法评估这个输出,例如:
import tensorflow as tf
def Mult(mult):
A = tf.get_variable([2,2], initializer = tf.zeros_initializer)
B = tf.get_variable([2,2], initializer = tf.zeros_initializer)
return mult*tf.matmul(A,B)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Here we alter the variables A,B in some fashion e.g. in an optimisation algorithm
print(Mult(A,B))
产生错误
最佳答案
代码存在几个问题。首先,由于您将变量 A
和 B
传递给函数,因此不需要在函数内初始化它们。所以这个函数应该是这样的:
def Mult(A,B):
return tf.matmul(A,B)
接下来,您需要在初始化变量之前定义它们,因此行
sess.run(tf.global_variables_initializer())
应该在定义变量之后出现。
此外,如果你想得到Mult(A,B)
的数值,你应该打印
sess.run(Mult(A,B))
而不仅仅是Mult(A,B)
(因为后者只会给你一个张量对象)。
最后,您需要为您定义的变量提供一个名称。 [2,2]
是形状(函数的第二个参数)。第一个参数应该是名称。
这是更正后的代码:
import tensorflow as tf
def Mult(A,B):
return tf.matmul(A,B)
A = tf.get_variable('A',shape=[2,2], initializer = tf.zeros_initializer)
B = tf.get_variable('B',shape=[2,2], initializer = tf.zeros_initializer)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(Mult(A,B)))
这会打印
[[ 0. 0.]
[ 0. 0.]]
关于 tensorflow 中的函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45869275/