python - 在 tensorflow 中扩展维度并复制数据

标签 python tensorflow

我有一个大小为 BxHxWx3 的张量 input 和另一个大小为 Bx3 的张量 params。这里的 B 是批量大小。我想将 params 转换为大小为 BxHxWx3 的张量?这样我就可以将两个张量相乘。关于我应该如何解决这个问题有什么建议吗? (在较高层面上,我想要做的是将一组图像中的每个像素乘以为每个 channel 定义的值)

最佳答案

<强>1。回答你的第一个问题

您可以使用tf.expand_dimstf.tile的组合:

input_shape = tf.shape(input)
mod_params = params.expand_dims(1) # shape is [Bx1x3]
mod_params = mod_params.expand_dims(2) # shape is [Bx1x1x3]
mod_params = tf.tile( \
                mod_params, \
                [1, input_shape[1], input_shape[2], 1] \
             ) # shape is [BxHxWx3]

<强>2。为了实现您的最终结果,...

...你可以执行

ret = tf.multiply(input, mod_params)

...或者,您也可以使用tensorflow的广播功能(借助tf.transpose)

ret = tf.multiply(
         tf.transpose(input, perm=[2,1,0,3]), \
         params \
      ) # shape: [WxHxBx3]
ret = tf.transpose(ret, perm=[2,1,0,3]) # shape: [BxHxWx3]

关于python - 在 tensorflow 中扩展维度并复制数据,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50808134/

相关文章:

tensorflow - 如何使用 tfa.layers.PoincareNormalize 实现庞加莱嵌入?

python - keras 序列模型中的多个输出

tensorflow - 修改内置 TensorFlow 内核的最佳方法

python - 在 python 上模拟父类(super class)调用

python - "foo is None"和 "foo == None"之间有什么区别吗?

python - 为什么日期时间字符串格式不可逆?

python - 用于重新训练示例中验证的 Tensorflow 混淆矩阵

python - 在 Django 中从 DateField 迁移到 DateTimeField

python - Jupyter: "notebook"不是 jupyter 命令

python - 多个输出的单一损失