c++ - Eigen:可修改的自定义表达式

标签 c++ customization eigen eigen3

我正在尝试使用 Eigen 实现可修改的自定义表达式,类似于 question .基本上,我想要的是类似于 tutorial 中的索引示例的东西。 , 但可以为所选系数分配新值。

正如上述问题中接受的答案所建议的那样,我已经研究了 Transpose实现并尝试了很多事情,但没有成功。基本上,我的尝试失败并出现类似 'Eigen::internal::evaluator<SrcXprType>::evaluator(const Eigen::internal::evaluator<SrcXprType> &)': cannot convert argument 1 from 'const Eigen::Indexing<Derived>' to 'Eigen::Indexing<Derived> &' 的错误.可能问题出在我的evaluator似乎是只读的结构。

namespace Eigen {
namespace internal {
    template<typename ArgType>
    struct evaluator<Indexing<ArgType> >
        : evaluator_base<Indexing<ArgType> >
    {
        typedef Indexing<ArgType> XprType;
        typedef typename nested_eval<ArgType, XprType::ColsAtCompileTime>::type ArgTypeNested;
        typedef typename remove_all<ArgTypeNested>::type ArgTypeNestedCleaned;
        typedef typename XprType::CoeffReturnType CoeffReturnType;
        typedef typename traits<ArgType>::Scalar Scalar;
        enum {
            CoeffReadCost = evaluator<ArgTypeNestedCleaned>::CoeffReadCost,
            Flags = Eigen::ColMajor
        };

        evaluator(XprType& xpr)
            : m_argImpl(xpr.m_arg), m_rows(xpr.rows())
        { }
        const Scalar& coeffRef(Index row, Index col) const
        {
             return m_argImpl.coeffRef(... very clever stuff ...)
        }

        Scalar& coeffRef(Index row, Index col)
        {
             return m_argImpl.coeffRef(... very clever stuff ...)
        }

        evaluator<ArgTypeNestedCleaned> m_argImpl;
        const Index m_rows;
    };
}
}

此外,我已经更改了所有出现的 typedef typename Eigen::internal::ref_selector<ArgType>::type...::non_const_type , 但这没有效果。

由于 Eigen 库的复杂性,我不知道如何正确地将表达式和求值器拼在一起。我不明白,为什么我的评估器是只读的,或者如何获得一个可写的评估器。 如果有人可以为可修改的自定义表达式提供一个最小示例,那就太好了。

最佳答案

在 ggael 的提示的帮助下,我已经能够成功地添加我自己的可修改表达式。 Eigen开发分支的IndexedView我已经基本适配了。

由于最初请求的功能包含在 IndexedView 中,因此我编写了一个可修改的循环移位函数作为可修改的自定义表达式的简单示例。大部分代码直接取自 IndexedView,因此归功于其作者。

// circ_shift.h
#pragma once
#include <Eigen/Core>

namespace helper
{
        namespace detail
    {
        template <typename T>
        constexpr std::true_type is_matrix(Eigen::MatrixBase<T>);
        std::false_type constexpr is_matrix(...);

        template <typename T>
        constexpr std::true_type is_array(Eigen::ArrayBase<T>);
        std::false_type constexpr is_array(...);
    }


    template <typename T>
    struct is_matrix : decltype(detail::is_matrix(std::declval<std::remove_cv_t<T>>()))
    {
    };

    template <typename T>
    struct is_array : decltype(detail::is_array(std::declval<std::remove_cv_t<T>>()))
    {
    };

    template <typename T>
    using is_matrix_or_array = std::bool_constant<is_array<T>::value || is_matrix<T>::value>;



    /*
     * Index something if it's not an scalar
     */
    template <typename T, typename std::enable_if<is_matrix_or_array<T>::value, int>::type = 0>
    auto index_if_necessary(T&& thing, Eigen::Index idx)
    {
        return thing(idx);
    }

    /*
    * Overload for scalar.
    */
    template <typename T, typename std::enable_if<std::is_scalar<std::decay_t<T>>::value, int>::type = 0>
    auto index_if_necessary(T&& thing, Eigen::Index)
    {
        return thing;
    }
}

