c++ - 采用 Eigen::Tensor 的函数 - 模板参数推导失败

标签 c++ templates tensor eigen3

我正在尝试编写一个以 Eigen::Tensor 作为参数的模板函数。适用于 Eigen::Matrix 等的相同方法在这里不起作用。

Eigen 建议使用公共(public)基类编写函数。 https://eigen.tuxfamily.org/dox/TopicFunctionTakingEigenTypes.html

编译的 Eigen::Matrix 的最小示例:

#include <Eigen/Dense>

template <typename Derived>
void func(Eigen::MatrixBase<Derived>& a) 
{
    a *= 2;
}

int main()
{
    Eigen::Matrix<int, 2, 2> matrix;
    func(matrix);
}

Eigen::Tensor 的最小示例不编译:

#include <unsupported/Eigen/CXX11/Tensor>

template <typename Derived>
void func(Eigen::TensorBase<Derived>& a)
{
    a *= 2;
}

int main()
{
    Eigen::Tensor<int, 1> tensor;
    func(tensor);
}
$ g++ -std=c++11 -I /usr/include/eigen3 eigen_tensor_func.cpp
eigen_tensor_func.cpp: In function ‘int main()’:
eigen_tensor_func.cpp:12:16: error: no matching function for call to ‘func(Eigen::Tensor<int, 1>&)’
     func(tensor);
                ^
eigen_tensor_func.cpp:4:6: note: candidate: ‘template<class Derived> void func(Eigen::TensorBase<Derived>&)’
 void func(Eigen::TensorBase<Derived>& a)
      ^~~~
eigen_tensor_func.cpp:4:6: note:   template argument deduction/substitution failed:
eigen_tensor_func.cpp:12:16: note:   ‘Eigen::TensorBase<Derived>’ is an ambiguous base class of ‘Eigen::Tensor<int, 1>’
     func(tensor);

最佳答案

Tensor-Module 距离与 Eigen/Core 功能完全兼容还有很长的路要走(当然,这也意味着核心功能的文档不一定适用于 Tensor-Module)。

第一个主要区别是 TensorBase采用两个模板参数而不是一个,即,您需要编写 TensorBase<Derived, Eigen::WriteAccessors> .还有一些功能要么根本没有实现,要么 TensorBase没有正确转发它。以下适用于当前主干 (2019-04-03):

template <typename Derived>
void func(Eigen::TensorBase<Derived, Eigen::WriteAccessors>& a)
{
    // a *= 2;  // operator*=(Scalar) not implemented
    // a = 2*a; // operator=(...) not implemented/forwarded
    a *= a;     // ok
    a *= 2*a;   // ok
    a *= 0*a+2; // ok

    // a.derived() = 2*a; // derived() is not public
    static_cast<Derived&>(a) = a*2; // ok
}

关于c++ - 采用 Eigen::Tensor 的函数 - 模板参数推导失败,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55491596/

相关文章:

c++ - 无法理解我的程序的输出

c++ - 在 OpenCV 中保存图像

带有模板的 C++ header 顺序

c++ - 内部模板类 C++

python - PyTorch 稀疏张量的维数必须为 nDimI + nDimV

c++ - pytorch/libtorch C++ 中的自定义子模块

c++ - 在模板参数中查找 typename 的 typename

python-3.x - PyTorch 中的 nn.functional() 与 nn.sequential() 之间是否存在计算效率差异

python - 如何将 PyTorch 张量分块到指定的桶大小并重叠?

c++ - 将指针传递给 C++ 中的方法导致奇怪的输出