在编写机器学习模型时,我发现自己需要计算指标,或者为了可视化目的在回调中运行额外的前向传递。在 PyTorch 中,我使用 torch.no_grad()
执行此操作,这会阻止计算梯度,因此这些操作不会影响优化。
- 此机制在 TensorFlow/Keras 中如何运作?
- Keras 模型是可调用的。所以,像
model(x)
这样的东西是可能的。但是,也可以说model.predict(x)
,它似乎也调用了call
。两者有区别吗?
最佳答案
tensorflow 等效项为 tf.stop_gradient
另外不要忘记,Keras 在使用预测时不会计算梯度(或者只是通过 __call__
调用模型)。
关于python - PyTorch 的 `no_grad` 函数在 TensorFlow/Keras 中的等价物是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67385963/