From a383e83d5fd87ac20b12e850f8f2609bc9e65f41 Mon Sep 17 00:00:00 2001 From: jiangzhonglian Date: Wed, 29 Mar 2017 23:08:09 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B09=20=E6=A0=91=E5=9B=9E?= =?UTF-8?q?=E5=BD=92=E7=9A=84=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/python/09.RegTrees/regTrees.py | 65 ++++++++++++++++----------- src/python/09.RegTrees/treeExplore.py | 2 +- 2 files changed, 41 insertions(+), 26 deletions(-) diff --git a/src/python/09.RegTrees/regTrees.py b/src/python/09.RegTrees/regTrees.py index c95dcf54..c6077e0f 100644 --- a/src/python/09.RegTrees/regTrees.py +++ b/src/python/09.RegTrees/regTrees.py @@ -77,10 +77,10 @@ def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)): """chooseBestSplit(用最佳方式切分数据集 和 生成相应的叶节点) Args: - dataSet 数据集 - leafType 计算叶子节点的函数 - errType 求总方差 - ops [容许误差下降值,切分的最少样本数] + dataSet 加载的原始数据集 + leafType 建立叶子点的函数 + errType 误差计算函数(求总方差) + ops [容许误差下降值,切分的最少样本数] Returns: bestIndex feature的index坐标 bestValue 切分的最优值 @@ -128,6 +128,16 @@ def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)): # assume dataSet is NumPy Mat so we can array filtering def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)): + """createTree(获取回归树) + + Args: + dataSet 加载的原始数据集 + leafType 建立叶子点的函数 + errType 误差计算函数 + ops=(1, 4) [容许误差下降值,切分的最少样本数] + Returns: + retTree 决策树最后的结果 + """ # 选择最好的切分方式: feature索引值,最优切分值 # choose the best split feat, val = chooseBestSplit(dataSet, leafType, errType, ops) @@ -137,7 +147,7 @@ def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)): retTree = {} retTree['spInd'] = feat retTree['spVal'] = val - # 大于在右边,小于在左边 + # 大于在右边,小于在左边,分为2个数据集 lSet, rSet = binSplitDataSet(dataSet, feat, val) # 递归的进行调用 retTree['left'] = createTree(lSet, leafType, errType, ops) @@ -161,21 +171,25 @@ def getMean(tree): # 检查是否适合合并分枝 def prune(tree, testData): - # 判断是否测试数据集没有数据 + # 判断是否测试数据集没有数据,如果没有,就直接返回tree本身的均值 if shape(testData)[0] == 0: return getMean(tree) - # 对测试进行分支,看属于哪只分支,然后返回tree结果的均值 + + # 判断分枝是否是dict字典,如果是就将测试数据集进行切分 if (isTree(tree['right']) or isTree(tree['left'])): lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) + # 如果是左边分枝是字典,就传入左边的数据集和左边的分枝,进行递归 if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet) + # 如果是右边分枝是字典,就传入左边的数据集和左边的分枝,进行递归 if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet) - # 如果左右两边无子分支,那么计算一下总方差 和 该结果集的本身不分枝的总方差比较 - # 1.如果测试数据集足够大,将tree进行分支到最后 - # 2.如果测试数据集不够大,那么就无法进行合并 - # 注意返回的结果: 是合并后对原来为字典tree进行赋值,相当于进行了合并 + # 如果左右两边同时都不是dict字典,那么分割测试数据集。 + # 1. 如果正确 + # * 那么计算一下总方差 和 该结果集的本身不分枝的总方差比较 + # * 如果 合并的总方差 < 不合并的总方差,那么就进行合并 + # 注意返回的结果: 如果可以合并,原来的dict就变为了 数值 if not isTree(tree['left']) and not isTree(tree['right']): lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) # power(x, y)表示x的y次方 @@ -274,27 +288,28 @@ if __name__ == "__main__": # mat0, mat1 = binSplitDataSet(testMat, 1, 0.5) # print mat0, '\n-----------\n', mat1 - # 回归树 + # # 回归树 # myDat = loadDataSet('testData/RT_data1.txt') - # myDat = loadDataSet('testData/RT_data2.txt') + # # myDat = loadDataSet('testData/RT_data2.txt') # myMat = mat(myDat) # myTree = createTree(myMat) + # print myTree - # 1. 预剪枝就是,提起设置最大误差数和最少元素数 + # # 1. 预剪枝就是:提起设置最大误差数和最少元素数 # myDat = loadDataSet('testData/RT_data3.txt') # myMat = mat(myDat) # myTree = createTree(myMat, ops=(0, 1)) # print myTree - # 2.后剪枝 + # # 2.后剪枝就是:通过测试数据,对预测模型进行合并判断 # myDatTest = loadDataSet('testData/RT_data3test.txt') # myMat2Test = mat(myDatTest) # myFinalTree = prune(myTree, myMat2Test) # print '\n\n\n-------------------' # print myFinalTree - # -------- - # 模型树求解 + # # -------- + # # 模型树求解 # myDat = loadDataSet('testData/RT_data4.txt') # myMat = mat(myDat) # myTree = createTree(myMat, modelLeaf, modelErr) @@ -315,11 +330,11 @@ if __name__ == "__main__": print myTree2 print "模型树:", corrcoef(yHat2, testMat[:, 1],rowvar=0)[0, 1] - # 线性回归 - ws, X, Y = linearSolve(trainMat) - print ws - m = len(testMat[:, 0]) - yHat3 = mat(zeros((m, 1))) - for i in range(shape(testMat)[0]): - yHat3[i] = testMat[i, 0]*ws[1, 0] + ws[0, 0] - print "线性回归:", corrcoef(yHat3, testMat[:, 1],rowvar=0)[0, 1] + # # 线性回归 + # ws, X, Y = linearSolve(trainMat) + # print ws + # m = len(testMat[:, 0]) + # yHat3 = mat(zeros((m, 1))) + # for i in range(shape(testMat)[0]): + # yHat3[i] = testMat[i, 0]*ws[1, 0] + ws[0, 0] + # print "线性回归:", corrcoef(yHat3, testMat[:, 1],rowvar=0)[0, 1] diff --git a/src/python/09.RegTrees/treeExplore.py b/src/python/09.RegTrees/treeExplore.py index aa394f5d..a426a342 100644 --- a/src/python/09.RegTrees/treeExplore.py +++ b/src/python/09.RegTrees/treeExplore.py @@ -5,7 +5,7 @@ Created on 2017-03-08 Update on 2017-03-08 Tree-Based Regression Methods Source Code for Machine Learning in Action Ch. 9 -@author: jiangzhonglian +@author: Peter/片刻 ''' import regTrees from Tkinter import *