mirror of
https://github.com/apachecn/ailearning.git
synced 2026-06-30 10:16:12 +08:00
更新构建树的Coding
This commit is contained in:
16
src/python/09.RegTrees/TreeNode.py
Normal file
16
src/python/09.RegTrees/TreeNode.py
Normal file
@@ -0,0 +1,16 @@
|
||||
#!/usr/bin/python
|
||||
# coding:utf8
|
||||
|
||||
'''
|
||||
Created on 2017-03-06
|
||||
Update on 2017-03-06
|
||||
@author: jiangzhonglian
|
||||
'''
|
||||
|
||||
|
||||
class treeNode():
|
||||
def __init__(self, feat, val, right, left):
|
||||
self.featureToSplitOn = feat
|
||||
self.valueOfSplit = val
|
||||
self.rightBranch = right
|
||||
self.leftBranch = left
|
||||
@@ -9,25 +9,136 @@ Tree-Based Regression Methods Source Code for Machine Learning in Action Ch. 9
|
||||
'''
|
||||
from numpy import *
|
||||
|
||||
def loadDataSet(fileName): #general function to parse tab -delimited floats
|
||||
dataMat = [] #assume last column is target value
|
||||
|
||||
# 默认解析的数据是用tab分隔,并且是数值类型
|
||||
# general function to parse tab -delimited floats
|
||||
def loadDataSet(fileName):
|
||||
"""loadDataSet(解析每一行,并转化为float类型)
|
||||
|
||||
Args:
|
||||
fileName 文件名
|
||||
Returns:
|
||||
dataMat 每一行的数据集array类型
|
||||
Raises:
|
||||
"""
|
||||
# 假定最后一列是结果值
|
||||
# assume last column is target value
|
||||
dataMat = []
|
||||
fr = open(fileName)
|
||||
for line in fr.readlines():
|
||||
curLine = line.strip().split('\t')
|
||||
fltLine = map(float,curLine) #map all elements to float()
|
||||
# 将所有的元素转化为float类型
|
||||
# map all elements to float()
|
||||
fltLine = map(float, curLine)
|
||||
dataMat.append(fltLine)
|
||||
return dataMat
|
||||
|
||||
|
||||
def binSplitDataSet(dataSet, feature, value):
|
||||
mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
|
||||
mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
|
||||
return mat0,mat1
|
||||
"""binSplitDataSet(将数据集,按照feature列的value进行 二元切分)
|
||||
|
||||
def regLeaf(dataSet):#returns the value used for each leaf
|
||||
return mean(dataSet[:,-1])
|
||||
Args:
|
||||
fileName 文件名
|
||||
Returns:
|
||||
dataMat 每一行的数据集array类型
|
||||
Raises:
|
||||
"""
|
||||
# # 测试案例
|
||||
# print 'dataSet[:, feature]=', dataSet[:, feature]
|
||||
# print 'nonzero(dataSet[:, feature] > value)[0]=', nonzero(dataSet[:, feature] > value)[0]
|
||||
# print 'nonzero(dataSet[:, feature] <= value)[0]=', nonzero(dataSet[:, feature] <= value)[0]
|
||||
|
||||
# dataSet[:, feature] 取去每一行中,第1列的值(从0开始算)
|
||||
# nonzero(dataSet[:, feature] > value) 返回结果为true行的index下标
|
||||
mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]
|
||||
mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
|
||||
return mat0, mat1
|
||||
|
||||
|
||||
# 返回每一个叶子结点的均值
|
||||
# returns the value used for each leaf
|
||||
def regLeaf(dataSet):
|
||||
return mean(dataSet[:, -1])
|
||||
|
||||
|
||||
# 计算总方差=方差*样本数
|
||||
def regErr(dataSet):
|
||||
return var(dataSet[:,-1]) * shape(dataSet)[0]
|
||||
# shape(dataSet)[0] 表示行数
|
||||
return var(dataSet[:, -1]) * shape(dataSet)[0]
|
||||
|
||||
|
||||
# 1.用最佳方式切分数据集
|
||||
# 2.生成相应的叶节点
|
||||
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
|
||||
"""chooseBestSplit(用最佳方式切分数据集 和 生成相应的叶节点)
|
||||
|
||||
Args:
|
||||
dataSet 数据集
|
||||
leafType 计算叶子节点的函数
|
||||
errType 求总方差
|
||||
ops [容许误差下降值,切分的最少样本数]
|
||||
Returns:
|
||||
bestIndex feature的index坐标
|
||||
bestValue 切分的最优值
|
||||
Raises:
|
||||
"""
|
||||
tolS = ops[0]
|
||||
tolN = ops[1]
|
||||
# 如果结果集(最后一列为1个变量),就返回推出
|
||||
# .T 对数据集进行转置
|
||||
# .tolist()[0] 转化为数组并取第0列
|
||||
if len(set(dataSet[:, -1].T.tolist()[0])) == 1:
|
||||
# exit cond 1
|
||||
return None, leafType(dataSet)
|
||||
# 计算行列值
|
||||
m, n = shape(dataSet)
|
||||
print m, n
|
||||
# 无分类误差的总方差和
|
||||
# the choice of the best feature is driven by Reduction in RSS error from mean
|
||||
S = errType(dataSet)
|
||||
# inf 正无穷大
|
||||
bestS, bestIndex, bestValue = inf, 0, 0
|
||||
# 循环处理每一列对应的feature值
|
||||
for featIndex in range(n-1):
|
||||
for splitVal in set(dataSet[:, featIndex].T.tolist()[0]):
|
||||
# 对该列进行分组,然后组内的成员的val值进行 二元切分
|
||||
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
|
||||
# 判断二元切分的方式的元素数量是否符合预期
|
||||
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
|
||||
continue
|
||||
newS = errType(mat0) + errType(mat1)
|
||||
# 如果二元切分,算出来的误差在可接受范围内,那么就记录切分点,并记录最小误差
|
||||
if newS < bestS:
|
||||
bestIndex = featIndex
|
||||
bestValue = splitVal
|
||||
bestS = newS
|
||||
# 判断二元切分的方式的元素误差是否符合预期
|
||||
# if the decrease (S-bestS) is less than a threshold don't do the split
|
||||
if (S - bestS) < tolS:
|
||||
return None, leafType(dataSet)
|
||||
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
|
||||
# 对整体的成员进行判断,是否符合预期
|
||||
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
|
||||
return None, leafType(dataSet)
|
||||
return bestIndex, bestValue
|
||||
|
||||
|
||||
# assume dataSet is NumPy Mat so we can array filtering
|
||||
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
|
||||
# 选择最好的切分方式: feature索引值,最优切分值
|
||||
# choose the best split
|
||||
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
|
||||
# if the splitting hit a stop condition return val
|
||||
if feat is None:
|
||||
return val
|
||||
retTree = {}
|
||||
retTree['spInd'] = feat
|
||||
retTree['spVal'] = val
|
||||
lSet, rSet = binSplitDataSet(dataSet, feat, val)
|
||||
retTree['right'] = createTree(lSet, leafType, errType, ops)
|
||||
retTree['left'] = createTree(rSet, leafType, errType, ops)
|
||||
return retTree
|
||||
|
||||
|
||||
def linearSolve(dataSet): #helper function used in two places
|
||||
m,n = shape(dataSet)
|
||||
@@ -49,43 +160,7 @@ def modelErr(dataSet):
|
||||
yHat = X * ws
|
||||
return sum(power(Y - yHat,2))
|
||||
|
||||
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
|
||||
tolS = ops[0]; tolN = ops[1]
|
||||
#if all the target variables are the same value: quit and return value
|
||||
if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1
|
||||
return None, leafType(dataSet)
|
||||
m,n = shape(dataSet)
|
||||
#the choice of the best feature is driven by Reduction in RSS error from mean
|
||||
S = errType(dataSet)
|
||||
bestS = inf; bestIndex = 0; bestValue = 0
|
||||
for featIndex in range(n-1):
|
||||
for splitVal in set(dataSet[:,featIndex]):
|
||||
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
|
||||
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
|
||||
newS = errType(mat0) + errType(mat1)
|
||||
if newS < bestS:
|
||||
bestIndex = featIndex
|
||||
bestValue = splitVal
|
||||
bestS = newS
|
||||
#if the decrease (S-bestS) is less than a threshold don't do the split
|
||||
if (S - bestS) < tolS:
|
||||
return None, leafType(dataSet) #exit cond 2
|
||||
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
|
||||
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #exit cond 3
|
||||
return None, leafType(dataSet)
|
||||
return bestIndex,bestValue#returns the best feature to split on
|
||||
#and the value used for that split
|
||||
|
||||
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering
|
||||
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split
|
||||
if feat == None: return val #if the splitting hit a stop condition return val
|
||||
retTree = {}
|
||||
retTree['spInd'] = feat
|
||||
retTree['spVal'] = val
|
||||
lSet, rSet = binSplitDataSet(dataSet, feat, val)
|
||||
retTree['left'] = createTree(lSet, leafType, errType, ops)
|
||||
retTree['right'] = createTree(rSet, leafType, errType, ops)
|
||||
return retTree
|
||||
|
||||
def isTree(obj):
|
||||
return (type(obj).__name__=='dict')
|
||||
@@ -137,4 +212,21 @@ def createForeCast(tree, testData, modelEval=regTreeEval):
|
||||
yHat = mat(zeros((m,1)))
|
||||
for i in range(m):
|
||||
yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)
|
||||
return yHat
|
||||
return yHat
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
||||
# # 测试数据集
|
||||
# testMat = mat(eye(4))
|
||||
# print testMat
|
||||
# print type(testMat)
|
||||
# mat0, mat1 = binSplitDataSet(testMat, 1, 0.5)
|
||||
# print mat0, '\n-----------\n', mat1
|
||||
|
||||
# 获取数据集
|
||||
# myDat = loadDataSet('testData/RT_data1.txt')
|
||||
myDat = loadDataSet('testData/RT_data2.txt')
|
||||
myMat = mat(myDat)
|
||||
myTree = createTree(myMat)
|
||||
print myTree
|
||||
|
||||
Reference in New Issue
Block a user