python - tf.keras model.predict 导致内存泄漏

标签 python tensorflow keras google-colaboratory

在 google colab 上工作。使用 tf.keras和 tensorflow 版本 2.3.0
我快疯了,因为我不能使用我训练过的模型来运行预测 model.predict因为它用完了 CPU RAM。我已经能够用一个非常小的例子来重现这个问题。

import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input,Conv2D, Activation

matrixSide = 512 #define a big enough matrix to give memory issues

inputL = Input([matrixSide,matrixSide,12]) #create a toy model
l1 = Conv2D(32,3,activation='relu',padding='same') (inputL) #120
l1 = Conv2D(64,1,activation='relu',padding='same')(l1)
l1 = Conv2D(64,3,activation='relu',padding='same')(l1)
l1 = Conv2D(1,1,padding='same')(l1)
l1 = Activation('linear')(l1)
model = Model(inputs= inputL,outputs = l1)


#run predictions
inImm = np.zeros((64,matrixSide,matrixSide,12))
for i in range (60):
  print(i)
  outImm = model.predict(inImm)
# K.clear_session() #somebody suggested it...
基本上,在 GPU 上工作时,它在前 4 次迭代中使用 3.0 GB 的 CPU RAM,然后上升到 7,然后到 10 然后它崩溃,因为它耗尽了所有可用的 RAM!
在 CPU 上运行时,它会持续更多次迭代,有时甚至将其使用的 RAM 量从 9 GB 减少到 3 GB,但最终在 20 次左右的迭代后仍然崩溃。
前面的示例( Keras predict loop memory leak using tf.data.Dataset but not with a numpy array )在使用 tf.data 时也有类似的问题但不是 numpy。有人建议在 tensorflow 1.14 的 github 问题上做一个 K.clear_session在每个循环中……但这无济于事!
关于如何解决这个问题的任何想法?

最佳答案

这是我将此作为错误发布到 Tensorflow 后的理解。
将代码更改为;

in_imm = np.zeros((64,matrix_side,matrix_side,12))
for i in range (60):
  print(i)
  tensor = tf.convert_to_tensor(in_imm, dtype=tf.float32)
  out_imm = model.predict(tensor)
在带有 numpy 输入的 for 循环中使用 tf.keras.Model.predict 每次迭代都会创建一个新图,因为 numpy 数组是使用不同的签名创建的。将 numpy 数组转换为张量可保持相同的签名并避免创建新图。

关于python - tf.keras model.predict 导致内存泄漏,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64199384/

相关文章:

machine-learning - 如何在Keras中实现稀疏均方误差损失

python - 神经网络中具有不同样本大小的多个输入

python - 想要创建当前任务下游的 Airflow 任务

Python:将列表与范围列表合并

python - 如何处理句子中的换行符? - 斯帕西 NER

python - 如何根据 tensorflow 中另一个矩阵获得的最大值和次要值以及索引来获取矩阵中每一行的值?

python - TensorFlow:tf.train.batch 中 dequeue_up_to 比 dequeue_many 慢吗?

tensorflow - 如何在 tensorflow2.0 的 keras 模型中使用 tf.train.ExponentialMovingAverage

python - Tensorflow - 损失不减少

python - 我应该使用哪个,为什么?有关系吗? SafeUnicode 还是 django.utils.safestring.mark_safe()?