c++ - Pytorch Torch.Cholesky忽略异常

标签 c++ pytorch libtorch

对于我批次中的某些矩阵,由于矩阵是奇异的,所以我有一个异常(exception)。

L = th.cholesky(Xt.bmm(X))

cholesky_cpu: For batch 51100: U(22,22) is zero, singular U



由于在我的用例中它们很少,因此我想忽略该异常并进一步处理它们。我将结果计算设置为 nan 是否可能?

实际上,如果我用catch异常并使用continue仍然无法完成其余批次的计算。

在带有Pytorch libtorch的C++中也是如此。

最佳答案

在执行cholesky分解时,PyTorch依赖于LAPACK用于CPU张量和MAGMA用于CUDA张量。在PyTorch code used to call LAPACK中,仅对批次进行迭代,分别在每个矩阵上调用LAPACK的 zpotrs_ 函数。在PyTorch code used to call MAGMA中,使用MAGMA的 magma_dpotrs_batched 处理整个批次,这可能比在每个矩阵上分别迭代要快。

AFAIK无法指示MAGMA或LAPACK不引发异常(尽管公平地说,我不是这些软件包的专家)。由于MAGMA可能以某种方式利用批处理,我们可能不希望仅默认使用迭代方法,因为我们可能由于不执行批处理的cholesky而失去性能。

一种可能的解决方案是首先尝试执行批处理的cholesky分解,如果失败,则可以对批处理中的每个元素执行cholesky分解,将失败的条目设置为NaN。

def cholesky_no_except(x, upper=False, force_iterative=False):
    success = False
    if not force_iterative:
        try:
            results = torch.cholesky(x, upper=upper)
            success = True
        except RuntimeError:
            pass

    if not success:
        # fall back to operating on each element separately
        results_list = []
        x_batched = x.reshape(-1, x.shape[-2], x.shape[-1])
        for batch_idx in range(x_batched.shape[0]):
            try:
                result = torch.cholesky(x_batched[batch_idx, :, :], upper=upper)
            except RuntimeError:
                # may want to only accept certain RuntimeErrors add a check here if that's the case
                # on failure create a "nan" matrix
                result = float('nan') + torch.empty(x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype)
            results_list.append(result)
        results = torch.cat(results_list, dim=0).reshape(*x.shape)

    return results

如果您希望异常在cholesky分解期间很常见,则您可能希望使用force_iterative=True跳过尝试使用批处理版本的初始调用,因为在这种情况下,此功能可能会浪费大量时间进行第一次尝试。

关于c++ - Pytorch Torch.Cholesky忽略异常,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60230464/

相关文章:

c++ - 在 Torch C++ API 中,如何快速写入张量的内部数据?

c++ - 如何在 C++ 中将 torch 模型定义为函数的输入

c++ - 我正在尝试设置一个文件流或类似的东西,但我很困惑我应该做什么

c++ - 使用 boost::signals2 编译时间非常慢

python - 正在构造 key :Value pair from list comprehension in Python

pytorch - 在pytorch中合并两个张量

c++ - pytorch C++与alexnet和cv::imread图像

c++ - 如何从libtorch输出中删除乘法器并显示最终结果?

c++ - 在未安装驱动程序的情况下连接ODBC

c++ - 检测浮点软件仿真