python - 如何将 numpy 数组的多个值添加到图例中?

标签 python python-3.x numpy matplotlib knn

在我的实验中,我使用 KNN 对一些数据集进行分类(共享 here 以实现可重复性)。下面是我的源代码。

import numpy as np
from numpy import genfromtxt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt




types = {
        "Data_G": ["datag_s.csv", "datag_m.csv"], 
        "Data_V": ["datav_s.csv", "datav_m.csv"], 
        "Data_C": ["datac_s.csv", "datac_m.csv"], 
        "Data_R": ["datar_s.csv", "datar_m.csv"]
        }

dataset = None
ground_truth = None

for idx, csv_list in types.items():
    for csv_f in csv_list:

        col_time,col_window = np.loadtxt(csv_f,delimiter=',').T
        trailing_window = col_window[:-1] # "past" values at a given index
        leading_window  = col_window[1:]  # "current values at a given index
        decreasing_inds = np.where(leading_window < trailing_window)[0]
        beta_value = leading_window[decreasing_inds]/trailing_window[decreasing_inds]
        quotient_times = col_time[decreasing_inds]

        my_data = genfromtxt(csv_f, delimiter=',')
        my_data = my_data[:,1]
        my_data = my_data[:int(my_data.shape[0]-my_data.shape[0]%200)].reshape(-1, 200)
        labels = np.full(1, idx)

        if dataset is None:
            dataset = beta_value.reshape(1,-1)[:,:15]
        else:
            dataset = np.concatenate((dataset,beta_value.reshape(1,-1)[:,:15]))

        if ground_truth is None:
            ground_truth = labels
        else:
            ground_truth = np.concatenate((ground_truth,labels))



X_train, X_test, y_train, y_test = train_test_split(dataset, ground_truth, test_size=0.25, random_state=42)

knn_classifier = KNeighborsClassifier(n_neighbors=3, weights='distance', algorithm='auto', leaf_size=300, p=2, metric='minkowski')
knn_classifier.fit(X_train, y_train)

当我执行以下操作时

plot_data=dataset.transpose()
plt.plot(plot_data)

它产生以下图。

enter image description here

我将图例添加到图中,如下所示:

plt.plot(plot_data, label=idx)
plt.legend()

enter image description here

但是,正如所见,它将所有图例替换为 Data_R。我在这里做错了什么?

最佳答案

在回答这个问题之前我要说的一件事是,在循环字典时我总是小心谨慎。在 Python 3.6 之前,字典不是有序的,因此如果您需要保证字典中的顺序,您应该使用 OrderedDict。如果您运行的是 Python3.6+,那么您无需担心这一点。无论如何...

在 for 循环之后 for idx, csv_list in types.items(): 我们将始终拥有 idx = "Data_R" (假设您的字典是有序的) .

因此,当您使用 plt.plot(plot_data, label=idx) 绘图时,所有线条的标签都将设置为“Data_R”

相反,您应该遍历这些行并一次向其中添加一个标签。

for i, key in enumerate(types.keys()):
    plt.plot(plot_data[:, 2*i], label=key)
    plt.plot(plot_data[:, 2*i+1], label=key)

plt.legend()

关于python - 如何将 numpy 数组的多个值添加到图例中?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55039420/

相关文章:

python-3.x - AWS CDK - 访问新功能

python - 尝试在 PyQT5 中打开文件时出现无效文件错误

python - numpy.unique 生成的列表在哪些方面是唯一的?

python - python中将edgelist导入igraph的格式

python - 从另一个包中动态导入一个包

python - 无法让验证码出现

python - 找出 _ast 模块的元素,该模块在 ast.py 中导入

python - Scipy:Argrelmax 在 3d 数组中沿维度查找 Maxmima

Python 运行时警告 : overflow encountered in long scalars

python - 使用虚拟环境时安装python模块