python - 使用 Tensorflow 中的索引为 2D 张量赋值

标签 python tensorflow keras tensor assign

我有一个二维张量 A,我希望将其非零条目替换为另一个张量 B,如下所示。

A = tf.constant([[1.0,0,1.0],[0,1.0,0],[1.0,0,1.0]],dtype=tf.float32)
B = tf.constant([1.0,2.0,3.0,4,0,5.0],dtype=tf.float32)

所以我希望最终的 A 为

 A = tf.constant([[1.0,0.0,2.0],[0,3.0,0.0],[4.0,0.0,5.0]],dtype=tf.float32)

我得到 A 的非零元素的索引如下

where_nonzero = tf.not_equal(A, tf.constant(0, dtype=tf.float32))
indices = tf.where(where_nonzero)

indices = <tf.Tensor: shape=(5, 2), dtype=int64, numpy=
array([[0, 0],
   [0, 2],
   [1, 1],
   [2, 0],
   [2, 2]])>

有人可以帮忙吗?

最佳答案

IIUC,您应该能够使用tf.tensor_scatter_nd_update:

import tensorflow as tf

A = tf.constant([[1.0,0,1.0],[0,1.0,0],[1.0,0,1.0]],dtype=tf.float32)
B = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0],dtype=tf.float32)

where_nonzero = tf.not_equal(A, tf.constant(0, dtype=tf.float32))
indices = tf.where(where_nonzero)
A = tf.tensor_scatter_nd_update(A, indices, B)
print(A)
tf.Tensor(
[[1. 0. 2.]
 [0. 3. 0.]
 [4. 0. 5.]], shape=(3, 3), dtype=float32)

关于python - 使用 Tensorflow 中的索引为 2D 张量赋值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/71222431/

相关文章:

python - 如何清除这两个小部件以及它们被其他小部件清除

Python 并发 : Wait for one out of few futures and cancel the rest

python - 尝试解码 tensorflow 数据时获取 "AttributeError: ' MapDataset' 对象没有属性 'prefetch'

python - 使用 tf.keras CNN 进行图像分类有关内核以及如何创建测试数据集的问题?

python - 了解具有不同批量大小的 Keras LSTM 模型

python - 为什么使用plt.imshow显示后输出与变量的值不同?

python - 需要找到特定长度的所有回文

python - Tensorflow build() 如何从 tf.keras.layers.Layer 工作

python - tf.reshape 在添加额外维度的情况下不起作用

Python Selenium - 类似身份验证的警报弹出窗口