python - Tensorflow - 条件训练

标签 python tensorflow if-statement

我正在以监督方式使用 tensorflow (1.12)训练神经网络。我只想针对具体示例进行训练。这些示例是通过删除子序列来动态创建的,因此我想在 tensorflow 中进行调节。

这是我原来的代码部分:

train_step, gvs = minimize_clipped(optimizer, loss,
                               clip_value=FLAGS.gradient_clip,
                               return_gvs=True)
gradients = [g for (g,v) in gvs]
gradient_norm = tf.global_norm(gradients)
tf.summary.scalar('gradients/norm', gradient_norm)
eval_losses = {'loss1': loss1,
               'loss2': loss2}

训练步骤稍后执行为:

batch_eval, _ = sess.run([eval_losses, train_step])

我正在考虑插入类似的内容

train_step_fake = ????
eval_losses_fake = tf.zeros_like(tensor)
train_step_new = tf.cond(my_cond, train_step, train_step_fake)
eval_losses_new = tf.cond(my_cond, eval_losses, eval_losses_fake)

然后做

batch_eval, _ = sess.run([eval_losses, train_step])

但是,我不知道如何创建一个假的 train_step。

此外,这总体上是一个好主意还是有更顺畅的方法?我正在使用 tfrecords 管道,但没有其他高级模块(如 keras、tf.estimator、eager execution 等)。

显然非常感谢任何帮助!

最佳答案

先回答具体问题。当然可能仅根据tf.cond 结果执行训练步骤。请注意,第二个和第三个参数是 lambda,但更像是:

train_step_new = tf.cond(my_cond, lambda: train_step, lambda: train_step_fake)
eval_losses_new = tf.cond(my_cond, lambda: eval_losses, lambda: eval_losses_fake)

尽管您的直觉认为这可能不是正确的做法,但这是正确的。

在数据到达模型之前过滤掉您想要忽略的数据会更可取(无论是在效率方面还是在阅读和推理代码方面) .

您可以使用 Dataset API 来实现这一点。其中有一个非常有用的 filter() 方法,您可以使用。如果您现在使用数据集 api 来读取 TFRecords,那么这应该像添加以下内容一样简单:

dataset = dataset.filter(lambda x: {whatever op you were going to use in tf.cond})

如果您尚未使用数据集 API,那么现在可能是时候稍微阅读一下并考虑它,而不是使用 tf.cond() 来破坏模型来采取行 Action 为过滤器。

关于python - Tensorflow - 条件训练,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56184095/

相关文章:

javascript - 为什么我的函数没有按预期显示和隐藏元素?

javascript - 如果另一个元素包含特定文本,请单击该元素。使用 Protractor 进行自动测试

php - 如何使 PHP if 语句依赖于 SQL 查询结果

Python字典理解过滤

python - 如何追加 AWS S3 存储桶通知配置

tensorflow - 在 TensorBoard 中找不到任何标量摘要

python - Keras 代码 Q-learning OpenAI gym FrozenLake 有问题

tensorflow - 如何在 tensorflow 的发行版中使用 Tensorflow XLA AOT 支持

python - 在此代码的上下文中, "' module' object is not subscriptable"是什么意思?

python - 类型错误 : __init__() got an unexpected keyword argument 'early_stopping_rounds' for CatBoost in Python