python - 卡在 tensorflow 高级索引上

标签 python tensorflow

给定形状为 (?,5,5) 的输入张量,我需要通过对形状为 (120,5,2) 的索引张量指定的元素求和来找到每个示例的最大和。索引张量列出了 120 种对示例的 5x5 矩阵求和的方法。 例如:

Input tensor (?,5,5):
[
  [
    [0,1,0,0,0],
    [0,0,0,0,1],
    [1,0,0,0,0],
    [0,0,0,1,0],
    [0,0,1,0,0]
  ],
  [
    ...
  ],
  ...
]

Index tensor(120,5,2):
[
  [
    [0,1], 
    [1,4], 
    [2,2], 
    [3,0], 
    [4,3]  
  ],
  [
    ...
  ],
...
]

这里,第一次求和的结果将是 1+1+0+0+0 = 2。 我需要找到每个示例的索引数组给出的所有 120 种方式的最大总和。

在 numpy 中,我将使用整数索引数组的高级索引,但不幸的是 tf 不支持这一点。我找到了 tf.gather_nd 但似乎我这个函数假设我知道批处理中每个示例的索引,但我不知道。

最佳答案

解决了。 诀窍是调换轴。这样,未知维度可以被推到末尾,gather_nd 可以选择未知维度之前的所有切片。

如果有人关心的话,这是完整的代码......

def permute(a, l, r):
    if l==r:
        yield list(zip([0,1,2,3,4],a))
    else:
        for i in range(l,r+1):
            a[l], a[i] = a[i], a[l]
            yield from permute(a, l+1, r)
            a[l], a[i] = a[i], a[l]

def multi_class_acc_positions(pred, target, input):
    pred_5x5 = tf.reshape(pred, [-1, 5, 5])
    target_5x5 = tf.reshape(target, [-1, 5, 5])
    pred_5x5_T = tf.transpose(pred_5x5, (1,2,0))
    all_perms = tf.constant(list(permute([0,1,2,3,4],0,4)))
    selected_elemens_per_example = tf.gather_nd(pred_5x5_T, all_perms)
    sums_per_example = tf.reduce_sum(selected_elemens_per_example, axis=1)
    best_perm_per_example_index = tf.argmax(sums_per_example, axis=0)
    best_perms = tf.gather_nd(all_perms, best_perm_per_example_index[:,tf.newaxis])[:,:,1]
    pred_5x5_one_hot = tf.reshape(tf.one_hot(best_perms, depth=5), (-1, 5, 5))
    correct_prediction = tf.equal(tf.argmax(pred_5x5_one_hot, axis=2), tf.argmax(target_5x5, axis=2))
    all_correct = tf.reduce_min(tf.cast(correct_prediction, tf.float32), 1)
    acc = tf.reduce_mean(all_correct)
    return acc

关于python - 卡在 tensorflow 高级索引上,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52623614/

相关文章:

python - 为什么在 Python 中从自身 (x - x) 中减去一个值?

python - Django 递归注解

machine-learning - 多个独立标签的成本和激活函数

python - 扩展 mnist 数据库

python - 专家混合 - 仅在每次迭代时训练最佳模型

python - 如何过滤经纬度坐标落在一定半径内的django模型

python - Groupby 客户和商店 - 获得平均交易频率。日期问题

python - 如何删除 python - mechanize 中的控件?

javascript - tensorflowjs 加载重新训练的 coco-ssd 模型 - 在浏览器中不起作用

python - 在 Keras 中构建自定义损失函数