python - 分阶段的 Tensorflow 自定义训练

标签 python tensorflow keras deep-learning dropout

我需要使用 Tensorflow/Keras 创建一个自定义训练循环(因为我想要拥有多个优化器并告诉每个优化器应该作用于哪些权重)。

虽然this tutorialthat one too对于这个问题非常清楚,他们错过了一个非常重要的一点:我如何预测训练阶段以及如何预测验证阶段?

假设我的模型有 Dropout 层或 BatchNormalization 层。无论是在训练还是验证中,它们的工作方式肯定是完全不同的。

我如何改编这些教程?这是一个虚拟示例(可能包含一两段伪代码):

# Iterate over epochs.
for epoch in range(3):


    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:

            #model with two outputs
            #IMPORTANT: must be in training phase (use dropouts, calculate batch statistics)
            logits1, logits2 = model(x_batch_train) #must be "training"

            loss_value1 = loss_fn1(y_batch_train[0], logits1)
            loss_value2 = loss_fn2(y_batch_train[1], logits2)

            grads1 = tape.gradient(loss_value1, model.trainable_weights[selection1])    
            grads2 = tape.gradient(loss_value2, model.trainable_weights[selection2])

            optimizer1.apply_gradients(zip(grads1, model.trainable_weights[selection1]))
            optimizer2.apply_gradients(zip(grads2, model.trainable_weights[selection2]))



    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:

        ##Important: must be validation phase
            #dropouts are off: calculate all neurons and divide value    
            #batch norms use previously calculated statistics    
        val_logits1, val_logits2 = model(x_batch_val)

        #.... do the evaluations

最佳答案

我认为你可以在调用 tf.keras.Model 时传递一个 training 参数。 ,并且它将被传递到各层:

# On training
logits1, logits2 = model(x_batch_train, training=True)
# On evaluation
val_logits1, val_logits2 = model(x_batch_val, training=False)

关于python - 分阶段的 Tensorflow 自定义训练,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58012094/

相关文章:

python-2.7 - ValueError : Tensor Tensor ("Const:0", shape=(), dtype=float32) 可能无法与 tf.placeholder 一起提供

machine-learning - 如何在 Keras 中提取训练集和验证集?

python - 如何调用 Django 模型中定义的模板中的函数?

string - tensorflow,如何将 tf.string SparseTensor 连接到一维密集张量

python - 创建一个输出字典的 tensorflow 数据集

python - 如何存储CNN的展平结果?

tensorflow - keras SpatialDropout2D 在 TimeDistributed 层中的正确使用 - CNN LSTM 网络

python - 如何在django中保存多对多关系

通过 REPL 与阻塞循环交互的 Python 最佳实践

python - 使用Python计算公交车站之间的时间