我一直在尝试使用Keras.js将Python生成的基本Keras模型实现到网站中。图书馆。现在,我已经训练了模型并将其导出到 model.json
、model_weights.buf
和 model_metadata.json
文件中。现在,我基本上从 github 页面复制并粘贴了测试代码,以查看模型是否会在浏览器中加载,但不幸的是我遇到了错误。这是测试代码。 (编辑:我修复了一些错误,其余错误请参见下文。)
var model = new KerasJS.Model({
filepaths: {
model: 'dist/model.json',
weights: 'dist/model_weights.buf',
metadata: 'dist/model_metadata.json'
},
gpu: true
});
model.ready()
.then(function() {
console.log("1");
// input data object keyed by names of the input layers
// or `input` for Sequential models
// values are the flattened Float32Array data
// (input tensor shapes are specified in the model config)
var inputData = {
'input_1': new Float32Array(data)
};
console.log("2 " + inputData);
// make predictions
return model.predict(inputData);
})
.then(function(outputData) {
// outputData is an object keyed by names of the output layers
// or `output` for Sequential models
// e.g.,
// outputData['fc1000']
console.log("3 " + outputData);
})
.catch(function(err) {
console.log(err);
// handle error
});
编辑:所以我稍微改变了我的程序以与 JS 5 兼容(这对我来说是一个愚蠢的错误),现在我遇到了一个不同的错误。该错误被捕获并记录下来。我得到的错误是:错误:predict() 必须采用一个对象,其中键是模型的命名输入:input。
我相信出现此问题是因为我的数据
变量的格式不正确。我认为如果我的模型接受 28x28 的数字数组,那么 data 也应该是 28x28 的数组,以便它可以正确地“预测”正确的输出。但是,我相信我遗漏了一些东西,这就是抛出错误的原因。 This问题和我的非常相似,但是它是用python而不是JS。再次强调,任何帮助将不胜感激。
最佳答案
好的,所以我明白为什么会发生这种情况。有两个问题。首先,data 数组需要被展平,所以我编写了一个快速函数来获取 2D 输入并将其“展平”为长度为 784 的 1D 数组。然后,因为我使用了顺序模型,数据的键名不应该是 'input_1'
,而应该是 'input'
。这消除了所有错误。
现在,要获取输出信息,我们只需将其存储在一个数组中,如下所示:var out = outputData['output']
。因为我使用了 MNIST 数据集,所以 out
是一个长度为 10 的一维数组,其中包含每个数字是用户编写的数字的概率。从那里,您可以简单地找到概率最高的数字,并将其用作模型的预测。
关于javascript - 使用 Keras.js 将 Keras 模型实现到网站中,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45470067/