python - TensorFlow 中基于 bool 掩码的部分更新张量

标签 python tensorflow tensorflow2.0

我想根据某些条件更新部分张量。

我知道 TensorFlow 张量是不可变的,因此创建一个新的张量对我来说没问题。 我尝试了 tensor_scatter_nd_update 方法,但无法使其工作

这是我想要在用 NumPy 编写的 TensorFlow 中复制的代码。

import numpy as np

a = np.random.random((1, 3))
b = np.array([[0, 1, 0]])

c = np.zeros_like(a)
mask = b == 1
c[mask] = np.log(a[mask])

最佳答案

在 TensorFlow 中,我们不会更新实际上是不可变对象(immutable对象)的张量。相反,我们从其他张量创建新的张量,就像在函数式语言中一样。

import tensorflow as tf

a = tf.random.uniform(shape=(1, 3))
b = tf.constant([[0, 1, 0]], dtype=tf.int32)

c = tf.zeros_like(a)
mask = b == 1
c_updated = tf.where(mask, tf.math.log(a), c)
# [[ 0.      , -4.175911,  0.      ]]```

关于python - TensorFlow 中基于 bool 掩码的部分更新张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/70816829/

相关文章:

python - 如何在模块中定义与 __all__ 分开的 `from ... import *' api?

python - 无法使用 tf.summary() 为测试集存储准确度

python - python的list.count()有张量eqiv吗

tensorflow - Model.fit() 是否将整个训练数据集上传到 GPU?

python - 使用分布式策略在 Colab TPU 上训练模型

python - tf.dataset 实例的热切执行

python - 如何在 python 中已有的列表中插入第三个嵌套列表?

python - 如何使用 OCR 检测图像中的下标数字?

python - Django ModelView 的 "fields"属性不起作用?

python - 如何让 HMM 处理 Tensorflow 中的实值数据