c++ - 在割炬C++中创建BoolTensor蒙版

标签 c++ libtorch

我正在尝试创建BoolTensor类型的C++中的火炬 mask 。维度n中的第一个one元素需要为False,其余元素需要为True
这是我的尝试,但我不知道这是否正确(大小是元素数):

src_mask = torch::BoolTensor({6, 1});
src_mask[:size,:] = 0;
src_mask[size:,:] = 1;

最佳答案

我不确定在这里确切地了解您的目标,因此这是我将伪代码转换为C++的最佳尝试。
首先,使用libtorch,您可以通过torch::TensorOptions结构声明张量的类型(类型名称以小写k开头)
其次,借助torch::Tensor::slice函数,可以进行类似python的 slice (请参阅herethere)。
最后,您会得到类似:

// Creates a tensor of boolean, initially all ones
auto options = torch::TensorOptions().dtype(torch::kBool));
torch::Tensor bool_tensor = torch::ones({6,1}, options);
// Set the slice to 0
int size = 3;
bool_tensor.slice(/*dim=*/0, /*start=*/0, /*end=*/size) = 0;

std::cout << bool_tensor << std::endl;
请不要因为这会将第一个size行设置为0。我假设这就是“x维度中的第一个元素”的意思。

关于c++ - 在割炬C++中创建BoolTensor蒙版,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63186307/

相关文章:

c++ - 从 Clang AST 中的 CXXConstructExpr 中检索模板参数

C++:声明指针

c++ - Raspberry 上的 Libtorch 无法加载 pt 文件但在 ubuntu 上工作

c++ - libtorch:如何从tensorRT fp16半类型指针创建张量?

c++ - CString 用于静态方法

c++ - 无用的反斜杠会产生定义明确的字符串常量吗?

c++ - 关于c++中链表/指针的问题

c++ - 如何用 C++ 数组填充 torch::tensor?

python - Pytorch C++ (Libtroch),使用操作间并行性

c++ - 在 libtorch 中使用 forward 的非法指令(核心转储)