python - 使用 pybind11 将 NumPy 数组转换到自定义 C++ 矩阵类或从自定义 C++ 矩阵类转换

标签 python c++ arrays numpy pybind11

我正在尝试使用 pybind11 包装我的 C++ 代码.在 C++ 中,我有一个类 Matrix3D它充当 3-D 数组(即形状为 [n,m,p] )。它具有以下基本签名:

template <class T> class Matrix3D
{

  public:

    std::vector<T> data;
    std::vector<size_t> shape;
    std::vector<size_t> strides;

    Matrix3D<T>();
    Matrix3D<T>(std::vector<size_t>);
    Matrix3D<T>(const Matrix3D<T>&);

    T& operator() (int,int,int);

};

为了尽量减少包装代码,我想将此类直接转换为 NumPy 数组(拷贝没有问题)。例如,我想直接包装一个具有以下签名的函数:

Matrix3D<double> func ( const Matrix3D<double>& );

使用封装代码

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>

namespace py = pybind11;

PYBIND11_PLUGIN(example) {
  py::module m("example", "Module description");
  m.def("func", &func, "Function description" );
  return m.ptr();
}

目前我有另一个接受并返回 py::array_t<double> 的函数。 .但是我想通过用一些模板替换它来避免为每个函数编写一个包装函数。

已为 Eigen 完成此操作-库(用于数组和(2-D)矩阵)。但是代码太复杂了,我无法从中导出自己的代码。另外,我真的只需要包装一个简单的类。

最佳答案

在@kazemakase 和@jagerman(后者通过 pybind11 forum )的帮助下,我已经弄明白了。类本身应该有一个可以从一些输入复制的构造函数,这里使用迭代器:

#include <vector>
#include <assert.h>
#include <iterator>


template <class T> class Matrix3D
{
public:

  std::vector<T>      data;
  std::vector<size_t> shape;
  std::vector<size_t> strides;

  Matrix3D<T>() = default;

  template<class Iterator>
  Matrix3D<T>(const std::vector<size_t> &shape, Iterator first, Iterator last);
};


template <class T>
template<class Iterator>
Matrix3D<T>::Matrix3D(const std::vector<size_t> &shape_, Iterator first, Iterator last)
{
  shape = shape_;

  assert( shape.size() == 3 );

  strides.resize(3);

  strides[0] = shape[2]*shape[1];
  strides[1] = shape[2];
  strides[2] = 1;

  int size = shape[0] * shape[1] * shape[2];

  assert( last-first == size );

  data.resize(size);

  std::copy(first, last, data.begin());
}

直接包装具有以下签名的函数:

Matrix3D<double> func ( const Matrix3D<double>& );

需要下面的封装代码

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>

namespace py = pybind11;

namespace pybind11 { namespace detail {
  template <typename T> struct type_caster<Matrix3D<T>>
  {
    public:

      PYBIND11_TYPE_CASTER(Matrix3D<T>, _("Matrix3D<T>"));

      // Conversion part 1 (Python -> C++)
      bool load(py::handle src, bool convert) 
      {
        if ( !convert and !py::array_t<T>::check_(src) )
          return false;

        auto buf = py::array_t<T, py::array::c_style | py::array::forcecast>::ensure(src);
        if ( !buf )
          return false;

        auto dims = buf.ndim();
        if ( dims != 3  )
          return false;

        std::vector<size_t> shape(3);

        for ( int i = 0 ; i < 3 ; ++i )
          shape[i] = buf.shape()[i];

        value = Matrix3D<T>(shape, buf.data(), buf.data()+buf.size());

        return true;
      }

      //Conversion part 2 (C++ -> Python)
      static py::handle cast(const Matrix3D<T>& src, py::return_value_policy policy, py::handle parent) 
      {

        std::vector<size_t> shape  (3);
        std::vector<size_t> strides(3);

        for ( int i = 0 ; i < 3 ; ++i ) {
          shape  [i] = src.shape  [i];
          strides[i] = src.strides[i]*sizeof(T);
        }

        py::array a(std::move(shape), std::move(strides), src.data.data() );

        return a.release();

      }
  };
}} // namespace pybind11::detail

PYBIND11_PLUGIN(example) {
    py::module m("example", "Module description");
    m.def("func", &func, "Function description" );
    return m.ptr();
}

请注意,函数重载现在也是可能的。例如,如果存在具有以下签名的重载函数:

Matrix3D<int   > func ( const Matrix3D<int   >& );
Matrix3D<double> func ( const Matrix3D<double>& );

需要以下包装函数定义:

m.def("func", py::overload_cast<Matrix3D<int   >&>(&func), "Function description" );
m.def("func", py::overload_cast<Matrix3D<double>&>(&func), "Function description" );

关于python - 使用 pybind11 将 NumPy 数组转换到自定义 C++ 矩阵类或从自定义 C++ 矩阵类转换,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42645228/

相关文章:

python - 以对数刻度绘制 x 和 y 轴

c++ - 如何计算 ELF 文件中的静态初始值设定项?

c++ - 如何制作一个可以在 C 库中用作回调的 wxWidget 方法?

JavaScript push() 问题

c - 2 次方大小数据的性能优势?

c - 尝试了解 C 中多维动态字段的用法

Python:OSError:[Errno 2] subprocess.Popen 上没有这样的文件或目录

python - 从 OrderedDict : preserving columns order 列表构建 Pandas DataFrame

python - Python如何使用numpy数组进行垃圾回收追加和删除?

c++ - 如何为qsort编写比较器函数?