tensorflow - 使用输入 fn 在 Tensorflow 估计器中进行预测

标签 tensorflow classification predict

我使用来自 https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/learn/wide_n_deep_tutorial.py 的教程代码并且代码运行良好,直到我尝试做出预测而不是仅仅评估它。我试图制作另一个看起来像这样的预测函数(只需删除参数 y):

def input_fn_predict(data_file, num_epochs, shuffle):
  """Input builder function."""
  df_data = pd.read_csv(
      tf.gfile.Open(data_file),
      names=CSV_COLUMNS,
      skipinitialspace=True,
      engine="python",
      skiprows=1)
  # remove NaN elements
  df_data = df_data.dropna(how="any", axis=0)
  labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int)
  return tf.estimator.inputs.pandas_input_fn( #removed paramter y
      x=df_data,
      batch_size=100,
      num_epochs=num_epochs,
      shuffle=shuffle,
      num_threads=5)

并这样称呼它:
predictions = m.predict(
      input_fn=input_fn_predict(test_file_name, num_epochs=1, shuffle=True)
  )
  for i, p in enumerate(predictions):
      print(i, p)
  • 我做得对吗?
  • 为什么我得到预测 81404 而不是 16282(测试文件中的行数)?
  • 每行包含如下内容:

  • {'probabilities': array([ 0.78595656, 0.21404342], dtype=float32), 'logits': array([-1.3007226], dtype=float32), 'classes': array(['0'], dtype=object), 'class_ids': array([0]), 'logistic': array([ 0.21404341], dtype=float32)}



    我怎么读?

    最佳答案

    您需要设置shuffle=False因为要预测新标签,您需要维护数据顺序。

    下面是我运行预测的代码(我已经测试过了)。输入文件类似于测试数据(在 csv 中),但没有标签列。

    
    
        def predict_input_fn(data_file):
            global CSV_COLUMNS
            CSV_COLUMNS = CSV_COLUMNS[:-1]
            df_data = pd.read_csv(
                tf.gfile.Open(data_file),
                names=CSV_COLUMNS,
                skipinitialspace=True,
                engine='python',
                skiprows=1
            )
    
            # remove NaN elements
            df_data = df_data.dropna(how='any', axis=0)
    
            return tf.estimator.inputs.pandas_input_fn(
                x=df_data,
                num_epochs=1,
               shuffle=False
            )
    
    

    调用它:
    
    
        predict_file_name = 'tutorials/data/adult.predict'
        results = m.predict(
            input_fn=predict_input_fn(predict_file_name)
        )
        for result in results:
            print 'result: {}'.format(result)
    
    

    一个样本的预测结果如下:
    
    
        {
            'probabilities': array([0.78595656, 0.21404342], dtype = float32),
            'logits': array([-1.3007226], dtype = float32),
            'classes': array(['0'], dtype = object),
            'class_ids': array([0]),
            'logistic': array([0.21404341], dtype = float32)
        }
    
    

    每个字段的含义是什么
  • '概率':数组([0.78595656,0.21404342],dtype = float32) .它预测输出标签是 class-0(在这种情况下 <=50K)
    置信度 0.78595656
  • 'logits':数组([-1.3007226],dtype = float32)
    等式 1/(1+e^(-z)) 中的 z 值为 -1.3。
  • '类':数组(['0'],dtype = object)
    类标签为 0
  • 关于tensorflow - 使用输入 fn 在 Tensorflow 估计器中进行预测,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46948172/

    相关文章:

    tensorflow - 在tensorflow中生成pb文件

    python - 如何对张量执行阈值处理

    r - 将列值分类到具有特定输出的新列中

    matlab - 如何利用支持向量机提高分类精度

    python - Python 中基于字典的文本分类

    python - 属性错误 : module 'keras.backend' has no attribute 'image_dim_ordering'

    python - 如何可视化 TFRecord?

    r - 预测值不对用户输入使用react R Shiny

    r - 在 R 中组合多个模型对子集数据的预测的简单方法

    algorithm - 推荐系统和基线预测器