tensorflow - 实现tensorflow高级api时出错

标签 tensorflow machine-learning deep-learning tensor

我正在尝试实现提供高级 API 的 tensorflow ,特别是基线分类器。然而,当尝试训练模型时,我得到以下结果

错误:

NotFoundError (see above for traceback): Key baseline/bias not found in checkpoint
     [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

代码:

import tensorflow as tf
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split

def digit_cross():
    # Number of classes, one class for each of 10 digits.
    num_classes = 10

    digit = datasets.load_digits()
    x = digit.data
    y = digit.target
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.3, random_state=42)
    y_train_index = np.arange(y_train.size)

    train_input_fn = tf.estimator.inputs.numpy_input_fn(
        x={"x": np.array(x_train)},
        y=np.array(y_train),
        num_epochs=None,
        shuffle=False)

    # Build BaselineClassifier
    classifier = tf.estimator.BaselineClassifier(n_classes=num_classes,
                                                 model_dir="./checkpoints_tutorial17-1/")

    # Fit model.
    classifier.train(train_input_fn)

digit_cross()

最佳答案

您似乎在 model_dir="./checkpoints_tutorial17-1/" 中有一个检查点,它来自另一个模型,而不是来自 BaselineClassifier。具体来说,该文件夹中有一个检查点文件和 model.ckpt-* 文件。

如 tensorflow 记录:

  • model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. If PathLike object, the path will be resolved. If None, the model_dir in config will be used if set. If both are set, they must be same. If both are None, a temporary directory will be used.

这里,BaselineClassifier将首先构建一个使用baseline/bias的图。然后它发现 model_dir 中有一个先前的检查点。它将尝试加载此检查点,您应该会看到一条信息(如果您已完成tf.logging.set_verbosity(tf.logging.INFO)),内容类似于

"INFO:tensorflow:Restoring parameters from .../checkpoints_tutorial17-1\model.ckpt-..."

由于 model_dir 中的此检查点不是来自 BaselineClassifier,因此它不会有基线/偏差BaselineClassifier 找不到它,因此会抛出错误。

关于tensorflow - 实现tensorflow高级api时出错,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50146225/

相关文章:

python-3.x - 如何从Colaboratory下载大文件(例如模型的权重)?

python - Tensorflow._api.v2.train 没有属性 'AdamOptimizer'

python - 不同时间步长的数据形状和 LSTM 输入

python - 在K-NN中如何显示用于做出正确预测的k个案例集

python - 异构 DataFrame 上的 StratifiedKfold

python - 计算损失时检查标签( tensorflow )

tensorflow - 如何从 .cfg 文件加载 darknet YOLOv3 模型并从 .weights 文件加载权重,并将模型与权重保存到 .h5 文件?

python - TensorFlow 优化器是否最小化 API 实现的小批量?

algorithm - 感知器算法是否适用于二进制输入?

python - 导入torch(pytorch)时发生错误