python - 如何在 TensorFlow 中实现 Numpy where 索引?

标签 python numpy tensorflow

我有以下使用 numpy.where 的操作:

    mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
    index = np.array([[1,0,0],[0,1,0],[0,0,1]])
    mat[np.where(index>0)] = 100
    print(mat)

如何在 TensorFlow 中实现等价物?

mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
index = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
tf_mat = tf.constant(mat)
tf_index = tf.constant(index)
indi = tf.where(tf_index>0)
tf_mat[indi] = -1   <===== not allowed 

最佳答案

假设您想要创建一个带有一些替换元素的新张量,而不是更新变量,您可以这样做:

import numpy as np
import tensorflow as tf

mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
index = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
tf_mat = tf.constant(mat)
tf_index = tf.constant(index)
tf_mat = tf.where(tf_index > 0, -tf.ones_like(tf_mat), tf_mat)
with tf.Session() as sess:
    print(sess.run(tf_mat))

输出:

[[-1  2  3]
 [ 4 -1  6]
 [ 7  8 -1]]

关于python - 如何在 TensorFlow 中实现 Numpy where 索引?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51588899/

相关文章:

python - Tensorflow 2 坐标分类器

python - 将多列列表分解为行

python - Redis:用于修剪排序集的 ZUNIONSTORE

python - 无法使用python连接AWS EC2上的MongoDb

numpy - 如何有效地为多个参数准备矩阵(二维数组)?

python - 将 NumPy 数组映射到位

python - keras - 使用 lambda 层时如何避免尺寸错误

python - 神经机器翻译模型预测偏差一

python - 如何区分 2 个非常大的数组?

python - python中使用openCV进行多线程图像处理