namespace Eigen
{
    template <typename XprType, typename RowIndices, typename ColIndices>
    class CircShiftedView;

    namespace internal
    {
        template <typename XprType, typename RowIndices, typename ColIndices>
        struct traits<CircShiftedView<XprType, RowIndices, ColIndices>>
            : traits<XprType>
        {
            enum
            {
                RowsAtCompileTime = traits<XprType>::RowsAtCompileTime,
                ColsAtCompileTime = traits<XprType>::ColsAtCompileTime,
                MaxRowsAtCompileTime = RowsAtCompileTime != Dynamic ? int(RowsAtCompileTime) : int(traits<XprType>::MaxRowsAtCompileTime),
                MaxColsAtCompileTime = ColsAtCompileTime != Dynamic ? int(ColsAtCompileTime) : int(traits<XprType>::MaxColsAtCompileTime),

                XprTypeIsRowMajor = (int(traits<XprType>::Flags) & RowMajorBit) != 0,
                IsRowMajor = (MaxRowsAtCompileTime == 1 && MaxColsAtCompileTime != 1) ? 1
                                 : (MaxColsAtCompileTime == 1 && MaxRowsAtCompileTime != 1) ? 0
                                 : XprTypeIsRowMajor,


                FlagsRowMajorBit = IsRowMajor ? RowMajorBit : 0,
                FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0,
                Flags = (traits<XprType>::Flags & HereditaryBits) | FlagsLvalueBit | FlagsRowMajorBit
            };
        };
    }

    template <typename XprType, typename RowShift, typename ColShift, typename StorageKind>
    class CircShiftedViewImpl;


    template <typename XprType, typename RowShift, typename ColShift>
    class CircShiftedView : public CircShiftedViewImpl<XprType, RowShift, ColShift, typename internal::traits<XprType>::StorageKind>
    {
    public:
        typedef typename CircShiftedViewImpl<XprType, RowShift, ColShift, typename internal::traits<XprType>::StorageKind>::Base Base;
        EIGEN_GENERIC_PUBLIC_INTERFACE(CircShiftedView)
        EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CircShiftedView)

        typedef typename internal::ref_selector<XprType>::non_const_type MatrixTypeNested;
        typedef typename internal::remove_all<XprType>::type NestedExpression;

        template <typename T0, typename T1>
        CircShiftedView(XprType& xpr, const T0& rowShift, const T1& colShift)
            : m_xpr(xpr), m_rowShift(rowShift), m_colShift(colShift)
        {
            for (auto c = 0; c < xpr.cols(); ++c)
            assert(std::abs(helper::index_if_necessary(m_rowShift, c)) < m_xpr.rows()); // row shift must be within +- rows()-1
            for (auto r = 0; r < xpr.rows(); ++r)
            assert(std::abs(helper::index_if_necessary(m_colShift, r)) < m_xpr.cols()); // col shift must be within +- cols()-1
        }

        /** \returns number of rows */
        Index rows() const { return m_xpr.rows(); }

        /** \returns number of columns */
        Index cols() const { return m_xpr.cols(); }

        /** \returns the nested expression */
        const typename internal::remove_all<XprType>::type&
        nestedExpression() const { return m_xpr; }

        /** \returns the nested expression */
        typename internal::remove_reference<XprType>::type&
        nestedExpression() { return m_xpr.const_cast_derived(); }

        EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
        Index getRowIdx(Index row, Index col) const
        {
            Index R = m_xpr.rows();
            assert(row >= 0 && row < R && col >= 0 && col < m_xpr.cols());
            Index r = row - helper::index_if_necessary(m_rowShift, col);
            if (r >= R)
                return r - R;
            if (r < 0)
                return r + R;
            return r;
        }

        EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
        Index getColIdx(Index row, Index col) const
        {
            Index C = m_xpr.cols();
            assert(row >= 0 && row < m_xpr.rows() && col >= 0 && col < C);
            Index c = col - helper::index_if_necessary(m_colShift, row);
            if (c >= C)
                return c - C;
            if (c < 0)
                return c + C;
            return c;
        }

