python - TensorFlow,批量索引(第一维)和排序

标签 python sorting tensorflow indexing keras

我有一个形状为 (?,368,5) 的参数张量,以及一个形状为 (?,368) 的查询张量。查询张量存储用于对第一个张量进行排序的索引。

所需的输出具有形状:(?,368,5)。因为我需要它作为神经网络中的损失函数,所以使用的操作应该保持可微分。此外,在运行时,第一个轴的大小 ? 对应于 batchsize。

到目前为止,我尝试了 tf.gathertf.gather_nd,但是 tf.gather(params,q​​uery) 生成形状为 (?,368,368,5) 的张量。

查询张量是通过执行以下操作实现的:

query = tf.nn.top_k(params[:, :, 0], k=params.shape[1], sorted=True).indices

总的来说,我尝试按第三轴上的第一个元素对参数张量进行排序(某种倒角距离)。最后要提的是,我使用 Keras 框架。

最佳答案

您需要将第一个维度的索引添加到query 以便与tf.gather_nd 一起使用。这是一种方法:

import tensorflow as tf
import numpy as np

np.random.seed(100)

with tf.Graph().as_default(), tf.Session() as sess:
    params = tf.placeholder(tf.float32, [None, 368, 5])
    query = tf.nn.top_k(params[:, :, 0], k=params.shape[1], sorted=True).indices
    n = tf.shape(params)[0]
    # Make tensor of indices for the first dimension
    ii = tf.tile(tf.range(n)[:, tf.newaxis], (1, params.shape[1]))
    # Stack indices
    idx = tf.stack([ii, query], axis=-1)
    # Gather reordered tensor
    result = tf.gather_nd(params, idx)
    # Test
    out = sess.run(result, feed_dict={params: np.random.rand(10, 368, 5)})
    # Check the order is correct
    print(np.all(np.diff(out[:, :, 0], axis=1) <= 0))
    # True

关于python - TensorFlow,批量索引(第一维)和排序,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50605059/

相关文章:

python - 删除执行中的迭代器的域是安全的(记录的行为?)

c - 如何使用libavl?

python - 从 tensorflow 中的向量创建二进制张量

python - 安装 opencv 后 Numpy.core.multiarray 不再工作

Python pip install h5py 遇到 TypeError : unsupported operand type(s) for -=: 'Retry' and 'int'

python - 如何对 Python 列表进行部分排序?

Python基础资料引用,相同引用列表

python - 为什么列表的列表排序不正确?

c++ - 对指向对象的指针数组进行排序

tensorflow - Keras 中的扁平化函数