python - 加速 Tensorflow 2.0 渐变带

标签 python tensorflow keras

我一直在关注卷积 VAE 的 TF 2.0 教程,位于 here .

由于它是急切的,所以梯度是手动计算的,然后使用 tf.GradientTape() 手动应用。

for epoch in epochs:
  for x in x_train:
    with tf.GradientTape() as tape:
      loss = compute_loss(model, x)
    apply_gradients(tape.gradient(loss, model.trainable_variables))

该代码的问题在于它非常慢,每个周期大约需要 40-50 秒。 如果我将批量大小增加很多(到 2048 左右),那么最终每个周期大约需要 8 秒,但模型的性能会下降很多。

另一方面,如果我做一个更传统的模型(即使用基于惰性图的模型而不是渴望模型),例如 here ,那么即使批量大小很小,每个 epoch 也需要 8 秒。

model.add_loss(lazy_graph_loss)
model.fit(x_train epochs=epochs)

根据这些信息,我的猜测是 TF2.0 代码的问题在于手动计算损失和梯度。

有什么方法可以加快TF2.0代码的速度,使其更接近正常代码?

最佳答案

我找到了解决方案:TensorFlow 2.0引入了functions的概念,它将 eager 代码转换为图形代码。

用法非常简单。唯一需要的更改是所有相关函数(例如 compute_lossapply_gradients)都必须使用 @tf.function 进行注释。

关于python - 加速 Tensorflow 2.0 渐变带,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54968545/

相关文章:

linux - Python子进程超时与耗时任务

python - 使用python向sql server发送数据

tensorflow - 从源代码构建 TensorFlow 时出错

python - 如何使用 TensorBoard 可视化具有自定义模型子类的 keras 模型?

javascript - 错误 "Sequential.fromConfig called without an array of configs "

python - Keras 损失一直很低,但准确性开始很高然后下降

python - 将模块列为字符串并导入它们

python - Pytorch - 如何提取 MLP 网络的特征(权重、偏差、节点数、隐藏层)?

python-3.x - session 图为空

python - 为什么在 keras 中增加批处理大小时,使用的 GPU 内存量没有增加?