我正在使用 tensorflow 构建基于 CNN 的文本分类。有些数据集很大,有些数据集很小。
我使用 feed_dict 通过从系统内存(不是 GPU 内存)采样数据来为网络提供数据。网络是逐批训练的。每个数据集的批量大小固定为 1024。
我的问题是: 网络按批处理进行训练,每个批处理代码从系统内存中检索数据。因此,无论数据集有多大,代码都应该以相同的方式处理它,对吗?
但是我在处理大数据集时遇到了内存不足的问题,而对于小数据集它工作得很好。我很确定系统内存足以保存所有数据。所以 OOM 问题是关于 tensorflow 的,对吗?
是我代码写错了,还是tensorflow内存管理的问题?
非常感谢!
最佳答案
我认为你的批量大小是 1024 太大了。创建了很多矩阵开销,特别是如果你使用 AgaGrad Adam 等、dropout、attention 和/或更多。尝试使用较小的值(例如 100)作为批量大小。应该可以很好地解决和训练。
关于 tensorflow 内存不足,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37736071/