我正在为 pytorch 编写一个 C++ 扩展,并使用 c++ api 来执行此操作。对于我的 forward
函数,我需要传递一个可选的张量。在函数内部,我想根据是否传递了这个可选参数来做不同的事情。通常,我们在 C++ 中将 NULL 用作可选指针参数,并在函数内部检查指针是否为 NULL。我不知道如何为 at::Tensor
类型的 Torch 的 c++ api 执行此操作。
void xyz_forward(
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor optional_constraints = something)
{
if(optional_constraints){
//do something
}else{
//do something else
}
}
请注意,我不能执行 const at::Tensor optional_constraints = at::ones
或其他操作,因为该参数可以采用任何实际值并且可以具有不同的大小/形状。我不能为它分配一个数值作为可选参数。是否有与此等效的 NULL
?
最佳答案
一种可能是使用 std::optional作为std::optional<at::Tensor> optional_constraints = std::nullopt
.它可以根据上下文转换为 bool
, 所以你可以用 if (optional_constraints)
检查它.使用 .value()
如果传递一个,则获取张量的方法,否则默认值为 std::nullopt
.
关于c++ - PyTorch C++ 扩展中的可选张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54685444/