python - Tensorflow:使用沿第二轴的索引列表对 3D 张量进行切片

标签 python tensorflow

我有一个形状为:[batch_size,sentence_length,word_dim] 的占位符张量和一个 shape=[batch_size,num_indices] 的索引列表。索引位于第二个轴上,是句子中单词的索引。 Batch_size 和 Sentence_length 仅在运行时才知道。

如何提取形状为[batch_size, len(indices), word_dim]的张量?

我正在阅读有关 tensorflow.gather 的内容,但似乎只收集沿第一个轴的切片。我对么?

编辑:我设法让它持续工作

def tile_repeat(n, repTime):
    '''
    create something like 111..122..2333..33 ..... n..nn 
    one particular number appears repTime consecutively.
    This is for flattening the indices.
    '''
    print n, repTime
    idx = tf.range(n)
    idx = tf.reshape(idx, [-1, 1])    # Convert to a n x 1 matrix.
    idx = tf.tile(idx, [1, int(repTime)])  # Create multiple columns, each column has one number repeats repTime 
    y = tf.reshape(idx, [-1])
    return y

def gather_along_second_axis(x, idx):
    ''' 
    x has shape: [batch_size, sentence_length, word_dim]
    idx has shape: [batch_size, num_indices]
    Basically, in each batch, get words from sentence having index specified in idx
    However, since tensorflow does not fully support indexing,
    gather only work for the first axis. We have to reshape the input data, gather then reshape again
    '''
    reshapedIdx = tf.reshape(idx, [-1]) # [batch_size*num_indices]
    idx_flattened = tile_repeat(tf.shape(x)[0], tf.shape(x)[1]) * tf.shape(x)[1] + reshapedIdx
    y = tf.gather(tf.reshape(x, [-1,int(tf.shape(x)[2])]),  # flatten input
                idx_flattened)
    y = tf.reshape(y, tf.shape(x))
    return y

x = tf.constant([
            [[1,2,3],[3,5,6]],
            [[7,8,9],[10,11,12]],
            [[13,14,15],[16,17,18]]
    ])
idx=tf.constant([[0,1],[1,0],[1,1]])

y = gather_along_second_axis(x, idx)
with tf.Session(''):
    print y.eval()
    print tf.Tensor.get_shape(y)

输出是:

[[[ 1  2  3]
  [ 3  5  6]]
 [[10 11 12]
  [ 7  8  9]]
 [[16 17 18]
  [16 17 18]]]

形状:(3, 2, 3)

但是,当输入是占位符时,它不起作用返回错误:

idx = tf.tile(idx, [1, int(repTime)])  
TypeError: int() argument must be a string or a number, not 'Tensor'

Python 2.7, tensorflow 0.12

提前谢谢您。

最佳答案

感谢@AllenLavoie 的评论,我最终可以找到解决方案:

def tile_repeat(n, repTime):
    '''
    create something like 111..122..2333..33 ..... n..nn 
    one particular number appears repTime consecutively.
    This is for flattening the indices.
    '''
    print n, repTime
    idx = tf.range(n)
    idx = tf.reshape(idx, [-1, 1])    # Convert to a n x 1 matrix.
    idx = tf.tile(idx, [1, repTime])  # Create multiple columns, each column has one number repeats repTime 
    y = tf.reshape(idx, [-1])
    return y

def gather_along_second_axis(x, idx):
    ''' 
    x has shape: [batch_size, sentence_length, word_dim]
    idx has shape: [batch_size, num_indices]
    Basically, in each batch, get words from sentence having index specified in idx
    However, since tensorflow does not fully support indexing,
    gather only work for the first axis. We have to reshape the input data, gather then reshape again
    '''
    reshapedIdx = tf.reshape(idx, [-1]) # [batch_size*num_indices]
    idx_flattened = tile_repeat(tf.shape(x)[0], tf.shape(x)[1]) * tf.shape(x)[1] + reshapedIdx
    y = tf.gather(tf.reshape(x, [-1,tf.shape(x)[2]]),  # flatten input
                idx_flattened)
    y = tf.reshape(y, tf.shape(x))
    return y

x = tf.constant([
            [[1,2,3],[3,5,6]],
            [[7,8,9],[10,11,12]],
            [[13,14,15],[16,17,18]]
    ])
idx=tf.constant([[0,1],[1,0],[1,1]])

y = gather_along_second_axis(x, idx)
with tf.Session(''):
    print y.eval()
    print tf.Tensor.get_shape(y)

关于python - Tensorflow:使用沿第二轴的索引列表对 3D 张量进行切片,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43933882/

相关文章:

hadoop - 使用 tensorflow 从Hive表中读取数据

python - 如何在 tensorflow 中动态添加新节点/神经元

python - 如何从 craigslist 中仅抓取低于 x 的价格

tensorflow - 初始化 LSTM 隐藏状态 Tensorflow/Keras

python - 为什么截断不能正确默认为文件的当前位置?

python - Pandas :将列转换为列表

python - Keras 自定义目标需要张量评估

tensorflow - MNIST 识别手写文字

python - 判断 python 是否处于 -i 模式

Python:如何中断,然后返回到 while 循环,没有 goto?