javascript - 使用tensorflow.js根据类型预测电影的吸引力

标签 javascript machine-learning tensorflow.js

使用默认的基本示例tensorflow.js我正在尝试更改它,因此通过给它一个指定电影类型的数组,它可以预测我是否喜欢这部电影:

  // Define a model for linear regression.
  const model = tf.sequential();
  model.add(tf.layers.dense({units: 1, inputShape: [1]}));

  // Prepare the model for training: Specify the loss and the optimizer.
  model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
  // Generate some synthetic data for training.

  //[action, adventure, romance]
  const xs = tf.tensor1d([1,1,0]);
  //target data should be rating from 1 to 5
  const ys = tf.tensor1d([3]);

  // Train the model using the data.
  model.fit(xs, ys).then(() => {
    // Use the model to do inference on a data point the model hasn't seen before:
    // Open the browser devtools to see the output
    model.predict(tf.tensor2d([1,0,0])).print();
  });

但是,关于const ys = tf.tensor1d([3]);它抛出一个错误告诉我 Input Tensors should have the same number of samples as target Tensors. Found 3 input sample(s) and 1 target sample(s) ,但我想要从数组 [3] 到 1 到 5 之间的数字进行预测,但我不知道如何使用此示例来实现此目的

最佳答案

样本数量应与目标数量匹配,否则模型无法学习。我更新了您的示例,添加了另一个示例和另一个目标,并更正了形状。

// Define a model for linear regression.
const model = tf.sequential();
model.add(tf.layers.dense({ units: 1, inputDim: 3 }));

// Prepare the model for training: Specify the loss and the optimizer.
model.compile({ loss: 'meanSquaredError', optimizer: 'sgd' });
// Generate some synthetic data for training.

//[action, adventure, romance]
const xs = tf.tensor2d([[1, 1, 0], [1, 0, 1]]);
//target data should be rating from 1 to 5
const ys = tf.tensor2d([[3], [2]]);

// Train the model using the data.
model.fit(xs, ys).then(() => {
  // Use the model to do inference on a data point the model hasn't seen before:
  // Open the browser devtools to see the output
  model.predict(tf.tensor2d([[1, 0, 0]])).print();
});

编译并产生以下结果:

Tensor
     [[1.6977279],]

关于javascript - 使用tensorflow.js根据类型预测电影的吸引力,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49791076/

相关文章:

javascript - 使用 ajax 的 YUI 数据表

javascript - face-api 和 tensorflow.js 在浏览器中不起作用

javascript - 给定 url,如何加载图像,然后将其转换为 Javascript 中的 tf.tensor 数据?

tensorflow - 转换后的 Tensorflow JS 模型报告缺少输入形状

javascript - jQuery 让 div 在点击时具有相同的高度

javascript - 如何定位CSS类名的一部分?

math - 相关系数实际上代表什么

python - 将矩阵划分为 2x2 方阵子矩阵 - maxpooling fprop

python - scikit-learn 中的sample_weight 与class_weight 相比如何?

javascript - 将 JSON 值发送到 PHP