python-3.x - 在 tensorflow 中读取数据

标签 python-3.x machine-learning tensorflow

我正在从文档中学习如何使用tensorflow。但是,我无法理解以下两个功能。我尝试搜索文档,但无法得到明确的答案。

input_fn = tf.contrib.learn.io.numpy_input_fn({"x":x_train}, y_train,
                                          batch_size=4,
                                          num_epochs=1000)
eval_input_fn = tf.contrib.learn.io.numpy_input_fn(
    {"x":x_eval}, y_eval, batch_size=4, num_epochs=1000)

此外,如果您能解释一下函数中的参数是什么,那就太好了。提前致谢!

最佳答案

我同意documentation没有很好地解释这些值(value)观。

第一个参数是网络输入值的字典。

第二个是输出值。

batch_size是训练网络时一次获取的项目数,当你有大量数据时,如果你用每个值训练网络,训练会太慢,所以你使用批处理,因此如果 x_train 有 16 个项目且批处理大小为 4,则网络将使用前 4 个项目进行训练,然后使用下 4 个项目进行训练,直到使用所有 16 个输入。

num_epochs 是在您停止尝试优化网络节点之前前后传递的次数。

对于批量大小和轮数,您可以根据您想要花费多少时间进行训练以及性能的好坏来调整这些参数。

关于terms的精彩描述.

取自 an example ,然后您可以使用 input_fn 获取要传递以运行该数据的特征和目标值。

例如:

with self.test_session() as session:
  input_fn = numpy_io.numpy_input_fn({"x":x_train}, y_train,
                                      batch_size=4,
                                      num_epochs=1000)
  features, target = input_fn()

  res = session.run([features, target])

这两个函数做同样的事情,只是input_fn用于训练网络,eval_input_fn用于评估,所以唯一的区别是x和y值,因为当您评估训练数据的性能时,性能将是上限或最佳情况。

关于python-3.x - 在 tensorflow 中读取数据,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44975585/

相关文章:

machine-learning - 使用 tensorflow 后端在 keras 中定义 pinball 损失函数

python - 如何修复 "element not interactable"异常?

python-3.x - Selenium - 文本属性仅在调试器检查后才可用

python-3.x - 计算滚动 3 天 Pandas 中的不同计数?

tensorflow - PyCaret ValueError - 单元格为空

tensorflow - 不同 channel 上不求和的卷积层 - Keras

python - 无法更新标签文本

python-3.x - Keras 方法 'predict' 和 'predict_generator' 具有不同的结果

Python-图像识别分类器

python - 如何使用 Tensorflow 2/Keras 保存和继续训练具有多个模型部分的 GAN