对于我批次中的某些矩阵,由于矩阵是奇异的,所以我有一个异常(exception)。
L = th.cholesky(Xt.bmm(X))
cholesky_cpu: For batch 51100: U(22,22) is zero, singular U
由于在我的用例中它们很少,因此我想忽略该异常并进一步处理它们。我将结果计算设置为 nan 是否可能?
实际上,如果我用
catch
异常并使用continue
仍然无法完成其余批次的计算。在带有Pytorch libtorch的C++中也是如此。
最佳答案
在执行cholesky分解时,PyTorch依赖于LAPACK用于CPU张量和MAGMA用于CUDA张量。在PyTorch code used to call LAPACK中,仅对批次进行迭代,分别在每个矩阵上调用LAPACK的 zpotrs_
函数。在PyTorch code used to call MAGMA中,使用MAGMA的 magma_dpotrs_batched
处理整个批次,这可能比在每个矩阵上分别迭代要快。
AFAIK无法指示MAGMA或LAPACK不引发异常(尽管公平地说,我不是这些软件包的专家)。由于MAGMA可能以某种方式利用批处理,我们可能不希望仅默认使用迭代方法,因为我们可能由于不执行批处理的cholesky而失去性能。
一种可能的解决方案是首先尝试执行批处理的cholesky分解,如果失败,则可以对批处理中的每个元素执行cholesky分解,将失败的条目设置为NaN。
def cholesky_no_except(x, upper=False, force_iterative=False):
success = False
if not force_iterative:
try:
results = torch.cholesky(x, upper=upper)
success = True
except RuntimeError:
pass
if not success:
# fall back to operating on each element separately
results_list = []
x_batched = x.reshape(-1, x.shape[-2], x.shape[-1])
for batch_idx in range(x_batched.shape[0]):
try:
result = torch.cholesky(x_batched[batch_idx, :, :], upper=upper)
except RuntimeError:
# may want to only accept certain RuntimeErrors add a check here if that's the case
# on failure create a "nan" matrix
result = float('nan') + torch.empty(x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype)
results_list.append(result)
results = torch.cat(results_list, dim=0).reshape(*x.shape)
return results
如果您希望异常在cholesky分解期间很常见,则您可能希望使用
force_iterative=True
跳过尝试使用批处理版本的初始调用,因为在这种情况下,此功能可能会浪费大量时间进行第一次尝试。
关于c++ - Pytorch Torch.Cholesky忽略异常,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60230464/