tensorflow - tensorflow 索引如何工作

标签 tensorflow

我无法理解 tensorflow 的基本概念。张量读/写操作的索引如何工作?为了具体说明,如何将以下 numpy 示例转换为 tensorflow(使用张量来分配数组、索引和值):

x = np.zeros((3, 4))
row_indices = np.array([1, 1, 2])
col_indices = np.array([0, 2, 3])
x[row_indices, col_indices] = 2
x

带输出:
array([[ 0.,  0.,  0.,  0.],
       [ 2.,  0.,  2.,  0.],
       [ 0.,  0.,  0.,  2.]])

... 和 ...
x[row_indices, col_indices] = np.array([5, 4, 3])
x

带输出:
array([[ 0.,  0.,  0.,  0.],
       [ 5.,  0.,  4.,  0.],
       [ 0.,  0.,  0.,  3.]])

……最后……
y = x[row_indices, col_indices]
y

带输出:
array([ 5.,  4.,  3.])

最佳答案

有 github 问题 #206为了很好地支持这一点,同时你必须诉诸冗长的解决方法

第一个例子可以用 tf.select 完成通过从一个或另一个中选择每个元素来组合两个相同形状的张量

tf.reset_default_graph()
row_indices = tf.constant([1, 1, 2])
col_indices = tf.constant([0, 2, 3])
x = tf.zeros((3, 4))
sess = tf.InteractiveSession()

# get list of ((row1, col1), (row2, col2), ..)
coords = tf.transpose(tf.pack([row_indices, col_indices]))

# get tensor with 1's at positions (row1, col1),...
binary_mask = tf.sparse_to_dense(coords, x.get_shape(), 1)

# convert 1/0 to True/False
binary_mask = tf.cast(binary_mask, tf.bool)

twos = 2*tf.ones(x.get_shape())

# make new x out of old values or 2, depending on mask 
x = tf.select(binary_mask, twos, x)

print x.eval()



[[ 0.  0.  0.  0.]
 [ 2.  0.  2.  0.]
 [ 0.  0.  0.  2.]]

第二个可以用 scatter_update 完成,除了 scatter_update仅支持线性索引并适用于变量。所以你可以创建一个临时变量并像这样使用整形。 (为了避免变量,你可以使用 dynamic_stitch ,见最后)

# get linear indices
linear_indices = row_indices*x.get_shape()[1]+col_indices

# turn 'x' into 1d variable since "scatter_update" supports linear indexing only
x_flat = tf.Variable(tf.reshape(x, [-1]))

# no automatic promotion, so make updates float32 to match x
updates = tf.constant([5, 4, 3], dtype=tf.float32)

sess.run(tf.initialize_all_variables())
sess.run(tf.scatter_update(x_flat, linear_indices,  updates))

# convert back into original shape
x = tf.reshape(x_flat, x.get_shape())

print x.eval()



[[ 0.  0.  0.  0.]
 [ 5.  0.  4.  0.]
 [ 0.  0.  0.  3.]]

最后,gather_nd 已经支持第三个示例。 , 你写

print tf.gather_nd(x, coords).eval()

要得到

[ 5.  4.  3.]

编辑,5 月 6 日

更新 x[cols,rows]=newvals可以通过使用 select 在不使用变量(在 session 运行调用之间占用内存)的情况下完成。与 sparse_to_dense采用稀疏值向量,或依赖 dynamic_stitch
sess = tf.InteractiveSession()
x = tf.zeros((3, 4))
row_indices = tf.constant([1, 1, 2])
col_indices = tf.constant([0, 2, 3])

# no automatic promotion, so specify float type
replacement_vals = tf.constant([5, 4, 3], dtype=tf.float32)

# convert to linear indexing in row-major form
linear_indices = row_indices*x.get_shape()[1]+col_indices
x_flat = tf.reshape(x, [-1])

# use dynamic stitch, it merges the array by taking value either
# from array1[index1] or array2[index2], if indices conflict,
# the later one is used 
unchanged_indices = tf.range(tf.size(x_flat))
changed_indices = linear_indices
x_flat = tf.dynamic_stitch([unchanged_indices, changed_indices],
                           [x_flat, replacement_vals])
x = tf.reshape(x_flat, x.get_shape())
print x.eval()

关于tensorflow - tensorflow 索引如何工作,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37061808/

相关文章:

tensorflow - 如何使用tensorflow-gpu GPUOptions

python - 使用 VGG16 预训练模型处理灰度图像时出错

python - `tf.reshape(a, [m, n])` 和 `tf.transpose(tf.reshape(a, [n, m]))` 之间的区别?

python - 强化学习如何通过高斯策略进行连续控制?

python - Tensorflow - 范围明智回归损失

c++ - tensorflow c++代码SessionFactory::GetFactory如何选择使用哪个 session ?Direct还是Grpc Session?

python - 了解 model.summary Keras

python - 在 TensorFlow 的低级 API 中,是否可以使用优化器保存图形并在另一个文件中继续训练?

python - Tensorflow 导入错误 : No module named '_pywrap_tensorflow_internal' on Windows 10

tensorflow - 双向 RNN 单元 - 共享与否?