我希望在张量“a”的特定列(由 cols 提供)处添加张量“b”。
因此,在下面的示例中,我希望将张量“b”添加到批处理中每个元素的第一列和最后一列。
b = tf.Tensor: shape=(2,), dtype=float32, numpy=array([10., 10.], dtype=float32)
a = tf.Tensor: shape=(2, 2, 5), dtype=float32, numpy=
array([[[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.]],
[[10., 11., 12., 13., 14.],
[15., 16., 17., 18., 19.]]], dtype=float32)
cols = tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 4], dtype=int32)
我想要的结果是
out = tf.Tensor: shape=(2, 2, 5), dtype=float32, numpy=
array([[[ 10., 1., 2., 3., 14.],
[ 15., 6., 7., 8., 19.]],
[[20., 11., 12., 13., 24.],
[25., 16., 17., 18., 29.]]], dtype=float32)
有人能告诉我最有效的方法吗?
最佳答案
尝试使用tf.tensor_scatter_nd_add ,因为您“希望将张量 'b' 添加到批处理的每个元素的第一列和最后一列”:
import tensorflow as tf
b = tf.constant([10., 10.], dtype=tf.float32)
a = tf.constant([[[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.]],
[[10., 11., 12., 13., 14.],
[15., 16., 17., 18., 19.]]], dtype=tf.float32)
cols = tf.constant([0, 4], dtype=tf.int32)
indices = tf.stack([tf.repeat(tf.range(tf.shape(a)[0]), tf.shape(a)[0]*tf.shape(cols)[0]), tf.tile(tf.repeat(tf.range(tf.shape(a)[0]), tf.shape(a)[0]), [tf.shape(a)[1]])], axis=1)
indices = tf.concat([indices, tf.expand_dims(tf.tile(cols, [tf.math.reduce_prod(tf.shape(a)[:2])]), axis=-1)], axis=1)
updates = tf.tile(b, [tf.math.reduce_prod(tf.shape(a)[:2])])
print(tf.tensor_scatter_nd_add(a, indices, updates))
tf.Tensor(
[[[10. 1. 2. 3. 14.]
[15. 6. 7. 8. 19.]]
[[20. 11. 12. 13. 24.]
[25. 16. 17. 18. 29.]]], shape=(2, 2, 5), dtype=float32)
更新 1:
在任何列中添加 b
的通用方法:
import tensorflow as tf
b = tf.constant([10., 10.], dtype=tf.float32)
a = tf.constant([[[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.]],
[[10., 11., 12., 13., 14.],
[15., 16., 17., 18., 19.]]], dtype=tf.float32)
cols = tf.constant([0, 1, 4], dtype=tf.int32)
#cols = tf.constant([0, 1], dtype=tf.int32)
#cols = tf.constant([0], dtype=tf.int32)
#cols = tf.constant([0, 1, 3, 4], dtype=tf.int32)
indices = tf.stack([tf.repeat(tf.range(tf.shape(a)[0]), tf.shape(a)[0]*tf.shape(cols)[0]),
tf.tile(tf.repeat(tf.range(tf.shape(a)[0]), tf.shape(cols)[0]), [tf.shape(a)[1]])], axis=1)
indices = tf.concat([indices, tf.expand_dims(tf.tile(cols, [tf.math.reduce_prod(tf.shape(a)[:2])]), axis=-1)], axis=1)
updates = tf.repeat([b[0]], tf.shape(indices)[0])
print(tf.tensor_scatter_nd_add(a, indices, updates))
关于python - 将张量添加到 Tensorflow 中张量的特定列,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/71526128/