mirror of
https://github.com/apachecn/ailearning.git
synced 2026-02-12 06:46:14 +08:00
更新决策树的内容,做ppt
This commit is contained in:
@@ -51,7 +51,7 @@ def predict_train(x_train, y_train):
|
||||
return y_pre, clf
|
||||
|
||||
|
||||
def show_precision_recall(x, clf, y_train, y_pre):
|
||||
def show_precision_recall(x, y, clf, y_train, y_pre):
|
||||
'''
|
||||
准确率与召回率
|
||||
参考链接: http://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html#sklearn.metrics.precision_recall_curve
|
||||
@@ -110,7 +110,7 @@ if __name__ == '__main__':
|
||||
y_pre, clf = predict_train(x_train, y_train)
|
||||
|
||||
# 展现 准确率与召回率
|
||||
show_precision_recall(x, clf, y_train, y_pre)
|
||||
show_precision_recall(x, y, clf, y_train, y_pre)
|
||||
|
||||
# 可视化输出
|
||||
show_pdf(clf)
|
||||
|
||||
@@ -7,9 +7,9 @@ Update on 2017-02-27
|
||||
Decision Tree Source Code for Machine Learning in Action Ch. 3
|
||||
@author: Peter Harrington/jiangzhonglian
|
||||
'''
|
||||
from math import log
|
||||
import operator
|
||||
import DecisionTreePlot as dtPlot
|
||||
from math import log
|
||||
import decisionTreePlot as dtPlot
|
||||
|
||||
|
||||
def createDataSet():
|
||||
@@ -130,7 +130,9 @@ def chooseBestFeatureToSplit(dataSet):
|
||||
prob = len(subDataSet)/float(len(dataSet))
|
||||
newEntropy += prob * calcShannonEnt(subDataSet)
|
||||
# gain[信息增益] 值越大,意味着该分类提供的信息量越大,该特征对分类的不确定程度越小
|
||||
# 也就说: 列进行group分组后,对应的类别越多,信息量越大,那么香农熵越小,那么信息增益就越大,所以gain越大
|
||||
infoGain = baseEntropy - newEntropy
|
||||
# print 'infoGain=', infoGain, 'bestFeature=', i
|
||||
if (infoGain > bestInfoGain):
|
||||
bestInfoGain = infoGain
|
||||
bestFeature = i
|
||||
|
||||
@@ -128,5 +128,5 @@ def retrieveTree(i):
|
||||
return listOfTrees[i]
|
||||
|
||||
|
||||
myTree = retrieveTree(0)
|
||||
myTree = retrieveTree(1)
|
||||
createPlot(myTree)
|
||||
|
||||
Reference in New Issue
Block a user