python - 通过索引访问 coo_matrix 时出现类型错误

标签 python python-3.x numpy scipy sparse-matrix

我有 coo_matrix X 和索引 trn_idx,我想通过它们访问该 maxtrix

print (type(X  ), X.shape)
print (type(trn_idx), trn_idx.shape)

<class 'scipy.sparse.coo.coo_matrix'> (1503424, 2795253)
<class 'numpy.ndarray'> (1202739,)

这样调用:

X[trn_idx]
TypeError: only integer scalar arrays can be converted to a scalar index

无论是这样:

 X[trn_idx.astype(int)] #same error

如何通过索引访问?

最佳答案

coo_matrix 类不支持索引。您必须将其转换为不同的稀疏格式。

这是一个带有小型 coo_matrix 的示例:

In [19]: import numpy as np

In [20]: from scipy.sparse import coo_matrix

In [21]: m = coo_matrix([[0, 0, 0, 1], [2, 0, 0 ,0], [0, 0, 0, 0], [0, 3, 4, 0]])

尝试索引m失败:

In [22]: m[0,0]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-22-1f78c188393f> in <module>()
----> 1 m[0,0]

TypeError: 'coo_matrix' object is not subscriptable

In [23]: idx = np.array([2, 3])

In [24]: m[idx]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-24-a52866a6fec6> in <module>()
----> 1 m[idx]

TypeError: only integer scalar arrays can be converted to a scalar index

如果将m转换为CSR矩阵,则可以使用idx对其进行索引:

In [25]: m.tocsr()[idx]
Out[25]: 
<2x4 sparse matrix of type '<class 'numpy.int64'>'
    with 2 stored elements in Compressed Sparse Row format>

如果您打算进行更多索引,最好将新数组保存在变量中,并根据需要使用它:

In [26]: a = m.tocsr()

In [27]: a[idx]
Out[27]: 
<2x4 sparse matrix of type '<class 'numpy.int64'>'
    with 2 stored elements in Compressed Sparse Row format>

In [28]: a[0,0]
Out[28]: 0

关于python - 通过索引访问 coo_matrix 时出现类型错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50898924/

相关文章:

python - 当字符位于 unicode 范围内时,如何用空格填充字符?

python - 使用 Python 读取 YAML 文件会导致 yaml.scanner.ScannerError : mapping values are not allowed here

python - 使用 Numpy 在元组列表之间进行外部减法

Python pandas dataframe groupby 选择列

python-3.x - Scrapy - 从多个页面中提取数据

python - 使用 python numpy 矩阵类的梯度下降

python - 从其他域发送 Post 请求到 GAE

python - PyCharm 无法识别以开发模式安装的模块

python - 如何获得多维数组的填充切片?

python - 如何将 Numpy 4D 数组保存为 CSV?