我有一个形状为:[batch_size,sentence_length,word_dim]
的占位符张量和一个 shape=[batch_size,num_indices]
的索引列表。索引位于第二个轴上,是句子中单词的索引。 Batch_size 和 Sentence_length
仅在运行时才知道。
如何提取形状为[batch_size, len(indices), word_dim]
的张量?
我正在阅读有关 tensorflow.gather
的内容,但似乎只收集沿第一个轴的切片。我对么?
编辑:我设法让它持续工作
def tile_repeat(n, repTime):
'''
create something like 111..122..2333..33 ..... n..nn
one particular number appears repTime consecutively.
This is for flattening the indices.
'''
print n, repTime
idx = tf.range(n)
idx = tf.reshape(idx, [-1, 1]) # Convert to a n x 1 matrix.
idx = tf.tile(idx, [1, int(repTime)]) # Create multiple columns, each column has one number repeats repTime
y = tf.reshape(idx, [-1])
return y
def gather_along_second_axis(x, idx):
'''
x has shape: [batch_size, sentence_length, word_dim]
idx has shape: [batch_size, num_indices]
Basically, in each batch, get words from sentence having index specified in idx
However, since tensorflow does not fully support indexing,
gather only work for the first axis. We have to reshape the input data, gather then reshape again
'''
reshapedIdx = tf.reshape(idx, [-1]) # [batch_size*num_indices]
idx_flattened = tile_repeat(tf.shape(x)[0], tf.shape(x)[1]) * tf.shape(x)[1] + reshapedIdx
y = tf.gather(tf.reshape(x, [-1,int(tf.shape(x)[2])]), # flatten input
idx_flattened)
y = tf.reshape(y, tf.shape(x))
return y
x = tf.constant([
[[1,2,3],[3,5,6]],
[[7,8,9],[10,11,12]],
[[13,14,15],[16,17,18]]
])
idx=tf.constant([[0,1],[1,0],[1,1]])
y = gather_along_second_axis(x, idx)
with tf.Session(''):
print y.eval()
print tf.Tensor.get_shape(y)
输出是:
[[[ 1 2 3]
[ 3 5 6]]
[[10 11 12]
[ 7 8 9]]
[[16 17 18]
[16 17 18]]]
形状:(3, 2, 3)
但是,当输入是占位符时,它不起作用返回错误:
idx = tf.tile(idx, [1, int(repTime)])
TypeError: int() argument must be a string or a number, not 'Tensor'
Python 2.7, tensorflow 0.12
提前谢谢您。
最佳答案
感谢@AllenLavoie 的评论,我最终可以找到解决方案:
def tile_repeat(n, repTime):
'''
create something like 111..122..2333..33 ..... n..nn
one particular number appears repTime consecutively.
This is for flattening the indices.
'''
print n, repTime
idx = tf.range(n)
idx = tf.reshape(idx, [-1, 1]) # Convert to a n x 1 matrix.
idx = tf.tile(idx, [1, repTime]) # Create multiple columns, each column has one number repeats repTime
y = tf.reshape(idx, [-1])
return y
def gather_along_second_axis(x, idx):
'''
x has shape: [batch_size, sentence_length, word_dim]
idx has shape: [batch_size, num_indices]
Basically, in each batch, get words from sentence having index specified in idx
However, since tensorflow does not fully support indexing,
gather only work for the first axis. We have to reshape the input data, gather then reshape again
'''
reshapedIdx = tf.reshape(idx, [-1]) # [batch_size*num_indices]
idx_flattened = tile_repeat(tf.shape(x)[0], tf.shape(x)[1]) * tf.shape(x)[1] + reshapedIdx
y = tf.gather(tf.reshape(x, [-1,tf.shape(x)[2]]), # flatten input
idx_flattened)
y = tf.reshape(y, tf.shape(x))
return y
x = tf.constant([
[[1,2,3],[3,5,6]],
[[7,8,9],[10,11,12]],
[[13,14,15],[16,17,18]]
])
idx=tf.constant([[0,1],[1,0],[1,1]])
y = gather_along_second_axis(x, idx)
with tf.Session(''):
print y.eval()
print tf.Tensor.get_shape(y)
关于python - Tensorflow:使用沿第二轴的索引列表对 3D 张量进行切片,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43933882/