python - 使用 int 列表的稀疏矩阵切片

标签 python scipy segmentation-fault sparse-matrix

我正在针对大量稀疏数据编写机器学习算法(我的矩阵形状为 (347, 5 416 812 801) 但非常稀疏,只有 0.13% 的数据不为零。

我的稀疏矩阵大小为 105 000 字节 (<1Mbytes) 并且是 csr 类型。

我试图通过为每个训练集/测试集选择一个示例索引列表来分离训练集/测试集。 所以我想使用以下方法将我的数据集一分为二:

training_set = matrix[train_indices]

形状 (len(training_indices), 5 416 812 801),仍然稀疏

testing_set = matrix[test_indices]

形状 (347-len(training_indices), 5 416 812 801) 也是稀疏的

training_indicestesting_indices两个listint

但是 training_set = matrix[train_indices] 似乎失败并返回一个 Segmentation fault (core dumped)

这可能不是内存问题,因为我在具有 64GB RAM 的服务器上运行此代码。

关于可能是什么原因的任何线索?

最佳答案

我想我已经重新创建了 csr 行索引:

def extractor(indices, N):
   indptr=np.arange(len(indices)+1)
   data=np.ones(len(indices))
   shape=(len(indices),N)
   return sparse.csr_matrix((data,indices,indptr), shape=shape)

在我闲逛的 csr 上进行测试:

In [185]: M
Out[185]: 
<30x40 sparse matrix of type '<class 'numpy.float64'>'
    with 76 stored elements in Compressed Sparse Row format>

In [186]: indices=np.r_[0:20]

In [187]: M[indices,:]
Out[187]: 
<20x40 sparse matrix of type '<class 'numpy.float64'>'
    with 57 stored elements in Compressed Sparse Row format>

In [188]: extractor(indices, M.shape[0])*M
Out[188]: 
<20x40 sparse matrix of type '<class 'numpy.float64'>'
    with 57 stored elements in Compressed Sparse Row format>

与许多其他 csr 方法一样,它使用矩阵乘法来产生最终值。在这种情况下,稀疏矩阵在选定行中为 1。时间其实好一点。

In [189]: timeit M[indices,:]
1000 loops, best of 3: 515 µs per loop
In [190]: timeit extractor(indices, M.shape[0])*M
1000 loops, best of 3: 399 µs per loop

在您的情况下,提取器矩阵的形状为 (len(training_indices),347),只有 len(training_indices) 值。所以它并不大。

但是如果 matrix 太大(或者至少第二维太大)以至于它在矩阵乘法例程中产生一些错误,它可能会在没有 python/numpy 陷阱的情况下引起段错误

matrix.sum(axis=1) 是否有效。这也使用了矩阵乘法,尽管有一个 1s 的密集矩阵。或者 sparse.eye(347)*M,类似大小的矩阵乘法?

关于python - 使用 int 列表的稀疏矩阵切片,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39500649/

相关文章:

python - 手动图例中的自定义标记边缘样式

python - python中最快的成对距离度量

通过内联汇编访问系统时间后出现 C 段错误

c - 为什么这会产生段错误?

python - 处理缺失值: When 99% of the data is missing from most columns (important ones)

python - Graphite 和 GUnicorn - 配置问题或路径问题

python - 如何使用python计算信噪比

c - 尝试打印数组的第一个值时出现段错误

python - lxml unicode字符

python - 如何制作对数最佳拟合线?