python - 将张量添加到 Tensorflow 中张量的特定列

标签 python tensorflow tensorflow2.0

我希望在张量“a”的特定列(由 cols 提供)处添加张量“b”。

因此,在下面的示例中,我希望将张量“b”添加到批处理中每个元素的第一列和最后一列。

 b = tf.Tensor: shape=(2,), dtype=float32, numpy=array([10., 10.], dtype=float32)
    
 a = tf.Tensor: shape=(2, 2, 5), dtype=float32, numpy=
            array([[[ 0.,  1.,  2.,  3.,  4.],
                    [ 5.,  6.,  7.,  8.,  9.]],
                   [[10., 11., 12., 13., 14.],
                    [15., 16., 17., 18., 19.]]], dtype=float32)
    
 cols = tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 4], dtype=int32)

我想要的结果是

out = tf.Tensor: shape=(2, 2, 5), dtype=float32, numpy=
            array([[[ 10.,  1.,  2.,  3.,  14.],
                    [ 15.,  6.,  7.,  8.,  19.]],
                   [[20., 11., 12., 13., 24.],
                    [25., 16., 17., 18., 29.]]], dtype=float32)

有人能告诉我最有效的方法吗?

最佳答案

尝试使用tf.tensor_scatter_nd_add ,因为您“希望将张量 'b' 添加到批处理的每个元素的第一列和最后一列”:

import tensorflow as tf

b = tf.constant([10., 10.], dtype=tf.float32)
  
a = tf.constant([[[ 0.,  1.,  2.,  3.,  4.],
                  [ 5.,  6.,  7.,  8.,  9.]],
                  [[10., 11., 12., 13., 14.],
                  [15., 16., 17., 18., 19.]]], dtype=tf.float32)
cols = tf.constant([0, 4], dtype=tf.int32)

indices = tf.stack([tf.repeat(tf.range(tf.shape(a)[0]), tf.shape(a)[0]*tf.shape(cols)[0]), tf.tile(tf.repeat(tf.range(tf.shape(a)[0]), tf.shape(a)[0]), [tf.shape(a)[1]])], axis=1)
indices = tf.concat([indices, tf.expand_dims(tf.tile(cols, [tf.math.reduce_prod(tf.shape(a)[:2])]), axis=-1)], axis=1)
updates = tf.tile(b, [tf.math.reduce_prod(tf.shape(a)[:2])])

print(tf.tensor_scatter_nd_add(a, indices, updates))
tf.Tensor(
[[[10.  1.  2.  3. 14.]
  [15.  6.  7.  8. 19.]]

 [[20. 11. 12. 13. 24.]
  [25. 16. 17. 18. 29.]]], shape=(2, 2, 5), dtype=float32)

更新 1:

在任何列中添加 b 的通用方法:

import tensorflow as tf

b = tf.constant([10., 10.], dtype=tf.float32)
  
a = tf.constant([[[ 0.,  1.,  2.,  3.,  4.],
                  [ 5.,  6.,  7.,  8.,  9.]],
                  [[10., 11., 12., 13., 14.],
                  [15., 16., 17., 18., 19.]]], dtype=tf.float32)

cols = tf.constant([0, 1, 4], dtype=tf.int32)
#cols = tf.constant([0, 1], dtype=tf.int32)
#cols = tf.constant([0], dtype=tf.int32)
#cols = tf.constant([0, 1, 3, 4], dtype=tf.int32)

indices = tf.stack([tf.repeat(tf.range(tf.shape(a)[0]), tf.shape(a)[0]*tf.shape(cols)[0]), 
                    tf.tile(tf.repeat(tf.range(tf.shape(a)[0]), tf.shape(cols)[0]), [tf.shape(a)[1]])], axis=1)

indices = tf.concat([indices, tf.expand_dims(tf.tile(cols, [tf.math.reduce_prod(tf.shape(a)[:2])]), axis=-1)], axis=1)
updates = tf.repeat([b[0]],  tf.shape(indices)[0])
print(tf.tensor_scatter_nd_add(a, indices, updates))

关于python - 将张量添加到 Tensorflow 中张量的特定列,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/71526128/

相关文章:

Python-OBD模块对象没有属性 "OBD"

python - Django - 将命令传递给 shell

android - 适用于 Android 的 Tensorflow 示例

python - 无法在 Google Colab 上将 GPU 用于 tensorflow 2.0

tensorflow - 在 tensorflow 2.0 中添加无维度

python - 使用 tf.data 的 One-hot 编码会混淆列

python - 让 Python 在交互式 Perl 脚本中模拟用户输入

Python:子类可以重载继承的方法吗?

python - 如何在 TensorFlow 中对自定义损失函数内的张量进行归一化?

python - tensorflow:简单 LSTM 网络的共享变量错误