python - Tensorflow - 如何使用批量维度执行 tf.gather

标签 python tensorflow

不幸的是,我不知道如何制定这个问题的标题,也许有人可以更改它?

如何优雅地替换下面的for循环?

#tensor.shape -> (batchsize,100)
#indices.shape -> (batchsize,100) 
liste = []
for i in range(tensor.shape[0]):
        liste.append(tf.gather(tensor[i,:], indices[i,:10]))

new_tensor = tf.stack(liste)
        

最佳答案

这应该可以解决问题:

new_tensor = tf.gather(tensor, axis=-1, indices=indices[:, :10], batch_dims=1)

这里有一个最小的可重现示例:

import tensorflow as tf

# for version 1.x
#tf.enable_eager_execution()

tensor = tf.random.normal((2, 10))
indices = tf.random.uniform(shape=[2, 10], minval=0, maxval=4, dtype=tf.int32)

liste = []
for i in range(tensor.shape[0]):
        liste.append(tf.gather(tensor[i,:], indices[i,:5]))

new_tensor = tf.stack(liste)

print('tensor: ')
print(tensor)

print('new_tensor: ')
print(new_tensor)

new_tensor_v2 = tf.gather(tensor, axis=-1, indices=indices[:, :5], batch_dims=1)
print('new_tensor_v2: ')
print(new_tensor_v2)

关于python - Tensorflow - 如何使用批量维度执行 tf.gather,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64013384/

相关文章:

python - 在 Pandas 中重新格式化数据框

python - EmailMultiAlternatives 无法建立连接,因为目标机器主动拒绝它

python - 如何在多个定界符上拆分一个字符串但只捕获一些?

Tensorflow:索引在 CUDA_1D_KERNEL_LOOP(index, nthreads) op user 中表示什么

Tensorflow:从 zip/tar 文件列表中读取许多图像(或任何文件)

python - Tensorflow:尝试使用未初始化的值 beta1_power

python - 使用 Pandas 时 dateutil.tz 包显然丢失了?

python - 是否以root身份启动supervisord?

tensorflow - tensorflow tf.profile 计算的 FLOPs 是多少?

python - 在Tensorflow中,如何解开tf.nn.max_pool_with_argmax得到的扁平化指标?