c++ - PyTorch C++ 扩展中的可选张量

标签 c++ pytorch torch

我正在为 pytorch 编写一个 C++ 扩展,并使用 c++ api 来执行此操作。对于我的 forward 函数,我需要传递一个可选的张量。在函数内部,我想根据是否传递了这个可选参数来做不同的事情。通常,我们在 C++ 中将 NULL 用作可选指针参数,并在函数内部检查指针是否为 NULL。我不知道如何为 at::Tensor 类型的 Torch 的 c++ api 执行此操作。

void xyz_forward(
    const at::Tensor xyz1, 
    const at::Tensor xyz2, 
    const at::Tensor optional_constraints = something)
{
     if(optional_constraints){
        //do something
     }else{
        //do something else
     }
}

请注意,我不能执行 const at::Tensor optional_constraints = at::ones 或其他操作,因为该参数可以采用任何实际值并且可以具有不同的大小/形状。我不能为它分配一个数值作为可选参数。是否有与此等效的 NULL

最佳答案

一种可能是使用 std::optional作为std::optional<at::Tensor> optional_constraints = std::nullopt .它可以根据上下文转换为 bool , 所以你可以用 if (optional_constraints) 检查它.使用 .value()如果传递一个,则获取张量的方法,否则默认值为 std::nullopt .

关于c++ - PyTorch C++ 扩展中的可选张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54685444/

相关文章:

python - torch 中的 w -= dw 与 w = w - dw

c++ - 如何获取指向专门针对字符串的模板函数的指针?

c++ - 两个单独的键映射到 std::map 中的单个条目

c++ - CString 内崩溃

pytorch - 如何将两个 PyTorch 量化张量矩阵相乘?

pytorch - 导入 torch 时如何修复 'cannot initialize type TensorProto DataType'错误?

arrays - torch /Lua,如何选择数组或张量的子集?

java - 将自定义 Java 数据模型传递到我的 native 代码

python - 如何使用pytorch进行类似于numpy的trapz函数的数值积分?

deep-learning - PyTorch 等深度学习框架在使用多个 GPU 时如何处理内存?