python - 我需要什么 K.clear_session() 和 del 模型(Keras with Tensorflow-gpu)?

标签 python tensorflow memory-management keras

我在做什么
我正在训练并使用卷积神经元网络 (CNN) 进行图像分类,使用 Keras 和 Tensorflow-gpu 作为后端。

我正在使用什么
- PyCharm 社区 2018.1.2
- Python 2.7 和 3.5(但不能同时使用)
- Ubuntu 16.04
- Keras 2.2.0
- Tensorflow-GPU 1.8.0 作为后端

我想知道的
在许多代码中,我看到人们使用

from keras import backend as K 

# Do some code, e.g. train and save model

K.clear_session()

或使用后删除模型:

del model

关于 clear_session 的 keras 文档说:“销毁当前的 TF 图并创建一个新的。有助于避免旧模型/层造成困惑。” - https://keras.io/backend/

这样做有什么意义,我也应该这样做吗?在加载或创建新模型时,我的模型无论如何都会被覆盖,那何必呢?

最佳答案

K.clear_session() 在您连续创建多个模型时很有用,例如在超参数搜索或交叉验证期间。您训练的每个模型都会向图中添加节点(可能有数千个)。 TensorFlow 会在您(或 Keras)调用 tf.Session.run()tf.Tensor.eval() 时执行整个图表,因此您的模型会变得越来越慢进行训练,您也可能会耗尽内存。清除 session 会删除以前模型中遗留的所有节点,释放内存并防止速度变慢。


21/06/19 编辑:

TensorFlow 默认是惰性求值的。 TensorFlow 操作不会立即评估:创建张量或对其执行一些操作会在数据流图中创建节点。当您调用 tf.Session.run()tf.Tensor.eval() 时,通过一次性评估图表的相关部分来计算结果。这样 TensorFlow 就可以构建一个执行计划,将可以并行执行的操作分配给不同的设备。它还可以将相邻节点折叠在一起或删除冗余节点(例如,如果您连接两个张量,然后再次将它们分开而不改变)。更多详情,请参阅 https://www.tensorflow.org/guide/graphs

您的所有 TensorFlow 模型都以一系列张量和张量运算的形式存储在图中。机器学习的基本操作是张量点积——神经网络的输出是输入矩阵和网络权重的点积。如果你有一个单层感知器和 1000 个训练样本,那么每个 epoch 至少会创建 1000 个张量操作。如果您有 1,000 个 epoch,那么在考虑预处理、后处理和更复杂的模型(如循环网络、编码器-解码器、注意力模型等)之前,您的图最后至少包含 1,000,000 个节点。

问题是最终图表会太大而无法放入视频内存(在我的情况下为 6 GB),因此 TF 会将图表的部分内容从视频传输到主内存并返回。最终它甚至会变得对于主内存(12 GB)来说太大,并开始在主内存和硬盘之间移动。不用说,这让事情变得难以置信,而且随着训练的进行越来越慢。在开发此保存模型/清除 session /重新加载模型流程之前,我计算过,按照我经历的每个时代的减速率,我的模型完成训练所需的时间将超过宇宙的年龄。

Disclaimer: I haven't used TensorFlow in almost a year, so this might have changed. I remember there being quite a few GitHub issues around this so hopefully it has since been fixed.

关于python - 我需要什么 K.clear_session() 和 del 模型(Keras with Tensorflow-gpu)?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50895110/

相关文章:

python - Python 2.7 中的 Google Cloud 客户端库

javascript - Chart.js 未更新 Flask python 或接收来自 Flask 的数据

objective-c - 释放我已经完成的 NSString 会导致崩溃

python - 如何以快速且内存高效的方式替换列中的值

c++ - 将 new[] 与 delete 配对怎么可能只导致内存泄漏?

python - 使用 pygit2 提交时未跟踪的目录

python - 从 __main__ 命令行调用单元测试失败

tensorflow - 在 Google Colab 中使用多个 GPU 在 Tensorflow 中进行分布式训练

tensorflow - Tensorboard - 按范围过滤步骤

python - Tensorflow仅初始化特定范围