我有一个形状为 (5, 2) 的输入张量,代表二维空间中的五个点。
我想取第一点,然后从所有五点中减去它。
四处阅读,我想我可以使用 K.gather
来切片并重复第一层。
在 Lambda 层中应用它后,批量维度被覆盖:
_input = Input(shape=(5, 2))
x = Reshape((5 * 2,))(_input)
x_ = Lambda(lambda t: K.gather(t, [0, 1] * 5))(x)
结果:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 5, 2) 0
__________________________________________________________________________________________________
reshape_1 (Reshape) (None, 10) 0 input_1[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda) (10, 10) 0 reshape_1[0][0]
__________________________________________________________________________________________________
我做错了什么?
另外,有没有更简单的方法来做到这一点?
最佳答案
gather
函数从批处理(第 0 个)轴返回所提供索引的值。因此,它为我们提供了形状为 (10, 10) 的第一个 (index:0) 和第二个 (index:1) 样本(形状 (10,)) 的列表 (length=10),而我们想要第一个批处理中每个样本的(索引:0)和第二个(索引:1)特征点。为了解决这个问题,我们可以在使用 gather
函数之前转置张量,以便 gather
函数选择正确的值,最终得到的张量应该再次转置。
_input = Input(shape=(5, 2))
x = Reshape((5 * 2,))(_input)
x_ = Lambda(lambda t: K.transpose(K.gather(K.transpose(t), [0, 1]*5)))(x)
输出:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 5, 2)] 0
_________________________________________________________________
reshape (Reshape) (None, 10) 0
_________________________________________________________________
lambda (Lambda) (None, 10) 0
=================================================================
关于python - Keras:收集张量改变批量维度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58251407/