c++ - 用于 C++ 的 mxnet ndarray 迭代器

标签 c++ iterator mxnet

我想用 C++ 训练一个简单的分类器,非常符合 C++ mnist example 的风格。 ,但是我的数据没有存储在 HD 上,而是已经加载到内存中,比如 mxnet NDArray。在 Python 中,为了这个目的,我们有方便的 NDArrayIter,c.f. Module tutorial .

C++ 有这样的 NDArray 迭代器吗?

浏览代码我发现所有可能的MXDataIter都可以从MXListDataItersMXDataIterGetIterInfo中读出:

#include "mxnet-cpp/io.h"
using namespace std;
using namespace mxnet::cpp;

int main(int argc, char** argv) {
  Context ctx = Context::cpu();  // Use CPU

  mx_uint num_data_iter_creators;
  DataIterCreator *data_iter_creators = nullptr;

  int r = MXListDataIters(&num_data_iter_creators, &data_iter_creators);
  CHECK_EQ(r, 0);
  cout << "num_data_iter_creators = " << num_data_iter_creators << endl;
  //output: num_data_iter_creators = 8

  const char *name;
  const char *description;
  mx_uint num_args;
  const char **arg_names;
  const char **arg_type_infos;
  const char **arg_descriptions;

  for (mx_uint i = 0; i < num_data_iter_creators; i++) {
      r = MXDataIterGetIterInfo(data_iter_creators[i], &name, &description,
                                &num_args, &arg_names, &arg_type_infos,
                                &arg_descriptions);
      CHECK_EQ(r, 0);
      cout << " i: " << i << ", name: " << name << endl;
  }

  MXNotifyShutdown();
  return 0;
}

产生八个 MXDataIter():

num_data_iter_creators = 8
 i: 0, name: ImageDetRecordIter
 i: 1, name: CSVIter
 i: 2, name: ImageRecordIter_v1
 i: 3, name: ImageRecordUInt8Iter_v1
 i: 4, name: MNISTIter
 i: 5, name: ImageRecordIter
 i: 6, name: ImageRecordUInt8Iter
 i: 7, name: LibSVMIter

所以在我看来,对于 C++,没有 NDArray 迭代器,最简单的解决方案是将我的数据写入 csv 文件,然后将其再次加载到 MXDataIter(CSVIter) .另一种可能性是手动将数据分成批处理的 NDArray,然后将它们提供给训练,但这也让人觉得笨拙。

最佳答案

不幸的是,C++ 包中没有 NDArrayIter。

但我要说的是,如果您真的需要的话,实现起来应该不难。看看它是如何在 Python 中实现的,也许您可​​以通过 C++ 的实现回馈社区 - https://github.com/apache/incubator-mxnet/blob/fe5b56e419d454dc8f42f0307f53ced133804ca7/python/mxnet/io.py#L544

关于c++ - 用于 C++ 的 mxnet ndarray 迭代器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48743700/

相关文章:

c++ - 编译器错误 : failed to bind overloaded function with two arguments

c++ - 是否有打印库信息的命令? (C++)

c++ - 使用重载运算符在函数调用时执行操作?

php - 如何在 PHP 中正确安全地委托(delegate)迭代器?

machine-learning - mxnet 训练没有进展

python - scikit-learn, keras, tensorflow 和 mxnet 中保存机器学习模型的所有格式是什么?

python - 使用 Apache mxnet (gluon) 训练神经网络导致程序崩溃

c++ - 模板函数指针作为模板参数

c++ - 如何为双向迭代器定义 operator<?

Python zip 对象只能使用一次。这是为什么?