python - 在 Keras 到 TPU 模型中使用 tensorflow 学习率衰减

标签 python tensorflow keras

我按照“如何使用 TPU 免费训练 Keras 模型速度提高 20 倍”指南 ( click here ) 在 Google 的 colab TPU 上运行 keras 模型。它工作完美。但是......当我适合我的模型时,我喜欢使用余弦重新启动学习率衰减。我已经将自己的代码编写为 keras 回调,但它无法在此框架中工作,因为 TensorFlow TFOptimizer 类没有可以重置的学习率变量。我看到tensorflow本身在tf.train中有一堆衰减函数,比如tf.train.cosine_decay,但我不知道如何将它嵌入到我的模型中。

这是该博客文章中的基本代码。有人解决了吗?

import tensorflow as tf
import os
from tensorflow.python.keras.layers import Input, LSTM, Bidirectional, Dense, Embedding

def make_model(batch_size=None):
    source = Input(shape=(maxlen,), batch_size=batch_size,
                   dtype=tf.int32, name='Input')
    embedding = Embedding(input_dim=max_features,
                          output_dim=128, name='Embedding')(source)
    lstm = LSTM(32, name='LSTM')(embedding)
    predicted_var = Dense(1, activation='sigmoid', name='Output')(lstm)
    model = tf.keras.Model(inputs=[source], outputs=[predicted_var])
    model.compile(
        optimizer=tf.train.RMSPropOptimizer(learning_rate=0.01),
        loss='binary_crossentropy',
        metrics=['acc'])
    return model

training_model = make_model(batch_size=128)

# This address identifies the TPU we'll use when configuring TensorFlow.
TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']
tf.logging.set_verbosity(tf.logging.INFO)

tpu_model = tf.contrib.tpu.keras_to_tpu_model(
    training_model,
    strategy=tf.contrib.tpu.TPUDistributionStrategy(
        tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)))

history = tpu_model.fit(x_train, y_train,
                    epochs=20,
                    batch_size=128 * 8,
                    validation_split=0.2)

最佳答案

一种选择是手动设置学习率 - 这里有一个带有回调的 Keras+TPU 示例:https://github.com/tensorflow/tpu/blob/master/models/experimental/resnet50_keras/resnet50.py#L197-L201

关于python - 在 Keras 到 TPU 模型中使用 tensorflow 学习率衰减,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55163302/

相关文章:

machine-learning - 关于 keras 示例 pretrained_word_embeddings 的问题

python - 用于一维输入的 LSTM - TensorFlow 异常

python - 无法在 VS Code 中导入tensorflow.keras

python - 更好地替代添加前缀和后缀的巨大 "if x==2/if x==3/if x==4"链?

python - 创建模型后清除Python循环中的内存

python - Keras 分类损失的意义

python - LSTM Autoencoder 的这些实现之间的区别?

python - 如何为多列添加功能?

python - 线程 Thread-4 中的异常 - 查找原因或捕获异常

python - 获取 OSError : [Errno 16] Device or resource busy: ' when using tf. keras.models.Sequential.fit_generator