tensorflow - 如何使用 TensorFlow JS 为 csv 数据集创建模型并计算预测结果

标签 tensorflow tensor tensorflow.js

我是 TensorFlow JS 的新手。 我按照 TensorFlow JS 文档创建模型并训练它以根据创建的模型计算预测结果。

但我不知道如何为 CSV 文件训练创建的模型并计算 CSV 文件中两列或多列的预测结果。

有人可以指导我使用 CSV 文件创建、训练模型并计算预测结果的示例吗?

const csvUrl = 'https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv';

function save(model) {
    return model.save('downloads://boston_model');
}

function load() {
    return tf.loadModel('indexeddb://boston_model');
}

async function run() {
  // We want to predict the column "medv", which represents a median value of a
  // home (in $1000s), so we mark it as a label.
  const csvDataset = tf.data.csv(
    csvUrl, {
      columnConfigs: {
        medv: {
          isLabel: true
        }
      }
    });
  // Number of features is the number of column names minus one for the label
  // column.
  const numOfFeatures = (await csvDataset.columnNames()).length - 1;

  // Prepare the Dataset for training.
  const flattenedDataset =
    csvDataset
    .map(([rawFeatures, rawLabel]) =>
      // Convert rows from object form (keyed by column name) to array form.
      [Object.values(rawFeatures), Object.values(rawLabel)])
    .batch(10);

  // Define the model.
  const model = tf.sequential();
  model.add(tf.layers.dense({
    inputShape: [numOfFeatures],
    units: 1
  }));
  model.compile({
    optimizer: tf.train.sgd(0.000001),
    loss: 'meanSquaredError'
  });

  // Fit the model using the prepared Dataset
  model.fitDataset(flattenedDataset, {
    epochs: 10,
    callbacks: {
      onEpochEnd: async (epoch, logs) => {
        console.log(epoch, logs.loss);
      }
    }
  });

  const savedModel=save(model);
}

run().then(() => console.log('Done'));

最佳答案

使用 tf.data.csv您可以使用 csv 文件训练模型。

但是浏览器不能直接读取文件。因此,您必须在本地服务器上提供 csv 文件

更新

您的模型仅使用一个感知器。使用多个感知器可以帮助提高模型的准确性,即添加多个层。你可以看看here至于它是如何完成的。

关于tensorflow - 如何使用 TensorFlow JS 为 csv 数据集创建模型并计算预测结果,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54785024/

相关文章:

python - 用于高流量应用程序实时预测的生产环境中的 TensorFlow - 如何使用?

python - tensorflow 类型错误 : Can't convert 'numpy.int64' object to str implicitly

python - 如何加载经过训练的 TensorFlow 模型以使用不同的批量大小进行预测?

pytorch - 如何在张量中复制输入 channel ?

tensorflow - TensorFlow 中的批处理是什么?

keras - Tensorflow.js 使用自定义层和已训练的模型

javascript - TensorFlow JS 异常 - 无法开始训练,因为正在进行另一个 ft() 调用

tensorflow - efficientnet.tfkeras 与 tf.keras.applications.efficientnet

python - 在 C++ 中索引 tensorflow 输出张量

javascript - 如何在nodejs(tensorflow.js)中训练模型?