python - 将行添加到 TensorFlow 张量批处理

标签 python tensorflow

我有 3 级张量 [batch_size, num_rows, num_cols]),我想向其追加适当大小的行,从而产生维度为 [batch_size, num_rows + 1, num_cols] 的 3 级张量

例如,如果我有以下一批 2x2 矩阵

batch = [ [[2, 2],
           [2, 2]],
          [[3, 3],
           [3, 3]],
          [[4, 4],
           [4, 4]] ]

还有一个新行v = [1, 1]我想追加,那么期望的结果是

new_batch = [ [[2, 2],
               [2, 2],
               [1, 1]], 
              [[3, 3],
               [3, 3],
               [1, 1]],
              [[4, 4],
               [4, 4],
               [1, 1]] ]

在 TensorFlow 中是否有一种简单的方法可以做到这一点?这是我尝试过的:

W, b, c0, q0 = params
c = tf.concat([context, c0], axis=1)
q_p = tf.tanh(tf.matmul(W, question) + b)
q = tf.concat([q_p, q0], axis=1)
q_mask = tf.concat([question_mask, 1], axis=1)

为了澄清这些术语,

  1. context 具有尺寸[batch_size, context_len,hidden_​​size]
  2. q_p 具有尺寸[batch_size, Question_len,hidden_​​size]
  3. question_mask 的尺寸为 [batch_size, Question_len]
  4. c0q0 都有尺寸 [hidden_​​size]

我想做什么

  1. 将向量 c0 添加到 context,生成尺寸为 [batch_size, context_len + 1,hidden_​​size] 的张量
  2. 将向量 q0 添加到 q_p,得到维度为 [batch_size, Question_len + 1, hide_size] 的张量
  3. question_mask 添加 1,生成尺寸为 [batch_size, Question_len + 1] 的张量

感谢您的帮助。

最佳答案

您可以使用tf.map_fn来做到这一点。

batch = [ [[2, 2],
           [2, 2]],
          [[3, 3],
           [3, 3]],
          [[4, 4],
           [4, 4]] ]

row_to_add = [1,1]

t = tf.convert_to_tensor(batch, dtype=np.float32)
appended_t = tf.map_fn(lambda x: tf.concat((x, [row_to_add]), axis=0), t)

输出

appended_t.eval(session=tf.Session())

array([[[ 2.,  2.],
        [ 2.,  2.],
        [ 1.,  1.]],

       [[ 3.,  3.],
        [ 3.,  3.],
        [ 1.,  1.]],

       [[ 4.,  4.],
        [ 4.,  4.],
        [ 1.,  1.]]], dtype=float32)

关于python - 将行添加到 TensorFlow 张量批处理,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49243416/

相关文章:

python - xlwt 限制行数

python - native 应用程序(混合)使用 django-social-auth

javascript - 无法在页面中找到元素。检查元素显示与源页面不同的 HTML Python - Selenium

python - Queue.dequeue 卡在 Tensorflow 输入管道中

python - 在 Windows 中使用 pip 安装 Tensorflow 时出现问题?找不到 tensorflow 的匹配分布

tensorflow - 如何想象在具有 3 个颜色 channel 的图像上进行卷积/池化

tensorflow - 为什么 Keras 会抛出 ResourceExhaustedError?

python - 使用序号创建 3D NumPy 数组

tensorflow - 循环时更改 tensorflow session 中的常量

python - 复合类型和标量类型的生命周期有什么区别?