    protected:
        MatrixTypeNested m_xpr;
        RowShift m_rowShift;
        ColShift m_colShift;
    };


    // Generic API dispatcher
    template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>
    class CircShiftedViewImpl
        : public internal::generic_xpr_base<CircShiftedView<XprType, RowIndices, ColIndices>>::type
    {
    public:
        typedef typename internal::generic_xpr_base<CircShiftedView<XprType, RowIndices, ColIndices>>::type Base;
    };

    namespace internal
    {
        template <typename ArgType, typename RowIndices, typename ColIndices>
        struct unary_evaluator<CircShiftedView<ArgType, RowIndices, ColIndices>, IndexBased>
            : evaluator_base<CircShiftedView<ArgType, RowIndices, ColIndices>>
        {
            typedef CircShiftedView<ArgType, RowIndices, ColIndices> XprType;

            enum
            {
                CoeffReadCost = evaluator<ArgType>::CoeffReadCost + NumTraits<Index>::AddCost /* for comparison */ + NumTraits<Index>::AddCost /*for addition*/,

                Flags = (evaluator<ArgType>::Flags & HereditaryBits),

                Alignment = 0
            };

            EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& xpr) : m_argImpl(xpr.nestedExpression()), m_xpr(xpr)
            {
                EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
            }

            typedef typename XprType::Scalar Scalar;
            typedef typename XprType::CoeffReturnType CoeffReturnType;


            EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
            CoeffReturnType coeff(Index row, Index col) const
            {
                return m_argImpl.coeff(m_xpr.getRowIdx(row, col), m_xpr.getColIdx(row, col));
            }

            EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
            Scalar& coeffRef(Index row, Index col)
            {
                assert(row >= 0 && row < m_xpr.rows() && col >= 0 && col < m_xpr.cols());

                return m_argImpl.coeffRef(m_xpr.getRowIdx(row, col), m_xpr.getColIdx(row, col));
            }

        protected:

            evaluator<ArgType> m_argImpl;
            const XprType& m_xpr;
        };
    } // end namespace internal
} // end namespace Eigen


template <typename XprType, typename RowShift, typename ColShift>
auto circShift(Eigen::DenseBase<XprType>& x, RowShift r, ColShift c)
{
    return Eigen::CircShiftedView<XprType, RowShift, ColShift>(x.derived(), r, c);
}

和:

// main.cpp
#include "stdafx.h"
#include "Eigen/Core"
#include <iostream>
#include "circ_shift.h"

using namespace Eigen;


int main()
{

    ArrayXXf x(4, 2);
    x.transpose() << 1, 2, 3, 4, 10, 20, 30, 40;


    Vector2i rowShift;
    rowShift << 3, -3; // rotate col 1 by 3 and col 2 by -3

    Index colShift = 1; // flip columns

    auto shifted = circShift(x, rowShift, colShift);

    std::cout << "shifted: " << std::endl << shifted << std::endl;

    shifted.block(2,0,2,1) << -1, -2; // will appear in row 3 and 0.
    shifted.col(1) << 2,4,6,8;  // shifted col 1 is col 0 of the original

    std::cout << "modified original:" << std::endl << x << std::endl;

    return 0;
}

关于c++ - Eigen:可修改的自定义表达式,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46077242/

相关文章:

c++数组赋值多个值

c++ - 函数被多次调用

kubernetes - 新的 kubernetes kustomize 支持中的变量替换(自 1.14.0 起): kubectl apply -k ./

ruby - 如何在 Eclipse 的 ruby​​ 插件中自定义编辑器的背景颜色?

flutter - 如何在 flutter 过渡期间为旧页面设置动画?

C++ - MATLAB : updating a Sparse Matrix blockwise

c++ - 调试断言失败!表达式 : _pFirstBlock == pHead

c++ - 实现程序配置设置的好方法是什么?

c++ - 使 Eigen::Vector 看起来像点 vector

c++ - 包含 vector 和协方差矩阵的负积