tensorflow - `model.predict(x,batch_size=n)` 是否使用多核?

标签 tensorflow keras tf.keras

当使用 batch_size>1 运行 model.predict(features,batch_size=n) 时,这些预测是并行进行的吗?也就是同时跨多个内核?

https://github.com/tensorflow/tensorflow/blob/eb7bf384017fd4eabdf500687bec552971cbd41c/tensorflow/python/keras/engine/training_arrays_v1.py

我需要排列数百万列,所以我希望使用 multiprocessing 同时推断出许多预测。然而,这存在重大挑战:Keras + Tensorflow and Multiprocessing in Python

我的一个 friend 建议将数据堆叠起来与 batch_size 一起使用,作为提高性能的解决方法,但我想知道潜在的 yield 是什么,因为这将是一个很大的重写。

最佳答案

这取决于您如何参数化 tf.keras.model.predict() 的函数调用

在文档中可以找到predict的参数,它有一个use_multiprocessing参数:

predict(
    x, batch_size=None, verbose=0, steps=None, callbacks=None, max_queue_size=10,
    workers=1, use_multiprocessing=False
)

除了多处理之外,批量预测而不是单个样本的预测受益于矢量化,这进一步加快了预测速度。

关于tensorflow - `model.predict(x,batch_size=n)` 是否使用多核?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/70885670/

相关文章:

python - 为什么一个 `tf.constant()`的值在TensorFlow中会多次存储在内存中?

python - 如何使用 TensorFlow 获得稳定的结果,设置随机种子

python - 使用 Keras 使用多个指标进行预测

python-3.x - mse 损失函数与隐藏层输出上的正则化损失 (add_loss) 不兼容

python - 在 TensorFlow Functional API 中保存和加载具有相同图表的多个模型

machine-learning - 何时使用tensorflow或api.ai

python - 回归精度

python - 如何在不使用后端函数的情况下在keras中编写自定义函数?

tensorflow - tf.keras.Model 如何区分 tf.data.Dataset 和 TFRecords 中的特征和标签?

python - Tensorflow 上的简单线性回归