python - 如何在 tensorflow 中更新张量内的子张量?

标签 python python-3.x tensorflow

我正在使用 MNIST,我有一个大小为 [?,28,28,1] 的梯度张量,我想将其中的一些 [28,28,1] 子张量归零,我应该如何做到这一点?

我知道我需要将子张量归零的索引(作为列表)。我尝试做这样的事情(如下所示)但是,scatter.update只能改变变量而不是张量。我还尝试叠加所需的零和一的子张量,但无法建立所需的结果。

dy_dx, = tf.gradients(loss, x_adv) zeroes = tf.zeros(dy_dx[0].get_shape(), tf.float32) dy_dx = tf.scatter_update(dy_dx, indices, zeroes)

谢谢!

最佳答案

我建议创建一个 TensorFlow 常量,在您想要清零的位置使用零,在其他位置使用零。然后您可以创建一个操作,使用 tf.multiply 对常量和 dy_dx 进行逐元素乘法。根据图形的结构,您可能需要在下次调用 session.run 时将结果提供给 dy_dx;您可以用提要数据替换任何张量,包括变量和常量。

顺便说一下,如果你只想将 dropout 应用到输入层,你可以使用 tf.layers.dropout

关于python - 如何在 tensorflow 中更新张量内的子张量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47731409/

相关文章:

python - 子进程模块无法运行命令

python - SQLite 跨具有重复分组行的多列进行 SELECT 查询

python - 模块未找到错误: No module named 'onnxruntime'

python 3.4.3 64bit 无法运行此安装所需的程序

python - Keras 的基本二进制分类不起作用

python - 什么是 __methods__ 以及它为什么调用 __getattr__?

python - 读取csv文件的小数问题

python-3.x - Cloud Run 连接到 Cloud SQL 模块时出错 Python

python - 如何从主函数中获取某个函数的局部占位符?

python - 将 >2GB 数据传递给 tf.estimator