Tensorflow 估计器 API : How to pass parameter from input function

标签 tensorflow tensorflow-estimator

我正在尝试为我的模型添加类权重作为超参数,但是为了计算权重,我需要读取输入数据,这发生在 input_fn 内部,然后传递给 estimator.fit() . input_fn 的输出只是特征,标签应该具有相同的形状 num_examples * num_features。我的问题 - 有没有办法将数据从 input_fn 传播到 model_fn 的超参数映射?或者作为替代 - 也许 input_fn 数据集有一个包装器,它允许对少数/欠采样多数以及批处理进行过采样 - 在这种情况下,我不需要任何参数来传播。

最佳答案

特征和标签都可以是张量字典(不仅仅是一个张量)。张量可以是您想要的任何形状,但通常为 num_examples * ...

如果您不使用任何预定义的估计器,最简单的方法是添加另一个具有计算权重所需的功能,计算模型中的权重,然后使用它们(乘以损失或将其作为参数传递) .

您还可以访问 input_fn 中的超参数,以便您可以在那里计算权重并将其添加为单独的列。

如果您使用 jar 装估算器,请查看文档。我看到他们中的大多数都支持 weight_column_name。在这种情况下,只需将其命名为您在特征字典中用于权重值的名称。

或者,如果所有其他方法都失败了,您可以在将数据提供给 tensorflow 之前以您想要的方式对数据进行采样。

关于Tensorflow 估计器 API : How to pass parameter from input function,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48372837/

相关文章:

tensorflow - 使用分层抽样的 tensorflow 数据集

tensorflow - 如何有效地标记我的图像分类数据集?

python - 使用tensorflow猜数字

python - TensorFlow ExportOutputs、PredictOuput 并指定 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY

python - Tensorflow Estimator : loss not decreasing when using tf. feature_column.embedding_column 用于分类变量列表

python - 如何将 tensorflow 变量打印到小数点后两位?

python - lstm预测结果延迟现象

python - 在 Keras 中从 .jpg 创建标准化数据集

python-3.x - Tensorflow 错误 "UnimplementedError: Cast string to float is not supported"- 使用估计器的线性分类器模型

tensorflow - key 错误 : "The name ' boosted_trees/QuantileAccumulator/' refers to an Operation not in the graph." when loading saved model