python - 索引 Keras 张量

标签 python tensorflow keras keras-layer

我的 Keras 函数模型的输出层是一个张量 x尺寸(None, 1344, 2) 。我要提取n < 1344来自 x 的第二维度内的条目并创建一个新的张量 y尺寸(None, n, 2)

提取n似乎很简单只需访问 x[:, :n,:] 即可连续输入,但是(看起来)很困难,如果 n索引不连续。 Keras 有没有一种干净的方法可以做到这一点?

这是迄今为止我的方法。

实验 1(切片张量、连续索引,有效):

print('My tensor shape is', K.int_shape(x)) #my tensor 
(None, 1344, 2) # as printed in my code
print('Slicing first 5 entries, shape is', K.int_shape(x[:, :5, :]))
(None, 5, 2) # as printed in my code, works!

实验 2(在任意索引处对张量进行索引,失败)

print('My tensor shape is', K.int_shape(x)) #my tensor 
(None, 1344, 2) # as printed in my code
foo = np.array([1,2,4,5,8])
print('arbitrary indexing, shape is', K.int_shape(x[:,foo,:]))

Keras 返回以下错误:

ValueError: Shapes must be equal rank, but are 1 and 0
From merging shape 1 with other shapes. for 'strided_slice_17/stack_1' (op: 
'Pack') with input shapes: [], [5], [].

实验3( tensorflow 后端函数) 我也尝试过K.backend.gather但其用法尚不清楚,因为 1) Keras 文档指出索引应该是整数张量,并且没有与 numpy.where 等价的 Keras如果我的目标是提取 x 中的条目满足一定条件和2) K.backend.gather似乎从 axis = 0 中提取条目而我想从 x 的第二个维度中提取.

最佳答案

您正在寻找tf.gather_nd它将基于索引数组进行索引:

# From documentation
indices = [[0, 0], [1, 1]]
params = [['a', 'b'], ['c', 'd']]
output = ['a', 'd']

要在 Keras 模型中使用它,请确保将其包装在像 Lambda 这样的层中。

关于python - 索引 Keras 张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50653044/

相关文章:

python - 在 Tensorboard 上显示图像(通过 Keras)

python - 值错误 : Shape must be rank 2 but is rank 3 for 'MatMul'

python - 从 Keras 中的生成器获取 x_test、y_test?

apache-spark - 使用 Keras 模型作为 Apache Spark 和 Elephas 的广播变量

python - python日志记录是否支持多处理?

python - 使用 Python 对文件列表进行排序

python - 如何修复这个链表算法

python - 传递给 `fit` 的模型只能将 `training` 和 `call` 中的第一个参数作为位置参数,发现

tensorflow - 如何将卷积神经网络从 Keras 模型对象提取到 Networkx DiGraph 对象,并将权重作为边缘属性?

python - 无法按 Pandas 数据框中的时间戳编制索引