diff --git a/docs/3.决策树.md b/docs/3.决策树.md index 7ce8b86c..d024d7e7 100644 --- a/docs/3.决策树.md +++ b/docs/3.决策树.md @@ -89,11 +89,11 @@ ``` 收集数据:可以使用任何方法 -准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。 -分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。 +准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化 +分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期 训练算法:构造树的数据结构 -测试算法:使用经验树计算错误率 -使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。 +测试算法:使用决策树执行分类 +使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义 ``` > 收集数据:可以使用任何方法 @@ -102,7 +102,7 @@ 我们利用 createDataSet() 函数输入数据 -```Python +```python def createDataSet(): dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], @@ -112,13 +112,17 @@ def createDataSet(): labels = ['no surfacing', 'flippers'] return dataSet, labels ``` -> 准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。 +> 准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化 此处,由于我们输入的数据本身就是离散化数据,所以这一步就省略了。 -计算给定数据集的香农熵 +> 分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期 -```Python +![熵的计算公式](/images/3.DecisionTree/熵的计算公式.jpg) + +计算给定数据集的香农熵的函数 + +```python def calcShannonEnt(dataSet): # 求list的长度,表示计算参与训练的数据量 numEntries = len(dataSet) @@ -145,20 +149,55 @@ def calcShannonEnt(dataSet): 按照给定特征划分数据集 -```Python -def splitDataSet(dataSet, axis, value): +`将指定特征的特征值等于 value 的行剩下列作为子数据集。` + +```python +def splitDataSet(dataSet, index, value): + """splitDataSet(通过遍历dataSet数据集,求出index对应的colnum列的值为value的行) + 就是依据index列进行分类,如果index列的数据等于 value的时候,就要将 index 划分到我们创建的新的数据集中 + Args: + dataSet 数据集 待划分的数据集 + index 表示每一行的index列 划分数据集的特征 + value 表示index列对应的value值 需要返回的特征的值。 + Returns: + index列为value的数据集【该数据集需要排除index列】 + """ retDataSet = [] - for featVec in dataSet: - if featVec[axis] == value: - reducedFeatVec = featVec[:axis] - reducedFeatVec.extend(featVec[axis+1:]) + for featVec in dataSet: + # index列为value的数据集【该数据集需要排除index列】 + # 判断index列的值是否为value + if featVec[index] == value: + # chop out index used for splitting + # [:index]表示前index行,即若 index 为2,就是取 featVec 的前 index 行 + reducedFeatVec = featVec[:index] + ''' + 请百度查询一下: extend和append的区别 + list.append(object) 向列表中添加一个对象object + list.extend(sequence) 把一个序列seq的内容添加到列表中 + 1、使用append的时候,是将new_media看作一个对象,整体打包添加到music_media对象中。 + 2、使用extend的时候,是将new_media看作一个序列,将这个序列和music_media序列合并,并放在其后面。 + result = [] + result.extend([1,2,3]) + print result + result.append([4,5,6]) + print result + result.extend([7,8,9]) + print result + 结果: + [1, 2, 3] + [1, 2, 3, [4, 5, 6]] + [1, 2, 3, [4, 5, 6], 7, 8, 9] + ''' + reducedFeatVec.extend(featVec[index+1:]) + # [index+1:]表示从跳过 index 的 index+1行,取接下来的数据 + # 收集结果值 index列为value的行【该行需要排除index列】 retDataSet.append(reducedFeatVec) return retDataSet ``` 选择最好的数据集划分方式 -```Python +```python def chooseBestFeatureToSplit(dataSet): """chooseBestFeatureToSplit(选择最好的特征) @@ -169,14 +208,14 @@ def chooseBestFeatureToSplit(dataSet): """ # 求第一行有多少列的 Feature, 最后一列是label列嘛 numFeatures = len(dataSet[0]) - 1 - # label的信息熵 + # 数据集的原始信息熵 baseEntropy = calcShannonEnt(dataSet) # 最优的信息增益值, 和最优的Featurn编号 bestInfoGain, bestFeature = 0.0, -1 # iterate over all the features for i in range(numFeatures): # create a list of all the examples of this feature - # 获取每一个实例的第i+1个feature,组成list集合 + # 获取对应的feature下的所有数据 featList = [example[i] for example in dataSet] # get a set of unique values # 获取剔重后的集合,使用set对list数据进行去重 @@ -187,7 +226,9 @@ def chooseBestFeatureToSplit(dataSet): # 遍历当前特征中的所有唯一属性值,对每个唯一属性值划分一次数据集,计算数据集的新熵值,并对所有唯一特征值得到的熵求和。 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) + # 计算概率 prob = len(subDataSet)/float(len(dataSet)) + # 计算信息熵 newEntropy += prob * calcShannonEnt(subDataSet) # gain[信息增益]: 划分数据集前后的信息变化, 获取信息熵最大的值 # 信息增益是熵的减少或者是数据无序度的减少。最后,比较所有特征中的信息增益,返回最好特征划分的索引值。 @@ -199,11 +240,17 @@ def chooseBestFeatureToSplit(dataSet): return bestFeature ``` +``` +问:上面的 newEntropy 为什么是根据子集计算的呢? +答:因为我们在根据一个特征计算香农熵的时候,该特征的分类值是相同,这个特征这个分类的香农熵为 0; +这就是为什么计算新的香农熵的时候使用的是子集。 +``` + > 训练算法:构造树的数据结构 -创建树的函数代码 +创建树的函数代码如下: -```Python +```python def createTree(dataSet, labels): classList = [example[-1] for example in dataSet] # 如果数据集的最后一列的第一个值出现的次数=整个集合的数量,也就说只有一个类别,就只直接返回结果就行 @@ -237,7 +284,36 @@ def createTree(dataSet, labels): return myTree ``` -> 测试算法:使用经验树计算错误率 +> 测试算法:使用决策树执行分类 + +```python +def classify(inputTree, featLabels, testVec): + """classify(给输入的节点,进行分类) + + Args: + inputTree 决策树模型 + featLabels Feature标签对应的名称 + testVec 测试输入的数据 + Returns: + classLabel 分类的结果值,需要映射label才能知道名称 + """ + # 获取tree的根节点对于的key值 + firstStr = inputTree.keys()[0] + # 通过key得到根节点对应的value + secondDict = inputTree[firstStr] + # 判断根节点名称获取根节点在label中的先后顺序,这样就知道输入的testVec怎么开始对照树来做分类 + featIndex = featLabels.index(firstStr) + # 测试数据,找到根节点对应的label位置,也就知道从输入的数据的第几位来开始分类 + key = testVec[featIndex] + valueOfFeat = secondDict[key] + print '+++', firstStr, 'xxx', secondDict, '---', key, '>>>', valueOfFeat + # 判断分枝是否结束: 判断valueOfFeat是否是dict类型 + if isinstance(valueOfFeat, dict): + classLabel = classify(valueOfFeat, featLabels, testVec) + else: + classLabel = valueOfFeat + return classLabel +``` > 使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。 @@ -247,7 +323,7 @@ def createTree(dataSet, labels): #### 项目概述 -隐形眼镜类型包括应材质、软材质以及不适合佩戴隐形眼镜。我们需要使用决策树预测患者需要佩戴的隐形眼镜类型。 +隐形眼镜类型包括硬材质、软材质以及不适合佩戴隐形眼镜。我们需要使用决策树预测患者需要佩戴的隐形眼镜类型。 #### 开发流程 @@ -270,20 +346,20 @@ presbyopic myope no reduced no lenses > 解析数据:解析 tab 键分隔的数据行 -```Python +```python lecses = [inst.strip().split('\t') for inst in fr.readlines()] lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate'] ``` > 分析数据:快速检查数据,确保正确地解析数据内容,使用 createPlot() 函数绘制最终的树形图。 -```Python +```python >>> treePlotter.createPlot(lensesTree) ``` > 训练算法:使用 createTree() 函数 -```Python +```python >>> lensesTree = trees.createTree(lenses, lensesLabels) >>> lensesTree {'tearRate': {'reduced': 'no lenses', 'normal': {'astigmatic':{'yes': @@ -299,7 +375,7 @@ lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate'] 使用 pickle 模块存储决策树 -```Python +```python def storeTree(inputTree, filename): impory pickle fw = open(filename, 'w') diff --git a/images/3.DecisionTree/熵的计算公式.jpg b/images/3.DecisionTree/熵的计算公式.jpg new file mode 100644 index 00000000..58fbc8b3 Binary files /dev/null and b/images/3.DecisionTree/熵的计算公式.jpg differ diff --git a/src/python/3.DecisionTree/DecisionTree.py b/src/python/3.DecisionTree/DecisionTree.py index f34a568c..8251bb90 100755 --- a/src/python/3.DecisionTree/DecisionTree.py +++ b/src/python/3.DecisionTree/DecisionTree.py @@ -75,24 +75,24 @@ def calcShannonEnt(dataSet): return shannonEnt -def splitDataSet(dataSet, axis, value): - """splitDataSet(通过遍历dataSet数据集,求出axis对应的colnum列的值为value的行) - 就是依据axis列进行分类,如果axis列的数据等于 value的时候,就要将 axis 划分到我们创建的新的数据集中 +def splitDataSet(dataSet, index, value): + """splitDataSet(通过遍历dataSet数据集,求出index对应的colnum列的值为value的行) + 就是依据index列进行分类,如果index列的数据等于 value的时候,就要将 index 划分到我们创建的新的数据集中 Args: dataSet 数据集 待划分的数据集 - axis 表示每一行的axis列 划分数据集的特征 - value 表示axis列对应的value值 需要返回的特征的值。 + index 表示每一行的index列 划分数据集的特征 + value 表示index列对应的value值 需要返回的特征的值。 Returns: - axis列为value的数据集【该数据集需要排除axis列】 + index列为value的数据集【该数据集需要排除index列】 """ retDataSet = [] for featVec in dataSet: - # axis列为value的数据集【该数据集需要排除axis列】 - # 判断axis列的值是否为value - if featVec[axis] == value: - # chop out axis used for splitting - # [:axis]表示前axis行,即若 axis 为2,就是取 featVec 的前 axis 行 - reducedFeatVec = featVec[:axis] + # index列为value的数据集【该数据集需要排除index列】 + # 判断index列的值是否为value + if featVec[index] == value: + # chop out index used for splitting + # [:index]表示前index行,即若 index 为2,就是取 featVec 的前 index 行 + reducedFeatVec = featVec[:index] ''' 请百度查询一下: extend和append的区别 list.append(object) 向列表中添加一个对象object @@ -109,11 +109,11 @@ def splitDataSet(dataSet, axis, value): 结果: [1, 2, 3] [1, 2, 3, [4, 5, 6]] - [1, 2, 3, [4, 5, 6], 7, 8, 9 + [1, 2, 3, [4, 5, 6], 7, 8, 9] ''' - reducedFeatVec.extend(featVec[axis+1:]) - # [axis+1:]表示从跳过 axis 的 axis+1行,取接下来的数据 - # 收集结果值 axis列为value的行【该行需要排除axis列】 + reducedFeatVec.extend(featVec[index+1:]) + # [index+1:]表示从跳过 index 的 index+1行,取接下来的数据 + # 收集结果值 index列为value的行【该行需要排除index列】 retDataSet.append(reducedFeatVec) return retDataSet @@ -250,6 +250,7 @@ def grabTree(filename): fr = open(filename) return pickle.load(fr) + def fishTest(): # 1.创建数据和结果标签 myDat, labels = createDataSet() @@ -269,10 +270,11 @@ def fishTest(): myTree = createTree(myDat, copy.deepcopy(labels)) print myTree # [1, 1]表示要取的分支上的节点位置,对应的结果值 - # print classify(myTree, labels, [1, 1]) + print classify(myTree, labels, [1, 1]) # 画图可视化展现 - # dtPlot.createPlot(myTree) + dtPlot.createPlot(myTree) + def ContactLensesTest(): """ @@ -294,10 +296,9 @@ def ContactLensesTest(): lensesTree = createTree(lenses, lensesLabels) print lensesTree # 画图可视化展现 - # dtPlot.createPlot(lensesTree) + dtPlot.createPlot(lensesTree) if __name__ == "__main__": - fishTest() - # ContactLensesTest() - + # fishTest() + ContactLensesTest() diff --git a/src/python/3.DecisionTree/decisionTreePlot.py b/src/python/3.DecisionTree/decisionTreePlot.py index 59b97751..1cafebec 100644 --- a/src/python/3.DecisionTree/decisionTreePlot.py +++ b/src/python/3.DecisionTree/decisionTreePlot.py @@ -87,7 +87,7 @@ def plotTree(myTree, parentPt, nodeTxt): plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) # 并打印输入对应的文字 plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) - # plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD + plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD def createPlot(inTree): diff --git a/src/python/3.DecisionTree/decisionTreePlot.pyc b/src/python/3.DecisionTree/decisionTreePlot.pyc new file mode 100644 index 00000000..c9f9da40 Binary files /dev/null and b/src/python/3.DecisionTree/decisionTreePlot.pyc differ