Files
ailearning/src/python/9.RegTrees/treeExplore.py
jiangzhonglian 62e8b3dee4 更新链接
2017-05-18 11:57:30 +08:00

124 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/python
# coding:utf8
'''
Created on 2017-03-08
Update on 2017-05-18
Tree-Based Regression Methods Source Code for Machine Learning in Action Ch. 9
@author: Peter/片刻
《机器学习实战》更新地址https://github.com/apachecn/MachineLearning
'''
import regTrees
from Tkinter import *
from numpy import *
import matplotlib
from matplotlib.figure import Figure
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
matplotlib.use('TkAgg')
def test_widget_text(root):
mylabel = Label(root, text="helloworld")
# 相当于告诉 布局管理器(Geometry Manager),如果不设定位置,默认在 0行0列的位置
mylabel.grid()
# 最大为误差, 最大子叶节点的数量
def reDraw(tolS, tolN):
# clear the figure
reDraw.f.clf()
reDraw.a = reDraw.f.add_subplot(111)
# 检查复选框是否选中
if chkBtnVar.get():
if tolN < 2:
tolN = 2
myTree = regTrees.createTree(reDraw.rawDat, regTrees.modelLeaf, regTrees.modelErr, (tolS, tolN))
yHat = regTrees.createForeCast(myTree, reDraw.testDat, regTrees.modelTreeEval)
else:
myTree = regTrees.createTree(reDraw.rawDat, ops=(tolS, tolN))
yHat = regTrees.createForeCast(myTree, reDraw.testDat)
# use scatter for data set
reDraw.a.scatter(reDraw.rawDat[:, 0], reDraw.rawDat[:, 1], s=5)
# use plot for yHat
reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0, c='red')
reDraw.canvas.show()
def getInputs():
try:
tolN = int(tolNentry.get())
except:
tolN = 10
print "enter Integer for tolN"
tolNentry.delete(0, END)
tolNentry.insert(0, '10')
try:
tolS = float(tolSentry.get())
except:
tolS = 1.0
print "enter Float for tolS"
tolSentry.delete(0, END)
tolSentry.insert(0, '1.0')
return tolN, tolS
# 画新的tree
def drawNewTree():
# #get values from Entry boxes
tolN, tolS = getInputs()
reDraw(tolS, tolN)
def main(root):
# 标题
Label(root, text="Plot Place Holder").grid(row=0, columnspan=3)
# 输入栏1, 叶子的数量
Label(root, text="tolN").grid(row=1, column=0)
global tolNentry
tolNentry = Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0, '10')
# 输入栏2, 误差量
Label(root, text="tolS").grid(row=2, column=0)
global tolSentry
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
# 设置输出值
tolSentry.insert(0,'1.0')
# 设置提交的按钮
Button(root, text="确定", command=drawNewTree).grid(row=1, column=2, rowspan=3)
# 设置复选按钮
global chkBtnVar
chkBtnVar = IntVar()
chkBtn = Checkbutton(root, text="Model Tree", variable = chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)
# 退出按钮
Button(root, text="退出", fg="black", command=quit).grid(row=1, column=2)
# 创建一个画板 canvas
reDraw.f = Figure(figsize=(5, 4), dpi=100)
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.show()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)
reDraw.rawDat = mat(regTrees.loadDataSet('testData/RT_sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:, 0]), max(reDraw.rawDat[:, 0]), 0.01)
reDraw(1.0, 10)
if __name__ == "__main__":
# 创建一个事件
root = Tk()
# test_widget_text(root)
main(root)
# 启动事件循环
root.mainloop()