javascript - tensorflow.js model.predict() 打印张量 [[NaN],]

标签 javascript tensorflow machine-learning deep-learning tensorflow.js

我是机器学习和 tensorflow.js 的新手,我试图预测下一组的值,但结果是“NaN”。我究竟做错了什么 ?

正在关注 this Github example

 async function myFirstTfjs(arr) {
    // Create a simple model.
    const model = tf.sequential();
    model.add(tf.layers.dense({units: 1, inputShape: [2]}));

    // Prepare the model for training: Specify the loss and the optimizer.
    model.compile({
      loss: 'meanSquaredError',
      optimizer: 'sgd'
    });
    const xs = tf.tensor([[1,6],
        [2,0],
        [3,1],
        [4,2],
        [5,3],
        [6,4],
        [7,5],
        [8,6],
        [9,0],
        [10,1],
        [11,2],
        [12,3],
        [13,4],
        [14,5],
        [15,6],
        [16,0],
        [17,1],
        [18,2],
        [19,3],
        [20,4],
        [21,5],
        [22,6],
        [23,0],
        [24,1],
        [25,2],
        [26,3]]);
    const ys = tf.tensor([104780,30280,21605,42415,32710,30385,35230,97795,31985,34570,35180,30095,36175,57300,104140,30735,28715,36035,34515,42355,38355,110080,26745,35315,40365,30655], [26, 1]);
    // Train the model using the data.
    await model.fit(xs, ys, {epochs: 500});
    // Use the model to do inference on a data point the model hasn't seen.
  model.predict(tf.tensor(arr, [1, 2])).print();
  }
  myFirstTfjs([28,5]);

最佳答案

发生的事情是 ys 中的大值导致了一个非常大的错误。那个大的错误,加上(默认)学习率,导致模型过度校正和不稳定。如果降低学习率,模型将会收敛。

const learningRate = 0.0001;
const optimizer = tf.train.sgd(learningRate);

model.compile({
  loss: 'meanSquaredError',
  optimizer: optimizer,      
});

关于javascript - tensorflow.js model.predict() 打印张量 [[NaN],],我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50123067/

相关文章:

tensorflow - 拟合生成器函数中的数据增强误差

python - 混淆矩阵生成 a 和 b 作为标签,但不是我需要的

machine-learning - 我应该先执行交叉验证然后再进行网格搜索吗?

javascript - 初始 Handlebars.js 无功能

javascript - 根据所选偏好匹配数据

php - Javascript - 转义特定字符

tensorflow - 如何在 tensorflow 中读取 utf-8 编码的二进制字符串?

javascript - 删除 CSS 中的特定动画

python - tensorflow 如何更改数据集

python - 使用混淆矩阵了解多标签分类器