python-3.x - Keras 张量 - 使用来自另一个张量的索引获取值

标签 python-3.x tensorflow keras slice tensor

假设我有这两个张量:

  • valueMatrix,形状为(?, 3),其中?是批量大小
  • indexMatrix,形如(?, 1)

我想从 valueMatrix 中包含在 indexMatrix 中的索引处检索值。

示例(伪代码):

valueMatrix = [[7,15,5],[4,6,8]] -- shape=(2,3) -- type=float 
indexMatrix = [[1],[0]] -- shape = (2,1) -- type=int

我想从这个例子中做类似的事情:

valueMatrix[indexMatrix] --> returns --> [[15],[4]]

与其他后端相比,我更喜欢 Tensorflow,但答案必须与使用 Lambda 层或其他适合任务的层的 Keras 模型兼容。

最佳答案

import tensorflow as tf
valueMatrix = tf.constant([[7,15,5],[4,6,8]])
indexMatrix = tf.constant([[1],[0]])

# create the row index with tf.range
row_idx = tf.reshape(tf.range(indexMatrix.shape[0]), (-1,1))
# stack with column index
idx = tf.stack([row_idx, indexMatrix], axis=-1)
# extract the elements with gather_nd
values = tf.gather_nd(valueMatrix, idx)

with tf.Session() as sess:
    print(sess.run(values))
#[[15]
# [ 4]]

关于python-3.x - Keras 张量 - 使用来自另一个张量的索引获取值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46526869/

相关文章:

python - 从二维数组中删除行?

python - Keras:使用 fit_generator 时出现 notImplementedError/RuntimeError

python - 如何检查discord.py中的所有者

python - 如何在python中使用fileinput保存文件?给出属性错误: 'FileInput' object has no attribute 'read'

javascript - 等待链接点击 Selenium 的正确方法

python - Keras 中的 model(x) 和 model.predict(x) 之间的区别?

python - Python跟踪车辆和TypeError整数

tensorflow - 当只需要第一个元素时,为什么要创建一个新轴?

tensorflow - 保存 keras 模型以便将来恢复训练的最佳方法是什么?

python - 没有 Gym 的 Tensorflow 强化学习