python - 只训练tensorflow中的一些变量

标签 python python-2.7 tensorflow

我正在使用 tensorflow 进行梯度体面分类。

train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

这里的 cost 是我在优化中使用的成本函数。 在 Session 中启动 Graph 后,Graph 可以作为:

sess.run(train_op, feed_dict)

这样,成本函数中的所有变量都将更新,以最小化成本。

这是我的问题。训练时如何只更新成本函数中的一些变量..?有没有办法将创建的变量转换为常量或其他东西......?

最佳答案

有几个很好的答案,这个主题应该已经关闭了: stackoverflow Quora

只是为了避免人们再次点击这里:

tensorflow 优化器的最小化函数为此目的采用了一个 var_list 参数:

first_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     "scope/prefix/for/first/vars")
first_train_op = optimizer.minimize(cost, var_list=first_train_vars)

second_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                      "scope/prefix/for/second/vars")                     
second_train_op = optimizer.minimize(cost, var_list=second_train_vars)

我照原样从mrry

要获取您应该使用的名称列表而不是 "scope/prefix/for/second/vars",您可以使用:

tf.get_default_graph().get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)

关于python - 只训练tensorflow中的一些变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38989046/

相关文章:

python - 批量大小可变的 TensorFlow 数据集 `from_generator`

python - 使用广播根据向量中的元素乘以矩阵行?

python - Python类继承问题

python - 如何动态更新 matplotlib 表格单元格文本

mysql - web2py是否支持在数据库中创建触发器

python - 为什么keras model.fit with sample_weight 初始化时间长?

python - 每次 python 脚本运行时更改 Excel 工作表

python - 在 python 中加密

Python .lower 似乎没有正确地小写所有 unicode 字符(Python 2.7)

c++ - 如何使用 nvlink GPU 创建 Tensorflow session