python - 是否可以在每个训练步骤中获得目标函数值?

标签 python machine-learning tensorflow deep-learning

在通常的TensorFlow训练循环中,比如

train_op = tf.train.AdamOptimizer().minimize(cross_entropy)

with tf.Session() as sess:
    for i in range(num_steps):
        # ...
        train_op.run(feed_dict = feed_dict)

train_op.run 返回 None

但是,有时收集中间结果很有用,例如目标值或准确性。

添加额外的 sess.run 调用将需要再次进行前向传播,从而增加运行时间:

train_op = tf.train.AdamOptimizer().minimize(cross_entropy)

with tf.Session() as sess:
    for i in range(num_steps):
        # ...
        o, a = sess.run([objective, accuracy], feed_dict = feed_dict)
        train_op.run(feed_dict = feed_dict)

是否可以在 TensorFlow 中一次完成此操作?


编辑:

人们建议

sess.run([objective, accuracy, train_op], feed_dict = feed_dict)

但结果取决于列表元素的执行顺序:

[objective, accuracy, train_op]

似乎未定义 -- you get different results depending on whether CUDA is being used .

最佳答案

只需将您的 train_op 添加到要评估的节点列表中即可。

o, a, _ = sess.run([objective, accuracy, train_op], feed_dict = feed_dict)

关于训练步骤及其在评估中的顺序,我做了如下小实验:

import tensorflow as tf
x = tf.Variable(0, dtype=tf.float32)
loss = tf.nn.l2_loss(x-1)
train_opt = tf.train.GradientDescentOptimizer(1)
train_op = train_opt.minimize(loss)
init_op = tf.global_variables_initializer()

sess = tf.Session()
sess.run(init_op)
x_val, _, loss_val = sess.run([x, train_op, loss])
# returns x_val = 1.0, loss_val = 0.5

情况比我最初想象的还要困惑。似乎给定的是,提取的执行顺序不取决于它们在列表中的各自位置:x_valloss_val 将是相同的,独立于它们列表中的位置。

但是,正如@MaxB 所注意到的,它们的执行顺序是无法保证的。在 GPU 上运行上述代码时,x_val 设置为 0.0,即初始值。但是,在 CPU 上运行时,x_val 为 1.0,即 train_op 更新后的值。

这种依赖于配置的行为可能仅限于通过训练操作更新的变量,正如上面的实验所表明的,但不能保证它们来自 tf 的文档。

关于python - 是否可以在每个训练步骤中获得目标函数值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43856480/

相关文章:

python - 如何在没有弃用函数的情况下迭代 tf.dataset?

python - "OSError: [Errno 22] Invalid argument"当读取大文件时

python - 可以安全地假设(并教导)Python 字典将保持有序吗?

python - pip 目标并在 setup.py 中安装后

python - 让 Keras/Tensorflow 输出 OneHotCategorical,但操作没有梯度

python - "Cannot convert a ndarray into a Tensor or Operation."尝试从 tensorflow 中的 session.run 获取值时出错

python - 如何使用循环显示查询字典

python - GridSearchCV 是否存储所有参数组合的所有分数?

excel - Azure机器学习工作室: how to add a dataset from a local Excel file?

python - 是否有使用 OPA UA 传输数据的 IEC 61131/IEC 61499 PLC 功能 block ?