我正在尝试从 3 级 tf.float
张量构造 2 级 tf.float
张量 x
, y
和 2 阶 tf.int32
张量,z
为:
x[i][j] = y[i,z[i][j],j]
我知道我需要使用tf.gather_nd
:
x = tf.gather_nd(y,indices)
哪里
索引[i][j][:] = [i,z[i][j],j]
但是,我在使用 tensorflow 函数将 z
扩展为更高级别来构造索引
时遇到了麻烦。
我试图以矢量化形式维护这些操作。
简单地使用tf.stack
作为,是否更实用,
索引 = tf.stack([ii,z,jj],axis=-1)
哪里
ii[i,:] = i
和
jj[:,j] = j
?
最佳答案
我认为这可以满足您的需要:
import tensorflow as tf
import numpy as np
# Inputs
y = tf.placeholder(tf.float32, [None, None, None])
z = tf.placeholder(tf.int32, [None, None])
# Make first and last indices
y_shape = tf.shape(y)
ii, jj = tf.meshgrid(tf.range(y_shape[0]), tf.range(y_shape[2]), indexing='ij')
# Make full ND index
idx = tf.stack([ii, z, jj], axis=-1)
# Gather result
x = tf.gather_nd(y, idx)
# Test
with tf.Session() as sess:
# Numbers from 0 to 11 in a (3, 4) matrix
a = np.arange(12).reshape((3, 4))
# Make Y with replicas of the matrix multiplied by 1, 10 and 100
y_val = np.stack([a, a * 10, a * 100], axis=1).astype(np.float32)
# Z will be a (3, 4) matrix of values 0, 1, 2, 0, 1, 2, ...
z_val = (a % 3).astype(np.int32)
# X should have numbers from 0 to 11 multiplied by 1, 10, 100, 1, 10, 100, ...
x_val = sess.run(x, feed_dict={y: y_val, z: z_val}) #, feed_dict={y: y_val, z: z_val})
print(x_val)
输出:
[[ 0. 10. 200. 3.]
[ 40. 500. 6. 70.]
[ 800. 9. 100. 1100.]]
关于python - 使用gather_nd构建矩阵,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53192619/