python - tensorflow 中 numpy.linalg.pinv 的替代方案

标签 python python-3.x numpy tensorflow

我正在寻找 tensorflow 中 numpy.linalg.pinv 的替代品。 到目前为止,我发现 tensorflow 只有 tf.matrix_inverse(input, adjoint=None, name=None) 如果矩阵不可逆,它会抛出错误。

最佳答案

TensorFlow 提供了 SVD 运算,因此您可以很容易地从中计算伪逆:

def pinv(A, b, reltol=1e-6):
  # Compute the SVD of the input matrix A
  s, u, v = tf.svd(A)

  # Invert s, clear entries lower than reltol*s[0].
  atol = tf.reduce_max(s) * reltol
  s = tf.boolean_mask(s, s > atol)
  s_inv = tf.diag(tf.concat([1. / s, tf.zeros([tf.size(b) - tf.size(s)])], 0))

  # Compute v * s_inv * u_t * b from the left to avoid forming large intermediate matrices.
  return tf.matmul(v, tf.matmul(s_inv, tf.matmul(u, tf.reshape(b, [-1, 1]), transpose_a=True)))

关于python - tensorflow 中 numpy.linalg.pinv 的替代方案,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42501715/

相关文章:

python - pretty-print 命名元组

python - 如何在 Mac 上将 Python3 设置为默认的 python 版本?

python - 在 pandas/python 中嵌套带有 .loc 的 if 语句

Python (numpy) 因大量数组元素而导致系统崩溃

用于扩展图像的 Python 函数(NumPy 数组)

python - 如何根据给定的标准将一个csv文件拆分为多个csv?

python - 为什么PyMySQL delete query执行但不删除记录而不会出错?

python - 有什么办法可以让整数被归类为没有 .0 的整数,而 float 被归类为 float 吗?

python - 保存 32 位浮点 TIFF 图像

python - 是否有一个 wsgi 服务器可以进行渐进式传输编码 : chunked