tensorflow - 如何在 tensorflow 中的二维张量中找到前 k 个值

标签 tensorflow matrix tensor

有没有办法在 Tensorflow 的二维张量中找到前 k 值?

我可以将 tf.nn.top_k 用于一维张量,但不能用于二维张量。我有一个大小未知的二维张量,有没有办法找到前 k 值及其索引?

非常感谢。

最佳答案

您可以在 tf.nn.top_k() 之前将矩阵 reshape 为一维张量,然后根据一维索引计算二维索引:

x = tf.random_uniform((3, 4))
x_shape = tf.shape(x)
k = 3

top_values, top_indices = tf.nn.top_k(tf.reshape(x, (-1,)), k)
top_indices = tf.stack(((top_indices // x_shape[1]), (top_indices % x_shape[1])), -1)

with tf.Session() as sess:
    mat, val, ind = sess.run([x, top_values, top_indices])
    print(mat)
    # [[ 0.2154634   0.52707899  0.29711092  0.74310601]
    #  [ 0.61274767  0.82408106  0.27242708  0.25479805]
    #  [ 0.25863791  0.16790807  0.95585966  0.51889324]]
    print(val)
    # [ 0.95585966  0.82408106  0.74310601]
    print(ind)
    # [[2 2]
    #  [1 1]
    #  [0 3]]

关于tensorflow - 如何在 tensorflow 中的二维张量中找到前 k 个值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50797807/

相关文章:

python - 从 numpy 和 scipy.sparse 准备 tensorflow 的数据输入

Tensorflow : Shape error in LSTM model expected shape=(None, 无,90),发现形状=[90, 1, 78]

python - Scipy 稀疏矩阵乘法

java - 一旦我退出函数(Java),对象的值为空

python - 如何在 PyTorch 中获取张量的值?

python - Numpy 和 Pandas - 用零填充 reshape

python - 使用相同大小的索引张量拆分 torch 张量

python - Windows 上的 Tensorflow 对象检测 API

Tensorflow 2.0 数据集和数据加载器

algorithm - 从一组成对距离中确定点