javascript - 使用 Keras.js 将 Keras 模型实现到网站中

标签 javascript keras

我一直在尝试使用Keras.js将Python生成的基本Keras模型实现到网站中。图书馆。现在,我已经训练了模型并将其导出到 model.jsonmodel_weights.bufmodel_metadata.json 文件中。现在,我基本上从 github 页面复制并粘贴了测试代码,以查看模型是否会在浏览器中加载,但不幸的是我遇到了错误。这是测试代码。 (编辑:我修复了一些错误,其余错误请参见下文。)

var model = new KerasJS.Model({
    filepaths: {
        model: 'dist/model.json',
        weights: 'dist/model_weights.buf',
        metadata: 'dist/model_metadata.json'
  },
  gpu: true
});

    model.ready()
  .then(function() {
    console.log("1");
    // input data object keyed by names of the input layers
    // or `input` for Sequential models
    // values are the flattened Float32Array data
    // (input tensor shapes are specified in the model config)
    var inputData = {
      'input_1': new Float32Array(data)
    };
    console.log("2 " + inputData);
    // make predictions
    return model.predict(inputData);
  })
  .then(function(outputData) {
    // outputData is an object keyed by names of the output layers
    // or `output` for Sequential models
    // e.g.,
    // outputData['fc1000']
    console.log("3 " + outputData);
  })
  .catch(function(err) {
    console.log(err);
    // handle error
  });

编辑:所以我稍微改变了我的程序以与 JS 5 兼容(这对我来说是一个愚蠢的错误),现在我遇到了一个不同的错误。该错误被捕获并记录下来。我得到的错误是:错误:predict() 必须采用一个对象,其中键是模型的命名输入:input。我相信出现此问题是因为我的数据变量的格式不正确。我认为如果我的模型接受 28x28 的数字数组,那么 data 也应该是 28x28 的数组,以便它可以正确地“预测”正确的输出。但是,我相信我遗漏了一些东西,这就是抛出错误的原因。 This问题和我的非常相似,但是它是用python而不是JS。再次强调,任何帮助将不胜感激。

最佳答案

好的,所以我明白为什么会发生这种情况。有两个问题。首先,data 数组需要被展平,所以我编写了一个快速函数来获取 2D 输入并将其“展平”为长度为 784 的 1D 数组。然后,因为我使用了顺序模型,数据的键名不应该是 'input_1',而应该是 'input'。这消除了所有错误。

现在,要获取输出信息,我们只需将其存储在一个数组中,如下所示:var out = outputData['output']。因为我使用了 MNIST 数据集,所以 out 是一个长度为 10 的一维数组,其中包含每个数字是用户编写的数字的概率。从那里,您可以简单地找到概率最高的数字,并将其用作模型的预测。

关于javascript - 使用 Keras.js 将 Keras 模型实现到网站中,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45470067/

相关文章:

javascript - 如何每2个字符拆分字符串

tensorflow - 将channel_shift_range添加到Keras预处理(图像增强)中是否允许模型在可变光照情况下使用?

tensorflow - 在自定义 keras 损失中使用 keras 模型

python - Keras VGG 提取特征

javascript - 在 JS 中生成非常长的字符串(数十兆字节)的最有效方法

javascript - 使用非标准事件名称理解 on() 语法

javascript - 如何立即执行 Node 计划作业并每 30 分钟重复一次

javascript - 如何在 Javascript 中将值附加到数组中并保持旧值不变?

python - "How to build your own AlphaZero AI using Python and Keras"中的 stmemory 和 ltmemory

python - 如何用keras近似行列式