python - TensorFlow 变量无法初始化

标签 python tensorflow

X=tf.placeholder(tf.float32,[None,32,32,3])
y=tf.placeholder(tf.int64,[None])
is_training=tf.placeholder(tf.bool)

def simple_model(X,y):

    Wconv1=tf.get_variable("Wconv1",shape=[7,7,3,32],use_resource=True)
    bconv1=tf.get_variable('bconv1',shape=[32])
    W1=tf.get_variable('W1',shape=[5408,10])
    b1=tf.get_variable('b1',shape=[10])

    a1=tf.nn.conv2d(X,Wconv1,[1,2,2,1],'VALID')+bconv1
    h1=tf.nn.relu(a1)

    h1_flat=tf.reshape(h1,[-1,5408])
    y_out=tf.matmul(h1_flat,W1)+b1
    return y_out

init=tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)

    sess.run(simple_model(X,y),feed_dict={X:X_train,y:y_train})

错误是

PreconditionError Attempting to use uninitialized variable Wconv1

我不知道代码有什么问题?

最佳答案

tf.global_variables_initializer到那时创建的所有全局变量进行初始化操作。这意味着如果您稍后创建其他变量,它们将不会被该操作初始化。这是因为变量初始值设定项仅包含它们必须初始化的变量列表,并且这不会随着您添加更多变量而改变(事实上,tf.global_variables_initializer() 只是 的快捷方式>tf.variables_initializer(tf.global_variables())tf.variables_initializer(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))。在您的情况下,在您之前创建 init 之后,将在第二次调用 sess.run 时创建变量。您需要在使用变量创建模型后创建初始化操作:

X=tf.placeholder(tf.float32,[None,32,32,3])
y=tf.placeholder(tf.int64,[None])
is_training=tf.placeholder(tf.bool)

def simple_model(X,y):

    Wconv1=tf.get_variable("Wconv1",shape=[7,7,3,32],use_resource=True)
    bconv1=tf.get_variable('bconv1',shape=[32])
    W1=tf.get_variable('W1',shape=[5408,10])
    b1=tf.get_variable('b1',shape=[10])

    a1=tf.nn.conv2d(X,Wconv1,[1,2,2,1],'VALID')+bconv1
    h1=tf.nn.relu(a1)

    h1_flat=tf.reshape(h1,[-1,5408])
    y_out=tf.matmul(h1_flat,W1)+b1
    return y_out

my_model = simple_model(X,y)
init=tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)

    sess.run(my_model, feed_dict={X:X_train,y:y_train})

关于python - TensorFlow 变量无法初始化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52723671/

相关文章:

python - 有什么方法可以更改寄存器的 Django-rest-auth View 吗?

python - 了解 Keras 中语音识别的 CTC 损失

tensorflow - tf.contrib.layers.embed_sequence() 是干什么用的?

Python 分析 - runsnakerun 输出中的列是什么?

python - SQL ORDER BY 时间戳,相同的值

python - Tensorflow 循环中的切片分配

python-3.x - key 错误 : 'Unable to open object (wrong B-tree signature)'

tensorflow - 对全连接层使用单一共享偏差

python - 通过正则表达式批量重命名 - python 解决方法

python - Airflow - GoogleCloudStorageToBigQueryOperator 不呈现模板化的 source_objects