Tensorflow多维argmax

标签 tensorflow

假设我有一个大小为 BxWxHxD 的张量。我想处理张量,这样我就有一个新的 BxWxHxD 张量,其中仅保留每个 WxH 切片中的最大元素,所有其他值均为零。

换句话说,我认为实现这一目标的最佳方法是以某种方式在 WxH 切片上采用 2D argmax,从而产生行和列的 BxD 索引张量,然后将其转换为单热 BxWxHxD 张量用作面膜。我怎样才能做到这一点?

最佳答案

您可以使用以下函数作为起点。它计算每个批处理和每个 channel 的最大元素的索引。生成的数组的格式为(批量大小、2、 channel 数)。

def argmax_2d(tensor):

  # input format: BxHxWxD
  assert rank(tensor) == 4

  # flatten the Tensor along the height and width axes
  flat_tensor = tf.reshape(tensor, (tf.shape(tensor)[0], -1, tf.shape(tensor)[3]))

  # argmax of the flat tensor
  argmax = tf.cast(tf.argmax(flat_tensor, axis=1), tf.int32)

  # convert indexes into 2D coordinates
  argmax_x = argmax // tf.shape(tensor)[2]
  argmax_y = argmax % tf.shape(tensor)[2]

  # stack and return 2D coordinates
  return tf.stack((argmax_x, argmax_y), axis=1)

def rank(tensor):

  # return the rank of a Tensor
  return len(tensor.get_shape())

关于Tensorflow多维argmax,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36388431/

相关文章:

python - tensorflow : Ran out of memory trying to allocate

python - TensorFlow 无法使用 gpu

tensorflow - Nvidia TX1上的TensorFlow

python - tensorflow 预测序列

python - 当切片本身是 tensorflow 中的张量时如何进行切片分配

python - 在 Tensorflow 急切模式下计算梯度与模型输入

machine-learning - 在多类分类问题中,为什么二元准确度会提供较高的准确度,而分类准确度会提供较低的准确度?

python - 检查目标 : expected dense_49 to have 4 dimensions, 时出错,但得到形状为 (2250, 3) 的数组

python - TensorFlow 中稀疏加法的 scatter_nd add() 示例

32 位 Linux 上的 TensorFlow?