tensorflow - Tensorflow 中的 block 对角矩阵

标签 tensorflow

假设我有几个不同形状的张量 A_i [N_i, N_i]。在 tensorflow 中是否可以在对角线上使用这些矩阵创建块对角矩阵?我现在能想到的唯一方法就是通过堆叠和添加 tf.zeros 完全自己构建它。

最佳答案

我同意有一个 C++ op 来做这件事会很好。与此同时,这就是我所做的(正确获取静态形状信息有点繁琐):

import tensorflow as tf

def block_diagonal(matrices, dtype=tf.float32):
  r"""Constructs block-diagonal matrices from a list of batched 2D tensors.

  Args:
    matrices: A list of Tensors with shape [..., N_i, M_i] (i.e. a list of
      matrices with the same batch dimension).
    dtype: Data type to use. The Tensors in `matrices` must match this dtype.
  Returns:
    A matrix with the input matrices stacked along its main diagonal, having
    shape [..., \sum_i N_i, \sum_i M_i].

  """
  matrices = [tf.convert_to_tensor(matrix, dtype=dtype) for matrix in matrices]
  blocked_rows = tf.Dimension(0)
  blocked_cols = tf.Dimension(0)
  batch_shape = tf.TensorShape(None)
  for matrix in matrices:
    full_matrix_shape = matrix.get_shape().with_rank_at_least(2)
    batch_shape = batch_shape.merge_with(full_matrix_shape[:-2])
    blocked_rows += full_matrix_shape[-2]
    blocked_cols += full_matrix_shape[-1]
  ret_columns_list = []
  for matrix in matrices:
    matrix_shape = tf.shape(matrix)
    ret_columns_list.append(matrix_shape[-1])
  ret_columns = tf.add_n(ret_columns_list)
  row_blocks = []
  current_column = 0
  for matrix in matrices:
    matrix_shape = tf.shape(matrix)
    row_before_length = current_column
    current_column += matrix_shape[-1]
    row_after_length = ret_columns - current_column
    row_blocks.append(tf.pad(
        tensor=matrix,
        paddings=tf.concat(
            [tf.zeros([tf.rank(matrix) - 1, 2], dtype=tf.int32),
             [(row_before_length, row_after_length)]],
            axis=0)))
  blocked = tf.concat(row_blocks, -2)
  blocked.set_shape(batch_shape.concatenate((blocked_rows, blocked_cols)))
  return blocked

举个例子:
blocked_tensor = block_diagonal(
    [tf.constant([[1.]]),
     tf.constant([[1., 2.], [3., 4.]])])

with tf.Session():
  print(blocked_tensor.eval())

打印:
[[ 1.  0.  0.]
 [ 0.  1.  2.]
 [ 0.  3.  4.]]

关于tensorflow - Tensorflow 中的 block 对角矩阵,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42157781/

相关文章:

tensorflow - TensorFlow 中的上采样特征图

parallel-processing - Tensorflow - 是否可以手动决定图中的哪些张量进入 GPU 和 CPU?

python - 如何在 TensorFlow 中生成随机向量并维护它以供进一步使用?

python-3.x - Keras/Tensorflow 模型适用于验证图像,但不适用于真实世界数据

python - 有没有办法使用 Python(例如 : TensorFlow or Sci-kit learn libs) in Flutter apps?

python - 存储 session 时 tensorflow 中出现错误 "no Variable to save"

tensorflow - 在 tensorflow 检查点中修改张量的形状

python - 处理 CNN 中的维度错误时遇到问题

python - 如何在 Arch Linux 中安装 Python 3.8 和 Python 3.9?

python - 使用 Tensorflow 读取 PNG 文件