python - 从 Numpy 或 Tensorflow 中的线性数组矢量化创建对角方阵数组

标签 python arrays numpy tensorflow vectorization

我有一个形状为 [batch_size, N] 的数组,例如:

[[1  2]
 [3  4]
 [5  6]]

我需要创建一个形状为 [batch_size, N, N] 的 3 索引数组,其中对于每个 batch 我都有一个 N x N 对角矩阵,其中对角线由相应的 batch 元素取,例如在这种情况下,在这种简单的情况下,我要查找的结果是:

[
  [[1,0],[0,2]],
  [[3,0],[0,4]],
  [[5,0],[0,6]],
]

如何在没有 for 循环和矢量化的情况下进行此操作?我想这是维度的扩展,但我找不到正确的功能来做到这一点。 (我需要它,因为我正在使用 tensorflow 并使用 numpy 制作原型(prototype))。

最佳答案

在 tensorflow 中试试:

import tensorflow as tf
A = [[1,2],[3 ,4],[5,6]]
B = tf.matrix_diag(A)
print(B.eval(session=tf.Session()))
[[[1 0]
  [0 2]]

 [[3 0]
  [0 4]]

 [[5 0]
  [0 6]]]

关于python - 从 Numpy 或 Tensorflow 中的线性数组矢量化创建对角方阵数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53741481/

相关文章:

javascript - VueJS、Vuex、Getter 显示为空数组,但 console.log 显示它是一个包含所有值的对象

python - 将数据填充到矩阵中,传递其定义的索引遇到 IndexError 的元素

python - 将 c++ 双指针传递给 python

python - np.array 中特定长度的连续部分的 Min-Max 差异

python - 有视频的opencv trackbar显示时间

python - 如何在函数中使用 Xpath 和 CSS 选择器

python - 下载后文件的可执行权限发生变化。是否有任何协议(protocol)或方法可以完好无损许可?

python - 如何以编程方式创建销售订单行 (Odoo 13)

java - Android 在从服务器接收时读取 JsonArray 中的不同数据

arrays - swift 平面 map : How to remove only tuples from an array where specific element in the tuple is nil?