python - 在 scikit-learn 中可视化决策树

标签 python scikit-learn visualization decision-tree

我正在尝试在 Python 中使用 scikit-learn 设计一个简单的决策树(我在 Windows 操作系统上使用 Anaconda 的 Ipython Notebook 和 Python 2.7.3)并将其可视化如下:

from pandas import read_csv, DataFrame
from sklearn import tree
from os import system

data = read_csv('D:/training.csv')
Y = data.Y
X = data.ix[:,"X0":"X33"]

dtree = tree.DecisionTreeClassifier(criterion = "entropy")
dtree = dtree.fit(X, Y)

dotfile = open("D:/dtree2.dot", 'w')
dotfile = tree.export_graphviz(dtree, out_file = dotfile, feature_names = X.columns)
dotfile.close()
system("dot -Tpng D:.dot -o D:/dtree2.png")

但是,我收到以下错误:

AttributeError: 'NoneType' object has no attribute 'close'

我使用以下博客文章作为引用:Blogpost link

以下 stackoverflow 问题似乎对我也不起作用:Question

有人可以帮助我如何在 scikit-learn 中可视化决策树吗?

最佳答案

对于那些使用 jupyter 和 sklearn(18.2+) 的人来说,这里是一个衬垫,你甚至不需要 matplotlib 。唯一的要求是 graphviz

pip install graphviz

比运行(根据有问题的代码 X 是 pandas DataFrame)

from graphviz import Source
from sklearn import tree
Source( tree.export_graphviz(dtreg, out_file=None, feature_names=X.columns))

这将以 SVG 格式显示。上面的代码生成 Graphviz 的 Source对象(source_code - 不可怕)这将直接在 jupyter 中呈现。

你可能会用它做一些事情

在 jupter 中显示:

from IPython.display import SVG
graph = Source( tree.export_graphviz(dtreg, out_file=None, feature_names=X.columns))
SVG(graph.pipe(format='svg'))

另存为png:

graph = Source( tree.export_graphviz(dtreg, out_file=None, feature_names=X.columns))
graph.format = 'png'
graph.render('dtree_render',view=True)

获取png图片,保存并查看:

graph = Source( tree.export_graphviz(dtreg, out_file=None, feature_names=X.columns))
png_bytes = graph.pipe(format='png')
with open('dtree_pipe.png','wb') as f:
    f.write(png_bytes)

from IPython.display import Image
Image(png_bytes)

如果您要使用该库,这里是 examples 的链接和 userguide

关于python - 在 scikit-learn 中可视化决策树,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/27817994/

相关文章:

python - 如何删除 Altair 图中的 save_as 图标

python wnck 在 pdb.set_trace() 之前不返回任何数据

python - 如何使用 scikit learn/pandas/python 打印任意一个集群的样本/观察结果/行?

javascript - 邻接矩阵图库

visualization - 包含表情符号的 graphml 文件在 gephi 中被转换为黑色节点

python - Seaborn 折线图样式导致重复的图例条目

javascript - Braintree JSv3 payment_method_nonce 值与 HostedFields 不一致

python - 使用pyserial向Arduino发送信息(操作步进电机)

python - scikit-learn 中 predict 与 predict_proba 的区别

machine-learning - 机器学习: Weighting Training Points by Importance