python - 使用 argmax 从 Tensor 获取值

标签 python tensorflow tensorflow2.0 tensor

我有一个形状为 (60, 128, 30000)张量。我想要获取 30000 维度 (axis=2) 的 argmax 值。 此代码是一个示例:

tensor = tf.random.uniform((60, 128, 30000)) # shape (60, 128, 30000)
argmax = tf.argmax(tensor, axis=2) # shape (60, 128) --> max of each 30000

# do something to get every values of 30000
# argmax output (index)
<tf.Tensor: shape=(60, 128), dtype=int64, numpy=
array([[ 3229,  3079,  8360, ...,  1005, 16460,   872],
       [17808,  1253, 25476, ..., 16130,  3479,  3479],
       [27717, 25429, 18808, ...,  9787,  2603, 24011],
       ...,
       [25429, 25429,  5647, ..., 18451, 12453, 12453],
       [ 7361, 13463, 15864, ..., 18839, 12453, 12453],
       [ 4750, 25009, 11888, ...,  5647,  1993, 18451]], dtype=int64)>

# Desired output: each values of every index

使用argmax,我得到了它们的索引数组,而不是它们的值。如何获得其值的形状相同 (60, 128) 的数组?

最佳答案

您必须使用tf.meshgridtf.gather_nd来实现您想要的:

tensor = tf.random.uniform((60, 128, 30000)) # shape (60, 128, 30000)
argmax = tf.argmax(tensor, axis=2)

ij = tf.stack(tf.meshgrid(
    tf.range(tensor.shape[0], dtype=tf.int64), 
    tf.range(tensor.shape[1], dtype=tf.int64),
                              indexing='ij'), axis=-1)

gather_indices = tf.concat([ij, tf.expand_dims(argmax, axis=-1)], axis=-1)
result = tf.gather_nd(tensor, gather_indices)
tf.print(result.shape)
TensorShape([60, 128])

为什么需要tf.meshgrid?因为 argmax 确实包含您的索引,但形状错误。函数 tf.gather_nd 需要知道它应该从 3D 张量中提取值的确切位置。 tf.meshgrid 函数创建两个一维数组的矩形网格,表示第一维和第二维的张量索引。

import tensorflow as tf

tensor = tf.random.uniform((2, 5, 3))
argmax = tf.argmax(tensor, axis=2)

# result = tf.gather_nd(tensor, gather_ind) <-- Would not work because arxmax has the shape TensorShape([2, 5]) but  TensorShape([2, 5, 3]) is required
tf.print('Input tensor:\n', tensor, tensor.shape, '\nArgmax tensor:\n', argmax, argmax.shape)

i, j = tf.meshgrid(
    tf.range(tensor.shape[0], dtype=tf.int64), 
    tf.range(tensor.shape[1], dtype=tf.int64),
                              indexing='ij')

# You need to create a mesh grid to correctly index your tensor.

ij = tf.stack([i, j], axis=-1)
tf.print('Meshgrid:\n', i, j, summarize=-1)
tf.print('Stacked:\n', ij, summarize=-1)

gather_indices = tf.concat([ij, tf.expand_dims(argmax, axis=-1)], axis=-1)
tf.print('Gathered indices:\n', gather_indices, gather_indices.shape, summarize=-1)

result = tf.gather_nd(tensor, gather_indices)
tf.print('\nFinal result:\n', result, result.shape)
Input tensor:
 [[[0.889752269 0.243187189 0.601408958]
  [0.891950965 0.776625633 0.146243811]
  [0.136176467 0.743871331 0.762170076]
  [0.424416184 0.150568008 0.464055896]
  [0.308753 0.0792338848 0.383242]]

 [[0.741660118 0.49783361 0.935318112]
  [0.0616152287 0.0367363691 0.748341084]
  [0.397849679 0.765681744 0.502376914]
  [0.750188231 0.304993749 0.733741879]
  [0.31267941 0.778184056 0.546301]]] TensorShape([2, 5, 3]) 
Argmax tensor:
 [[0 0 2 2 2]
 [2 2 1 0 1]] TensorShape([2, 5])
Meshgrid:
 [[0 0 0 0 0]
 [1 1 1 1 1]] [[0 1 2 3 4]
 [0 1 2 3 4]]
Stacked:
 [[[0 0]
  [0 1]
  [0 2]
  [0 3]
  [0 4]]

 [[1 0]
  [1 1]
  [1 2]
  [1 3]
  [1 4]]]
Gathered indices:
 [[[0 0 0]
  [0 1 0]
  [0 2 2]
  [0 3 2]
  [0 4 2]]

 [[1 0 2]
  [1 1 2]
  [1 2 1]
  [1 3 0]
  [1 4 1]]] TensorShape([2, 5, 3])

Final result:
 [[0.889752269 0.891950965 0.762170076 0.464055896 0.383242]
 [0.935318112 0.748341084 0.765681744 0.750188231 0.778184056]] TensorShape([2, 5])

顺便说一句,您还可以考虑使用tf.math.top_k,因为您想获取最后一个维度的最大值。此函数返回索引和值(您想要的):

tensor = tf.random.uniform((60, 128, 30000)) # shape (60, 128, 30000)
values, indices = tf.math.top_k(tensor,
                        k=1)
tf.print(tf.squeeze(values, axis=-1).shape)
TensorShape([60, 128])

关于python - 使用 argmax 从 Tensor 获取值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69897393/

相关文章:

python - 如何将 perreplica 转换为张量?

tensorflow - tensorflow-gpu 2.0alpha0 错误

python - 如何在 altair 中绘制一条跨越垂直连接图的线?

python - Tkinter 检查哪个条目最后有焦点

python - 如何在 Python 中循环调用函数

python - 如何在keras中使用一维卷积神经网络解决音频信号问题

tensorflow - Tensorflow 是如何做量化和反量化的?

python - Linux在OpenWrt中如何新建文件/脚本

python - Tensor.name 在 eager execution 中没有意义

python - 错误 : "http_archive" is not defined when starting the local Bazel server in the TensorFlow Installation