python - Keras:收集张量改变批量维度

标签 python tensorflow keras

我有一个形状为 (5, 2) 的输入张量,代表二维空间中的五个点。

我想取第一点,然后从所有五点中减去它。

四处阅读,我想我可以使用 K.gather 来切片并重复第一层。

在 Lambda 层中应用它后,批量维度被覆盖:

_input = Input(shape=(5, 2))
x = Reshape((5 * 2,))(_input)
x_ = Lambda(lambda t: K.gather(t, [0, 1] * 5))(x)

结果:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 5, 2)         0                                            
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 10)           0           input_1[0][0]                    
__________________________________________________________________________________________________
lambda_1 (Lambda)               (10, 10)             0           reshape_1[0][0]                  
__________________________________________________________________________________________________

我做错了什么?

另外,有没有更简单的方法来做到这一点?

最佳答案

gather 函数从批处理(第 0 个)轴返回所提供索引的值。因此,它为我们提供了形状为 (10, 10) 的第一个 (index:0) 和第二个 (index:1) 样本(形状 (10,)) 的列表 (length=10),而我们想要第一个批处理中每个样本的(索引:0)和第二个(索引:1)特征点。为了解决这个问题,我们可以在使用 gather 函数之前转置张量,以便 gather 函数选择正确的值,最终得到的张量应该再次转置。

_input = Input(shape=(5, 2))
x = Reshape((5 * 2,))(_input)
x_ = Lambda(lambda t: K.transpose(K.gather(K.transpose(t), [0, 1]*5)))(x)

输出:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 5, 2)]            0         
_________________________________________________________________
reshape (Reshape)            (None, 10)                0         
_________________________________________________________________
lambda (Lambda)              (None, 10)                0         
=================================================================

关于python - Keras:收集张量改变批量维度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58251407/

相关文章:

python - 检验零假设回归系数等于 statsmodels OLS 中的非零值

python - Python 的通用编码风格?

python - Tensorboard:导入错误:无法导入名称 'main'

python - InvalidArgumentError : Key: label. 无法解析序列化示例 : How can I find a way to parse the one-hot encoded labels from TFRecords?

python - keras 序列模型中的多个输出

python - keras 中输入数据不兼容错误,维度不匹配 ValueError

python - tkinter 入口高度

python - 使用 Flask 和 eventlet 响应并发请求

android - 使用 Gradle 添加 Tensorflow AAR 不起作用。我得到一个错误

python - 尝试重命名 tf.keras 上的预训练模型时出错