javascript - TensorFlow.js 中关于 tf.Model 的内存管理

标签 javascript tensorflow tensorflow.js

我是 TensorFlow 的新手。

https://js.tensorflow.org/tutorials/core-concepts.html 中的“内存管理:disposetf.tidy”部分说我们必须以特殊方式管理内存。

但是,tfjs-layers 中的类(例如 tf.ModelLayer)似乎没有 dispose tf.tidy 不接受这些作为返回值。

所以我的问题是:

  • tf.Model 是否自动管理内存?
  • 如果不是,我该如何正确管理内存?

示例代码:

function defineModel(
    regularizerRate: number,
    learningRate: number,
    stateSize: number,
    actionSize: number,
): tf.Model {
    return tf.tidy(() => { // Compile error here, I couldn't return model.
        const input = tf.input({
            name: "INPUT",
            shape: [stateSize],
            dtype: "int32" as any, // TODO(mysticatea): https://github.com/tensorflow/tfjs/issues/120
        })
        const temp = applyHiddenLayers(input, regularizerRate)
        const valueOutput = applyValueLayer(temp, regularizerRate)
        const policyOutput = applyPolicyLayer(temp, actionSize, regularizerRate)
        const model = tf.model({
            inputs: [input],
            outputs: [valueOutput, policyOutput],
        })

        // TODO(mysticatea): https://github.com/tensorflow/tfjs/issues/98
        model.compile({
            optimizer: tf.train.sgd(LEARNING_RATE),
            loss: ["meanSquaredError", "meanSquaredError"],
        })
        model.lossFunctions[1] = softmaxCrossEntropy

        return model
    })
}

最佳答案

您应该只在直接操作张量时使用 tf.tidy()。

构建模型时,您还没有直接操纵张量,而是在设置层如何组合在一起的结构。这意味着您不需要将模型创建包装在 tf.tidy() 中。

只有当您调用“predict()”或“fit()”时,我们才处理具体的 Tensor 值并需要处理内存管理。

当“predict()”被调用时,它返回一个张量,您必须处理它,或者用“tidy()”包围它。

对于“fit()”,我们在内部为您完成所有内存管理。 “fit()”的返回值是纯数字,因此您不需要将其包装在“tidy()”中。

关于javascript - TensorFlow.js 中关于 tf.Model 的内存管理,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49691508/

相关文章:

javascript - Backbone 中的预处理模型

javascript - 如何在 angularjs 的 for 循环中动态创建 $rootScope.$on 事件监听器

python - tf.train.AdamOptimizer 和在 keras.compile 中使用 adam 有什么区别?

python - 类型错误: ('Keyword argument not understood:' , 'inputs' )

javascript - tensorflow 如何使用大于 32 位的整数数据进行基本数学运算?

Tensorflow.js 转换后的模型预测的结果与卡住模型不同/不准确

javascript - 网络 worker 中的 tensorflow.js

php - 更新 magento 产品页面上的主要产品 url

javascript - div中的按钮不起作用

tensorflow - 如何在tensorflow keras中访问自定义层的递归层