python - 非线性决策边界的 SVM 图

标签 python python-3.x machine-learning scikit-learn svm

我正在尝试绘制 SVM 决策边界,它将癌性和非癌性两类分开。但是,它显示的情节与我想要的相去甚远。我希望它看起来像这样:

enter image description here 或任何显示点分散的东西。这是我的代码:

import numpy as np
import pandas as pd
from sklearn import svm
from mlxtend.plotting import plot_decision_regions
import matplotlib.pyplot as plt

autism = pd.read_csv('predictions.csv')


# Fit Support Vector Machine Classifier
X = autism[['TARGET','Predictions']]
y = autism['Predictions']

clf = svm.SVC(C=1.0, kernel='rbf', gamma=0.8)
clf.fit(X.values, y.values) 

# Plot Decision Region using mlxtend's awesome plotting function
plot_decision_regions(X=X.values, 
                      y=y.values,
                      clf=clf, 
                      legend=2)

# Update plot object with X/Y axis labels and Figure Title
plt.xlabel(X.columns[0], size=14)
plt.ylabel(X.columns[1], size=14)
plt.title('SVM Decision Region Boundary', size=16)
plt.show()

但我有一个看起来很奇怪的情节:

enter image description here

您可以在此处找到 csv 文件 predictions.csv

最佳答案

你听起来有点困惑......

您的 predictions.csv 看起来像:

TARGET  Predictions
     1  0
     0  0
     0  0
     0  0

而且,正如我猜想列名暗示的那样,它包含基本事实 (TARGET) 和一些已经运行的 (?) 模型的 Predictions

鉴于此,您在发布的代码中所做的事情完全毫无意义:您将这两列用作 X 中的特征以预测您的 y,它是...完全相同的列之一 (Predictions),已经包含在您的 X...

您的绘图看起来“很奇怪”仅仅是因为您绘制的不是您的数据点,而是您显示的Xy 数据这里不是应该用于拟合分类器的数据。

我进一步感到困惑,因为在您的链接 repo 中,您的脚本中确实有正确的程序:

autism = pd.read_csv('10-features-uns.csv')

x = autism.drop(['TARGET'], axis = 1)  
y = autism['TARGET']
x_train, X_test, y_train, y_test = train_test_split(x, y, test_size = 0.30, random_state=1)

即从 10-features-uns.csv 中读取您的功能和标签,当然不是predictions.csv 中读取,因为您正在莫名其妙地尝试这样做这里……

关于python - 非线性决策边界的 SVM 图,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55773434/

相关文章:

python - 如何循环遍历Keras fit函数?

machine-learning - BERT多类别情感分析准确率低?

python比较不同时区的日期时间

python - Django 查询上的 value() 方法后的计数和最大值

mongodb - 在 insert_many() 失败后获取插入的 ID

python - 尽可能方便地合并文件

java - 将 XML 位图数据转换为图像

python - 如何仅对 pandas 中的一组中的某些行进行排序?

python - 从python中的函数返回错误字符串

python - 如何强制 sklearn CountVectorizer 不删除特殊字符(即 #、@、$ 或 %)