diff --git a/src/python/03.DecisionTree/DTSklearn.py b/src/python/03.DecisionTree/DTSklearn.py index a4890961..1f451c61 100644 --- a/src/python/03.DecisionTree/DTSklearn.py +++ b/src/python/03.DecisionTree/DTSklearn.py @@ -51,7 +51,7 @@ def predict_train(x_train, y_train): return y_pre, clf -def show_precision_recall(x, clf, y_train, y_pre): +def show_precision_recall(x, y, clf, y_train, y_pre): ''' 准确率与召回率 参考链接: http://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html#sklearn.metrics.precision_recall_curve @@ -110,7 +110,7 @@ if __name__ == '__main__': y_pre, clf = predict_train(x_train, y_train) # 展现 准确率与召回率 - show_precision_recall(x, clf, y_train, y_pre) + show_precision_recall(x, y, clf, y_train, y_pre) # 可视化输出 show_pdf(clf) diff --git a/src/python/03.DecisionTree/DecisionTree.py b/src/python/03.DecisionTree/DecisionTree.py index 03a4d25f..e5f3d41b 100644 --- a/src/python/03.DecisionTree/DecisionTree.py +++ b/src/python/03.DecisionTree/DecisionTree.py @@ -7,9 +7,9 @@ Update on 2017-02-27 Decision Tree Source Code for Machine Learning in Action Ch. 3 @author: Peter Harrington/jiangzhonglian ''' -from math import log import operator -import DecisionTreePlot as dtPlot +from math import log +import decisionTreePlot as dtPlot def createDataSet(): @@ -130,7 +130,9 @@ def chooseBestFeatureToSplit(dataSet): prob = len(subDataSet)/float(len(dataSet)) newEntropy += prob * calcShannonEnt(subDataSet) # gain[信息增益] 值越大,意味着该分类提供的信息量越大,该特征对分类的不确定程度越小 + # 也就说: 列进行group分组后,对应的类别越多,信息量越大,那么香农熵越小,那么信息增益就越大,所以gain越大 infoGain = baseEntropy - newEntropy + # print 'infoGain=', infoGain, 'bestFeature=', i if (infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i diff --git a/src/python/03.DecisionTree/DecisionTreePlot.py b/src/python/03.DecisionTree/DecisionTreePlot.py index 737e4d31..3c5c4d31 100644 --- a/src/python/03.DecisionTree/DecisionTreePlot.py +++ b/src/python/03.DecisionTree/DecisionTreePlot.py @@ -128,5 +128,5 @@ def retrieveTree(i): return listOfTrees[i] -myTree = retrieveTree(0) +myTree = retrieveTree(1) createPlot(myTree) diff --git a/src/python/09.RegTrees/RTSklearn.py b/src/python/09.RegTrees/RTSklearn.py new file mode 100644 index 00000000..72036a23 --- /dev/null +++ b/src/python/09.RegTrees/RTSklearn.py @@ -0,0 +1,50 @@ +#!/usr/bin/python +# coding:utf8 + +''' +Created on 2017-03-10 +Update on 2017-03-10 +author: jiangzhonglian +content: 回归树 +''' + +print(__doc__) + + +# Import the necessary modules and libraries +import numpy as np +from sklearn.tree import DecisionTreeRegressor +import matplotlib.pyplot as plt + + +# Create a random dataset +rng = np.random.RandomState(1) +X = np.sort(5 * rng.rand(80, 1), axis=0) +y = np.sin(X).ravel() +print X, '\n\n\n-----------\n\n\n', y +y[::5] += 3 * (0.5 - rng.rand(16)) + + +# Fit regression model +regr_1 = DecisionTreeRegressor(max_depth=2, min_samples_leaf=5) +regr_2 = DecisionTreeRegressor(max_depth=5, min_samples_leaf=5) +regr_1.fit(X, y) +regr_2.fit(X, y) + + +# Predict +X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis] +y_1 = regr_1.predict(X_test) +y_2 = regr_2.predict(X_test) + + +# Plot the results +plt.figure() +plt.scatter(X, y, c="darkorange", label="data") +plt.plot(X_test, y_1, color="cornflowerblue", label="max_depth=2", linewidth=2) +plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=5", linewidth=2) +plt.xlabel("data") +plt.ylabel("target") +plt.title("Decision Tree Regression") +plt.legend() +plt.show() diff --git a/src/python/09.RegTrees/treeExplore.py b/src/python/09.RegTrees/treeExplore.py index 26fd9a8b..aa394f5d 100644 --- a/src/python/09.RegTrees/treeExplore.py +++ b/src/python/09.RegTrees/treeExplore.py @@ -7,14 +7,14 @@ Update on 2017-03-08 Tree-Based Regression Methods Source Code for Machine Learning in Action Ch. 9 @author: jiangzhonglian ''' +import regTrees from Tkinter import * from numpy import * -import regTrees import matplotlib -matplotlib.use('TkAgg') -from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg from matplotlib.figure import Figure +from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg +matplotlib.use('TkAgg') def test_widget_text(root): diff --git a/testResult/tree.pdf b/testResult/tree.pdf index f2b72829..6cb99e20 100644 Binary files a/testResult/tree.pdf and b/testResult/tree.pdf differ