mirror of
https://github.com/apachecn/ailearning.git
synced 2026-02-13 07:15:26 +08:00
更新 3.决策树 md
This commit is contained in:
@@ -75,24 +75,24 @@ def calcShannonEnt(dataSet):
|
||||
return shannonEnt
|
||||
|
||||
|
||||
def splitDataSet(dataSet, axis, value):
|
||||
"""splitDataSet(通过遍历dataSet数据集,求出axis对应的colnum列的值为value的行)
|
||||
就是依据axis列进行分类,如果axis列的数据等于 value的时候,就要将 axis 划分到我们创建的新的数据集中
|
||||
def splitDataSet(dataSet, index, value):
|
||||
"""splitDataSet(通过遍历dataSet数据集,求出index对应的colnum列的值为value的行)
|
||||
就是依据index列进行分类,如果index列的数据等于 value的时候,就要将 index 划分到我们创建的新的数据集中
|
||||
Args:
|
||||
dataSet 数据集 待划分的数据集
|
||||
axis 表示每一行的axis列 划分数据集的特征
|
||||
value 表示axis列对应的value值 需要返回的特征的值。
|
||||
index 表示每一行的index列 划分数据集的特征
|
||||
value 表示index列对应的value值 需要返回的特征的值。
|
||||
Returns:
|
||||
axis列为value的数据集【该数据集需要排除axis列】
|
||||
index列为value的数据集【该数据集需要排除index列】
|
||||
"""
|
||||
retDataSet = []
|
||||
for featVec in dataSet:
|
||||
# axis列为value的数据集【该数据集需要排除axis列】
|
||||
# 判断axis列的值是否为value
|
||||
if featVec[axis] == value:
|
||||
# chop out axis used for splitting
|
||||
# [:axis]表示前axis行,即若 axis 为2,就是取 featVec 的前 axis 行
|
||||
reducedFeatVec = featVec[:axis]
|
||||
# index列为value的数据集【该数据集需要排除index列】
|
||||
# 判断index列的值是否为value
|
||||
if featVec[index] == value:
|
||||
# chop out index used for splitting
|
||||
# [:index]表示前index行,即若 index 为2,就是取 featVec 的前 index 行
|
||||
reducedFeatVec = featVec[:index]
|
||||
'''
|
||||
请百度查询一下: extend和append的区别
|
||||
list.append(object) 向列表中添加一个对象object
|
||||
@@ -109,11 +109,11 @@ def splitDataSet(dataSet, axis, value):
|
||||
结果:
|
||||
[1, 2, 3]
|
||||
[1, 2, 3, [4, 5, 6]]
|
||||
[1, 2, 3, [4, 5, 6], 7, 8, 9
|
||||
[1, 2, 3, [4, 5, 6], 7, 8, 9]
|
||||
'''
|
||||
reducedFeatVec.extend(featVec[axis+1:])
|
||||
# [axis+1:]表示从跳过 axis 的 axis+1行,取接下来的数据
|
||||
# 收集结果值 axis列为value的行【该行需要排除axis列】
|
||||
reducedFeatVec.extend(featVec[index+1:])
|
||||
# [index+1:]表示从跳过 index 的 index+1行,取接下来的数据
|
||||
# 收集结果值 index列为value的行【该行需要排除index列】
|
||||
retDataSet.append(reducedFeatVec)
|
||||
return retDataSet
|
||||
|
||||
@@ -250,6 +250,7 @@ def grabTree(filename):
|
||||
fr = open(filename)
|
||||
return pickle.load(fr)
|
||||
|
||||
|
||||
def fishTest():
|
||||
# 1.创建数据和结果标签
|
||||
myDat, labels = createDataSet()
|
||||
@@ -269,10 +270,11 @@ def fishTest():
|
||||
myTree = createTree(myDat, copy.deepcopy(labels))
|
||||
print myTree
|
||||
# [1, 1]表示要取的分支上的节点位置,对应的结果值
|
||||
# print classify(myTree, labels, [1, 1])
|
||||
print classify(myTree, labels, [1, 1])
|
||||
|
||||
# 画图可视化展现
|
||||
# dtPlot.createPlot(myTree)
|
||||
dtPlot.createPlot(myTree)
|
||||
|
||||
|
||||
def ContactLensesTest():
|
||||
"""
|
||||
@@ -294,10 +296,9 @@ def ContactLensesTest():
|
||||
lensesTree = createTree(lenses, lensesLabels)
|
||||
print lensesTree
|
||||
# 画图可视化展现
|
||||
# dtPlot.createPlot(lensesTree)
|
||||
dtPlot.createPlot(lensesTree)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fishTest()
|
||||
# ContactLensesTest()
|
||||
|
||||
# fishTest()
|
||||
ContactLensesTest()
|
||||
|
||||
@@ -87,7 +87,7 @@ def plotTree(myTree, parentPt, nodeTxt):
|
||||
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
|
||||
# 并打印输入对应的文字
|
||||
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
|
||||
# plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
|
||||
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
|
||||
|
||||
|
||||
def createPlot(inTree):
|
||||
|
||||
BIN
src/python/3.DecisionTree/decisionTreePlot.pyc
Normal file
BIN
src/python/3.DecisionTree/decisionTreePlot.pyc
Normal file
Binary file not shown.
Reference in New Issue
Block a user