python - 动态构建通用 numpy 数组索引

标签 python arrays numpy scikit-learn

基本背景

我正在编写一些代码来简化针对具有不同特征数量的数据的 SVM 训练,并使用用户指定的“切片”可视化这些 SVM 的决策边界。如果我的数据集中有 n 个特征和 m 个样本,我会生成一个 (n+1)-dimensional 网格,其中每个切片沿着第一个索引是维度为 nm x m x ... 网格。然后我可以使用 SVM 对网格中的每个数据点进行分类。

我接下来要做的是在用户指定的任意两个维度上绘制这些结果的一部分。当数据只有两个特征时,我有代码可以绘制我想要的内容,但是一旦我添加了第三个特征,我就开始遇到索引问题。

问题陈述

假设我有一个三维矩阵,predictions,我想将这些预测绘制在与 index0 关联的网格 mesh 中的所有值上=0index1=1,以及这些维度的训练数据。我可以通过这样的函数调用来做到这一点:

import matplotlib.pyplot as plt
plt.contourf(mesh[index0,:,:,0], mesh[index1,:,:,0], pred[:,:,0])
plt.scatter(samples[:,index0], samples[:,index1], c=labels)
plt.show()

我想知道的是如何动态构建我的索引数组,以便如果 index0=0index1=1,我们得到上面的代码,但是如果index0=1index1=2,我们会得到:

plt.contourf(mesh[index0,0,:,:], mesh[index1,0,:,:], pred[0,:,:])

如果 index0=0index1=2,我们会得到:

plt.contourf(mesh[index0,:,0,:], mesh[index1,:,0,:], pred[:,0,:])

我怎样才能动态地构建这些?对于我可能无法提前知道数据将具有多少特征的情况,一般情况下是否有更好的方法来解决这个问题?

更多详情

我尝试了类似的东西:

mesh_indices0 = [0]*len(mesh.shape)
mesh_indices0[0] = index0
mesh_indices0[index0+1] = ':'    # syntax error: I cannot add this dynamically
mesh_indices0[index1+1] = ':'    # same problem

我也尝试从相反的方向使用 mesh_indices = [:]*len(mesh.shape),但这也是无效的语法。我想过尝试这样的事情:

mesh_indices[index0+1] = np.r_[:len(samples[:, 1])]

其中 samples 是我的 m x n 观察集。不过,这对我来说似乎真的很笨拙,所以我认为必须有更好的方法。

最佳答案

我不确定我是否完全理解你想要做什么,但如果你想操作切片,你应该使用 python slice 对象:

mesh[index0,0,:,:]

相当于:

mesh[index0,0,slice(0,mesh.shape[2]),slice(0,mesh.shape[3])]

另请注意,您可以使用切片和索引的列表或元组进行索引:

inds = (index0, 0, slice(0,mesh.shape[2]), slice(0,mesh.shape[3]))
mesh[inds]

将它们放在一起,您可以制作一个 : 等价的 slice 对象列表,然后用您的具体索引替换适当的对象。或者,走另一条路:

mesh_indices = [0]*len(mesh.shape)
mesh_indices[0] = index0
mesh_indices[index0+1] = slice(0, mesh.shape[index0+1])
mesh_indices[index1+1] = slice(0, mesh.shape[index1+1])

关于python - 动态构建通用 numpy 数组索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/24023492/

相关文章:

javascript - jQuery - 如果在代码末尾求值,console.log 会提供空数组

python - 将 numpy 2D 数组中的字符串元素转换为数组并生成 3D 数组

python - 将多项式回归从 R 移植到 python

python - 对巨大的矩阵进行排序,然后在列表中找到最小的元素及其索引

python - 如何更改 PIL 中矩形的不透明度?

python - 如何导出 flask restplus swagger json?

python - 我应该为 'social' 网站的后端使用什么?

python - 使用OpenCV删除车牌边框(python)

javascript - 获取 NaN 不一致地映射 parseInt

javascript - typescript 计算数组中的重复项并按每个项目的计数对结果进行排序