mirror of
https://github.com/apachecn/ailearning.git
synced 2026-05-08 23:12:06 +08:00
更新 3.决策树 md
This commit is contained in:
126
docs/3.决策树.md
126
docs/3.决策树.md
@@ -89,11 +89,11 @@
|
||||
|
||||
```
|
||||
收集数据:可以使用任何方法
|
||||
准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。
|
||||
分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。
|
||||
准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化
|
||||
分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期
|
||||
训练算法:构造树的数据结构
|
||||
测试算法:使用经验树计算错误率
|
||||
使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。
|
||||
测试算法:使用决策树执行分类
|
||||
使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义
|
||||
```
|
||||
|
||||
> 收集数据:可以使用任何方法
|
||||
@@ -102,7 +102,7 @@
|
||||
|
||||
我们利用 createDataSet() 函数输入数据
|
||||
|
||||
```Python
|
||||
```python
|
||||
def createDataSet():
|
||||
dataSet = [[1, 1, 'yes'],
|
||||
[1, 1, 'yes'],
|
||||
@@ -112,13 +112,17 @@ def createDataSet():
|
||||
labels = ['no surfacing', 'flippers']
|
||||
return dataSet, labels
|
||||
```
|
||||
> 准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。
|
||||
> 准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化
|
||||
|
||||
此处,由于我们输入的数据本身就是离散化数据,所以这一步就省略了。
|
||||
|
||||
计算给定数据集的香农熵
|
||||
> 分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期
|
||||
|
||||
```Python
|
||||

|
||||
|
||||
计算给定数据集的香农熵的函数
|
||||
|
||||
```python
|
||||
def calcShannonEnt(dataSet):
|
||||
# 求list的长度,表示计算参与训练的数据量
|
||||
numEntries = len(dataSet)
|
||||
@@ -145,20 +149,55 @@ def calcShannonEnt(dataSet):
|
||||
|
||||
按照给定特征划分数据集
|
||||
|
||||
```Python
|
||||
def splitDataSet(dataSet, axis, value):
|
||||
`将指定特征的特征值等于 value 的行剩下列作为子数据集。`
|
||||
|
||||
```python
|
||||
def splitDataSet(dataSet, index, value):
|
||||
"""splitDataSet(通过遍历dataSet数据集,求出index对应的colnum列的值为value的行)
|
||||
就是依据index列进行分类,如果index列的数据等于 value的时候,就要将 index 划分到我们创建的新的数据集中
|
||||
Args:
|
||||
dataSet 数据集 待划分的数据集
|
||||
index 表示每一行的index列 划分数据集的特征
|
||||
value 表示index列对应的value值 需要返回的特征的值。
|
||||
Returns:
|
||||
index列为value的数据集【该数据集需要排除index列】
|
||||
"""
|
||||
retDataSet = []
|
||||
for featVec in dataSet:
|
||||
if featVec[axis] == value:
|
||||
reducedFeatVec = featVec[:axis]
|
||||
reducedFeatVec.extend(featVec[axis+1:])
|
||||
for featVec in dataSet:
|
||||
# index列为value的数据集【该数据集需要排除index列】
|
||||
# 判断index列的值是否为value
|
||||
if featVec[index] == value:
|
||||
# chop out index used for splitting
|
||||
# [:index]表示前index行,即若 index 为2,就是取 featVec 的前 index 行
|
||||
reducedFeatVec = featVec[:index]
|
||||
'''
|
||||
请百度查询一下: extend和append的区别
|
||||
list.append(object) 向列表中添加一个对象object
|
||||
list.extend(sequence) 把一个序列seq的内容添加到列表中
|
||||
1、使用append的时候,是将new_media看作一个对象,整体打包添加到music_media对象中。
|
||||
2、使用extend的时候,是将new_media看作一个序列,将这个序列和music_media序列合并,并放在其后面。
|
||||
result = []
|
||||
result.extend([1,2,3])
|
||||
print result
|
||||
result.append([4,5,6])
|
||||
print result
|
||||
result.extend([7,8,9])
|
||||
print result
|
||||
结果:
|
||||
[1, 2, 3]
|
||||
[1, 2, 3, [4, 5, 6]]
|
||||
[1, 2, 3, [4, 5, 6], 7, 8, 9]
|
||||
'''
|
||||
reducedFeatVec.extend(featVec[index+1:])
|
||||
# [index+1:]表示从跳过 index 的 index+1行,取接下来的数据
|
||||
# 收集结果值 index列为value的行【该行需要排除index列】
|
||||
retDataSet.append(reducedFeatVec)
|
||||
return retDataSet
|
||||
```
|
||||
|
||||
选择最好的数据集划分方式
|
||||
|
||||
```Python
|
||||
```python
|
||||
def chooseBestFeatureToSplit(dataSet):
|
||||
"""chooseBestFeatureToSplit(选择最好的特征)
|
||||
|
||||
@@ -169,14 +208,14 @@ def chooseBestFeatureToSplit(dataSet):
|
||||
"""
|
||||
# 求第一行有多少列的 Feature, 最后一列是label列嘛
|
||||
numFeatures = len(dataSet[0]) - 1
|
||||
# label的信息熵
|
||||
# 数据集的原始信息熵
|
||||
baseEntropy = calcShannonEnt(dataSet)
|
||||
# 最优的信息增益值, 和最优的Featurn编号
|
||||
bestInfoGain, bestFeature = 0.0, -1
|
||||
# iterate over all the features
|
||||
for i in range(numFeatures):
|
||||
# create a list of all the examples of this feature
|
||||
# 获取每一个实例的第i+1个feature,组成list集合
|
||||
# 获取对应的feature下的所有数据
|
||||
featList = [example[i] for example in dataSet]
|
||||
# get a set of unique values
|
||||
# 获取剔重后的集合,使用set对list数据进行去重
|
||||
@@ -187,7 +226,9 @@ def chooseBestFeatureToSplit(dataSet):
|
||||
# 遍历当前特征中的所有唯一属性值,对每个唯一属性值划分一次数据集,计算数据集的新熵值,并对所有唯一特征值得到的熵求和。
|
||||
for value in uniqueVals:
|
||||
subDataSet = splitDataSet(dataSet, i, value)
|
||||
# 计算概率
|
||||
prob = len(subDataSet)/float(len(dataSet))
|
||||
# 计算信息熵
|
||||
newEntropy += prob * calcShannonEnt(subDataSet)
|
||||
# gain[信息增益]: 划分数据集前后的信息变化, 获取信息熵最大的值
|
||||
# 信息增益是熵的减少或者是数据无序度的减少。最后,比较所有特征中的信息增益,返回最好特征划分的索引值。
|
||||
@@ -199,11 +240,17 @@ def chooseBestFeatureToSplit(dataSet):
|
||||
return bestFeature
|
||||
```
|
||||
|
||||
```
|
||||
问:上面的 newEntropy 为什么是根据子集计算的呢?
|
||||
答:因为我们在根据一个特征计算香农熵的时候,该特征的分类值是相同,这个特征这个分类的香农熵为 0;
|
||||
这就是为什么计算新的香农熵的时候使用的是子集。
|
||||
```
|
||||
|
||||
> 训练算法:构造树的数据结构
|
||||
|
||||
创建树的函数代码
|
||||
创建树的函数代码如下:
|
||||
|
||||
```Python
|
||||
```python
|
||||
def createTree(dataSet, labels):
|
||||
classList = [example[-1] for example in dataSet]
|
||||
# 如果数据集的最后一列的第一个值出现的次数=整个集合的数量,也就说只有一个类别,就只直接返回结果就行
|
||||
@@ -237,7 +284,36 @@ def createTree(dataSet, labels):
|
||||
return myTree
|
||||
```
|
||||
|
||||
> 测试算法:使用经验树计算错误率
|
||||
> 测试算法:使用决策树执行分类
|
||||
|
||||
```python
|
||||
def classify(inputTree, featLabels, testVec):
|
||||
"""classify(给输入的节点,进行分类)
|
||||
|
||||
Args:
|
||||
inputTree 决策树模型
|
||||
featLabels Feature标签对应的名称
|
||||
testVec 测试输入的数据
|
||||
Returns:
|
||||
classLabel 分类的结果值,需要映射label才能知道名称
|
||||
"""
|
||||
# 获取tree的根节点对于的key值
|
||||
firstStr = inputTree.keys()[0]
|
||||
# 通过key得到根节点对应的value
|
||||
secondDict = inputTree[firstStr]
|
||||
# 判断根节点名称获取根节点在label中的先后顺序,这样就知道输入的testVec怎么开始对照树来做分类
|
||||
featIndex = featLabels.index(firstStr)
|
||||
# 测试数据,找到根节点对应的label位置,也就知道从输入的数据的第几位来开始分类
|
||||
key = testVec[featIndex]
|
||||
valueOfFeat = secondDict[key]
|
||||
print '+++', firstStr, 'xxx', secondDict, '---', key, '>>>', valueOfFeat
|
||||
# 判断分枝是否结束: 判断valueOfFeat是否是dict类型
|
||||
if isinstance(valueOfFeat, dict):
|
||||
classLabel = classify(valueOfFeat, featLabels, testVec)
|
||||
else:
|
||||
classLabel = valueOfFeat
|
||||
return classLabel
|
||||
```
|
||||
|
||||
> 使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。
|
||||
|
||||
@@ -247,7 +323,7 @@ def createTree(dataSet, labels):
|
||||
|
||||
#### 项目概述
|
||||
|
||||
隐形眼镜类型包括应材质、软材质以及不适合佩戴隐形眼镜。我们需要使用决策树预测患者需要佩戴的隐形眼镜类型。
|
||||
隐形眼镜类型包括硬材质、软材质以及不适合佩戴隐形眼镜。我们需要使用决策树预测患者需要佩戴的隐形眼镜类型。
|
||||
|
||||
#### 开发流程
|
||||
|
||||
@@ -270,20 +346,20 @@ presbyopic myope no reduced no lenses
|
||||
|
||||
> 解析数据:解析 tab 键分隔的数据行
|
||||
|
||||
```Python
|
||||
```python
|
||||
lecses = [inst.strip().split('\t') for inst in fr.readlines()]
|
||||
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
|
||||
```
|
||||
|
||||
> 分析数据:快速检查数据,确保正确地解析数据内容,使用 createPlot() 函数绘制最终的树形图。
|
||||
|
||||
```Python
|
||||
```python
|
||||
>>> treePlotter.createPlot(lensesTree)
|
||||
```
|
||||
|
||||
> 训练算法:使用 createTree() 函数
|
||||
|
||||
```Python
|
||||
```python
|
||||
>>> lensesTree = trees.createTree(lenses, lensesLabels)
|
||||
>>> lensesTree
|
||||
{'tearRate': {'reduced': 'no lenses', 'normal': {'astigmatic':{'yes':
|
||||
@@ -299,7 +375,7 @@ lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
|
||||
|
||||
使用 pickle 模块存储决策树
|
||||
|
||||
```Python
|
||||
```python
|
||||
def storeTree(inputTree, filename):
|
||||
impory pickle
|
||||
fw = open(filename, 'w')
|
||||
|
||||
BIN
images/3.DecisionTree/熵的计算公式.jpg
Normal file
BIN
images/3.DecisionTree/熵的计算公式.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 28 KiB |
@@ -75,24 +75,24 @@ def calcShannonEnt(dataSet):
|
||||
return shannonEnt
|
||||
|
||||
|
||||
def splitDataSet(dataSet, axis, value):
|
||||
"""splitDataSet(通过遍历dataSet数据集,求出axis对应的colnum列的值为value的行)
|
||||
就是依据axis列进行分类,如果axis列的数据等于 value的时候,就要将 axis 划分到我们创建的新的数据集中
|
||||
def splitDataSet(dataSet, index, value):
|
||||
"""splitDataSet(通过遍历dataSet数据集,求出index对应的colnum列的值为value的行)
|
||||
就是依据index列进行分类,如果index列的数据等于 value的时候,就要将 index 划分到我们创建的新的数据集中
|
||||
Args:
|
||||
dataSet 数据集 待划分的数据集
|
||||
axis 表示每一行的axis列 划分数据集的特征
|
||||
value 表示axis列对应的value值 需要返回的特征的值。
|
||||
index 表示每一行的index列 划分数据集的特征
|
||||
value 表示index列对应的value值 需要返回的特征的值。
|
||||
Returns:
|
||||
axis列为value的数据集【该数据集需要排除axis列】
|
||||
index列为value的数据集【该数据集需要排除index列】
|
||||
"""
|
||||
retDataSet = []
|
||||
for featVec in dataSet:
|
||||
# axis列为value的数据集【该数据集需要排除axis列】
|
||||
# 判断axis列的值是否为value
|
||||
if featVec[axis] == value:
|
||||
# chop out axis used for splitting
|
||||
# [:axis]表示前axis行,即若 axis 为2,就是取 featVec 的前 axis 行
|
||||
reducedFeatVec = featVec[:axis]
|
||||
# index列为value的数据集【该数据集需要排除index列】
|
||||
# 判断index列的值是否为value
|
||||
if featVec[index] == value:
|
||||
# chop out index used for splitting
|
||||
# [:index]表示前index行,即若 index 为2,就是取 featVec 的前 index 行
|
||||
reducedFeatVec = featVec[:index]
|
||||
'''
|
||||
请百度查询一下: extend和append的区别
|
||||
list.append(object) 向列表中添加一个对象object
|
||||
@@ -109,11 +109,11 @@ def splitDataSet(dataSet, axis, value):
|
||||
结果:
|
||||
[1, 2, 3]
|
||||
[1, 2, 3, [4, 5, 6]]
|
||||
[1, 2, 3, [4, 5, 6], 7, 8, 9
|
||||
[1, 2, 3, [4, 5, 6], 7, 8, 9]
|
||||
'''
|
||||
reducedFeatVec.extend(featVec[axis+1:])
|
||||
# [axis+1:]表示从跳过 axis 的 axis+1行,取接下来的数据
|
||||
# 收集结果值 axis列为value的行【该行需要排除axis列】
|
||||
reducedFeatVec.extend(featVec[index+1:])
|
||||
# [index+1:]表示从跳过 index 的 index+1行,取接下来的数据
|
||||
# 收集结果值 index列为value的行【该行需要排除index列】
|
||||
retDataSet.append(reducedFeatVec)
|
||||
return retDataSet
|
||||
|
||||
@@ -250,6 +250,7 @@ def grabTree(filename):
|
||||
fr = open(filename)
|
||||
return pickle.load(fr)
|
||||
|
||||
|
||||
def fishTest():
|
||||
# 1.创建数据和结果标签
|
||||
myDat, labels = createDataSet()
|
||||
@@ -269,10 +270,11 @@ def fishTest():
|
||||
myTree = createTree(myDat, copy.deepcopy(labels))
|
||||
print myTree
|
||||
# [1, 1]表示要取的分支上的节点位置,对应的结果值
|
||||
# print classify(myTree, labels, [1, 1])
|
||||
print classify(myTree, labels, [1, 1])
|
||||
|
||||
# 画图可视化展现
|
||||
# dtPlot.createPlot(myTree)
|
||||
dtPlot.createPlot(myTree)
|
||||
|
||||
|
||||
def ContactLensesTest():
|
||||
"""
|
||||
@@ -294,10 +296,9 @@ def ContactLensesTest():
|
||||
lensesTree = createTree(lenses, lensesLabels)
|
||||
print lensesTree
|
||||
# 画图可视化展现
|
||||
# dtPlot.createPlot(lensesTree)
|
||||
dtPlot.createPlot(lensesTree)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fishTest()
|
||||
# ContactLensesTest()
|
||||
|
||||
# fishTest()
|
||||
ContactLensesTest()
|
||||
|
||||
@@ -87,7 +87,7 @@ def plotTree(myTree, parentPt, nodeTxt):
|
||||
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
|
||||
# 并打印输入对应的文字
|
||||
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
|
||||
# plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
|
||||
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
|
||||
|
||||
|
||||
def createPlot(inTree):
|
||||
|
||||
BIN
src/python/3.DecisionTree/decisionTreePlot.pyc
Normal file
BIN
src/python/3.DecisionTree/decisionTreePlot.pyc
Normal file
Binary file not shown.
Reference in New Issue
Block a user