2022年 11月 5日

Python 决策树

                                        Python 决策树

1 声明

本文的数据来自网络,部分代码也有所参照,这里做了注释和延伸,旨在技术交流,如有冒犯之处请联系博主及时处理。

2 决策树简介

相关概念见下:

决策树是一个无参数的有监督的分类和回归算法。该算法通过IF-THEN-ELSE决策规则(比如:如果绩效考核是A则发奖金1000K,是B则发500)的方式来从数据中学习模型。这种决策的结构就像一个倒置树(第一个决策规则在最顶端,其它的节点随之展开)。在决策树里,每个决策规则都发生在一个决策节点上,该规则创建指向新节点的分支。一个在末端的没有决策规则的分支称为叶子节点。树越深决策树规则越复杂,模型拟合的越好。决策树构建时采用树形结构。树有两个或者多个分支,这里常见的是二分支。它能处理分类变量和连续型变量,适合处理非线性的关系。

决策树之所以较为流行是因为其较好的解释性,它的模型可以通过树状图形的方式直观的展示出来。

决策树示例(详细见韩家炜著《数据挖数据挖掘:概念与技术》内决策树部分介绍):

决策树模型总是试图一个在节点上找到产生最大的纯度的决策规则。划分纯度的指标有多,Scikit Learn里默认的是Gini。

这里 是节点t的杂质度(不纯净度,即分裂标准), 是节点t里类别是C的占比。这种寻找决策规则(该规则会产生分裂以增加不纯度)的过程会递归式重复执行,直到所有叶节点都是纯的(即节点划到某个类别)。

决策树回归

决策树回归与决策树分类类似,只不过它不是通过降低Gini不纯度或者熵,而是通过降低的残差平方和来度量。

这里 是目标变量的真实值, 是目标变量的预测值。

回归,就是根据特征向量来决定对应的输出值。回归树就是将特征空间划分成若干单元,每一个划分单元有一个特定的输出。因为每个结点都是“是”和“否”的判断,所以划分的边界是平行于坐标轴的。

那么得到的划分区域见下:

3 决策树代码与注释示例

  1. def irisdt():
  2. from sklearn import tree
  3. from sklearn import model_selection
  4. from sklearn.datasets import load_iris
  5. #from sklearn.grid_search import GridSearchCV
  6. from sklearn.model_selection import GridSearchCV
  7. from sklearn.metrics import classification_report
  8. import matplotlib.pyplot as plt
  9. iris = load_iris()
  10. x = iris.data
  11. y = iris.target
  12. X_train, X_test, y_train, y_test = model_selection \
  13. .train_test_split(x, y, test_size=0.2,
  14. random_state=123456)
  15. parameters = {
  16. 'criterion': ['gini', 'entropy'],
  17. 'max_depth': range(1,30),#[1, 2, 3, 4, 5, 6, 7, 8,9,10],
  18. 'max_leaf_nodes': [2,3,4, 5, 6, 7, 8, 9] #最大叶节点数
  19. }
  20. dtree = tree.DecisionTreeClassifier()
  21. grid_search = GridSearchCV(dtree, parameters, scoring='accuracy', cv=5)
  22. grid_search.fit(x, y)
  23. print(grid_search.best_estimator_) # 查看grid_search方法
  24. print(grid_search.best_score_) # 正确率
  25. print(grid_search.best_params_) # 最佳 参数组合
  26. dtree = tree.DecisionTreeClassifier(criterion='gini', max_depth=3)
  27. dtree.fit(X_train, y_train)
  28. pred = dtree.predict(X_test)
  29. print(pred)
  30. print(y_test)
  31. print(classification_report(y_test, pred,target_names=['setosa', 'versicolor', 'virginica']))
  32. print(dtree.predict([[6.9,3.3,5.6,2.4]]))#预测属于哪个分类
  33. print(dtree.predict_proba([[6.9,3.3,5.6,2.4]])) # 预测所属分类的概率值
  34. ##print(iris.target)
  35. print(list(iris.target_names)) #输出目标值的元素名称
  36. #print(grid_search.estimator.score(y_test, pred))
  37. def irisdecisontree():
  38. from sklearn import datasets
  39. iris = datasets.load_iris()
  40. X_train = iris.data[:,[0,1]][0:150]
  41. y_train = iris.target
  42. #print(iris.feature_names)
  43. #print(type(X_train))
  44. ##print(X_train[:,[0,1]][0:150])
  45. clf = tree.DecisionTreeClassifier(max_depth=3,criterion='entropy')
  46. clf = clf.fit(X_train, y_train)
  47. with open("../output/iristree.dot",'w') as f:
  48. f = export_graphviz(clf,feature_names=['sepallength','sepalwidth'],out_file=f)
  49. import os
  50. os.system('dot -Tpdf "../output/iristree.dot" -o "../output/iristree.pdf"')
  51. if __name__ == '__main__':
  52. irisdt()

4 总结