c++ - "Splitting"常数时间内的矩阵

标签 c++ c++11 matrix matrix-multiplication strassen

我正在尝试在 C++ 中实现 Strassen 的矩阵乘法算法,并且我想找到一种方法在常数时间内将两个矩阵分成四个部分。这是我目前的做法:

for(int i = 0; i < n; i++){
    for(int j = 0; j < n; j++){
        A11[i][j] = a[i][j];
        A12[i][j] = a[i][j+n];
        A21[i][j] = a[i+n][j];
        A22[i][j] = a[i+n][j+n];
        B11[i][j] = b[i][j];
        B12[i][j] = b[i][j+n];
        B21[i][j] = b[i+n][j];
        B22[i][j] = b[i+n][j+n];
    }
}

这种方法显然是 O(n^2),并且它将 n^2*log(n) 添加到运行时,因为每次递归调用都会调用它。

似乎在恒定时间内执行此操作的方法是创建指向四个子矩阵的指针,而不是复制值,但我很难弄清楚如何创建这些指针。任何帮助将不胜感激。

最佳答案

不要考虑矩阵,考虑矩阵 View 。

矩阵 View 具有指向 T 缓冲区的指针、宽度、高度、偏移量和列(或行)之间的步幅。

我们可以从数组 View 类型开始。

template<class T>
struct array_view {
  T* b = 0; T* e = 0;
  T* begin() const{ return b; }
  T* end() const{ return e; }

  array_view( T* s, T* f ):b(s), e(f) {}
  array_view( T* s, std::size_t l ):array_view(s, s+l) {}

  std::size_t size() const { return end()-begin(); }
  T& operator[]( std::size_t n ) const { return *(begin()+n); }
  array_view slice( std::size_t start, std::size_t length ) const {
    start = (std::min)(start, size());
    length = (std::min)(size()-start, length);
    return {b+start, length};
  }
};

现在我们的矩阵 View :

temlpate<class T>
struct matrix_view {
  std::size_t height, width;
  std::size_t offset, stride;
  array_view<T> buffer;

  // TODO: Ctors
  // one from a matrix that has offset and stirde set to 0.
  // another that lets you create a sub-matrix
  array_view<T> operator[]( std::size_t n ) const {
    return buffer.slice( offset+stride*n, width ); // or width, depending on if row or column major
  }
};

现在您的代码可以在 matrix_view 上运行,而不是在矩阵上运行。

关于c++ - "Splitting"常数时间内的矩阵,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40870813/

相关文章:

c++ - 生成前缀位掩码

当 C++ 使用智能指针时,C# 委托(delegate)等效

c++ - CMake 使用外部库编译

c++ - 获取没有访问器或修改器的嵌套类成员

c++ - 为什么 is_class<T> 在此代码段中不起作用?

c++ - 指向模板类方法的指针的 Typedef

c++ - 使用 "multiple"命名空间单行

html/css 用于围绕数学矩阵的括号——更喜欢轻量级

algorithm - 在矩阵中找到最大可访问节点

r - model.matrix 生成的行数比原始 data.frame 少