diff --git a/docs/9.树回归.md b/docs/9.树回归.md index f695bff5..ec01f7db 100644 --- a/docs/9.树回归.md +++ b/docs/9.树回归.md @@ -11,3 +11,6 @@ * 那么问题来了,如何计算连续型数值的混乱度呢? * `误差`:也就是计算平均差的总值(总方差=方差*样本数) * 二元切分方式 +* 回归树 +* 模型树(线性模型) +* 树回归方法在预测复杂数据时 会比 简单的线性模型 更有效。 diff --git a/src/python/09.RegTrees/treeExplore.py b/src/python/09.RegTrees/treeExplore.py new file mode 100644 index 00000000..26fd9a8b --- /dev/null +++ b/src/python/09.RegTrees/treeExplore.py @@ -0,0 +1,123 @@ +#!/usr/bin/python +# coding:utf8 + +''' +Created on 2017-03-08 +Update on 2017-03-08 +Tree-Based Regression Methods Source Code for Machine Learning in Action Ch. 9 +@author: jiangzhonglian +''' +from Tkinter import * +from numpy import * +import regTrees + +import matplotlib +matplotlib.use('TkAgg') +from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg +from matplotlib.figure import Figure + + +def test_widget_text(root): + mylabel = Label(root, text="helloworld") + # 相当于告诉 布局管理器(Geometry Manager),如果不设定位置,默认在 0行0列的位置 + mylabel.grid() + + +# 最大为误差, 最大子叶节点的数量 +def reDraw(tolS, tolN): + # clear the figure + reDraw.f.clf() + reDraw.a = reDraw.f.add_subplot(111) + + # 检查复选框是否选中 + if chkBtnVar.get(): + if tolN < 2: + tolN = 2 + myTree = regTrees.createTree(reDraw.rawDat, regTrees.modelLeaf, regTrees.modelErr, (tolS, tolN)) + yHat = regTrees.createForeCast(myTree, reDraw.testDat, regTrees.modelTreeEval) + else: + myTree = regTrees.createTree(reDraw.rawDat, ops=(tolS, tolN)) + yHat = regTrees.createForeCast(myTree, reDraw.testDat) + + # use scatter for data set + reDraw.a.scatter(reDraw.rawDat[:, 0], reDraw.rawDat[:, 1], s=5) + # use plot for yHat + reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0, c='red') + reDraw.canvas.show() + + +def getInputs(): + try: + tolN = int(tolNentry.get()) + except: + tolN = 10 + print "enter Integer for tolN" + tolNentry.delete(0, END) + tolNentry.insert(0, '10') + try: + tolS = float(tolSentry.get()) + except: + tolS = 1.0 + print "enter Float for tolS" + tolSentry.delete(0, END) + tolSentry.insert(0, '1.0') + return tolN, tolS + + +# 画新的tree +def drawNewTree(): + # #get values from Entry boxes + tolN, tolS = getInputs() + reDraw(tolS, tolN) + + +def main(root): + # 标题 + Label(root, text="Plot Place Holder").grid(row=0, columnspan=3) + # 输入栏1, 叶子的数量 + Label(root, text="tolN").grid(row=1, column=0) + global tolNentry + tolNentry = Entry(root) + tolNentry.grid(row=1, column=1) + tolNentry.insert(0, '10') + # 输入栏2, 误差量 + Label(root, text="tolS").grid(row=2, column=0) + global tolSentry + tolSentry = Entry(root) + tolSentry.grid(row=2, column=1) + # 设置输出值 + tolSentry.insert(0,'1.0') + + # 设置提交的按钮 + Button(root, text="确定", command=drawNewTree).grid(row=1, column=2, rowspan=3) + + # 设置复选按钮 + global chkBtnVar + chkBtnVar = IntVar() + chkBtn = Checkbutton(root, text="Model Tree", variable = chkBtnVar) + chkBtn.grid(row=3, column=0, columnspan=2) + + # 退出按钮 + Button(root, text="退出", fg="black", command=quit).grid(row=1, column=2) + + + # 创建一个画板 canvas + reDraw.f = Figure(figsize=(5, 4), dpi=100) + reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root) + reDraw.canvas.show() + reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3) + + reDraw.rawDat = mat(regTrees.loadDataSet('testData/RT_sine.txt')) + reDraw.testDat = arange(min(reDraw.rawDat[:, 0]), max(reDraw.rawDat[:, 0]), 0.01) + reDraw(1.0, 10) + + +if __name__ == "__main__": + + # 创建一个事件 + root = Tk() + # test_widget_text(root) + main(root) + + # 启动事件循环 + root.mainloop() diff --git a/testData/RT_sine.txt b/testData/RT_sine.txt new file mode 100644 index 00000000..e7050f37 --- /dev/null +++ b/testData/RT_sine.txt @@ -0,0 +1,200 @@ +0.190350 0.878049 +0.306657 -0.109413 +0.017568 0.030917 +0.122328 0.951109 +0.076274 0.774632 +0.614127 -0.250042 +0.220722 0.807741 +0.089430 0.840491 +0.278817 0.342210 +0.520287 -0.950301 +0.726976 0.852224 +0.180485 1.141859 +0.801524 1.012061 +0.474273 -1.311226 +0.345116 -0.319911 +0.981951 -0.374203 +0.127349 1.039361 +0.757120 1.040152 +0.345419 -0.429760 +0.314532 -0.075762 +0.250828 0.657169 +0.431255 -0.905443 +0.386669 -0.508875 +0.143794 0.844105 +0.470839 -0.951757 +0.093065 0.785034 +0.205377 0.715400 +0.083329 0.853025 +0.243475 0.699252 +0.062389 0.567589 +0.764116 0.834931 +0.018287 0.199875 +0.973603 -0.359748 +0.458826 -1.113178 +0.511200 -1.082561 +0.712587 0.615108 +0.464745 -0.835752 +0.984328 -0.332495 +0.414291 -0.808822 +0.799551 1.072052 +0.499037 -0.924499 +0.966757 -0.191643 +0.756594 0.991844 +0.444938 -0.969528 +0.410167 -0.773426 +0.532335 -0.631770 +0.343909 -0.313313 +0.854302 0.719307 +0.846882 0.916509 +0.740758 1.009525 +0.150668 0.832433 +0.177606 0.893017 +0.445289 -0.898242 +0.734653 0.787282 +0.559488 -0.663482 +0.232311 0.499122 +0.934435 -0.121533 +0.219089 0.823206 +0.636525 0.053113 +0.307605 0.027500 +0.713198 0.693978 +0.116343 1.242458 +0.680737 0.368910 +0.484730 -0.891940 +0.929408 0.234913 +0.008507 0.103505 +0.872161 0.816191 +0.755530 0.985723 +0.620671 0.026417 +0.472260 -0.967451 +0.257488 0.630100 +0.130654 1.025693 +0.512333 -0.884296 +0.747710 0.849468 +0.669948 0.413745 +0.644856 0.253455 +0.894206 0.482933 +0.820471 0.899981 +0.790796 0.922645 +0.010729 0.032106 +0.846777 0.768675 +0.349175 -0.322929 +0.453662 -0.957712 +0.624017 -0.169913 +0.211074 0.869840 +0.062555 0.607180 +0.739709 0.859793 +0.985896 -0.433632 +0.782088 0.976380 +0.642561 0.147023 +0.779007 0.913765 +0.185631 1.021408 +0.525250 -0.706217 +0.236802 0.564723 +0.440958 -0.993781 +0.397580 -0.708189 +0.823146 0.860086 +0.370173 -0.649231 +0.791675 1.162927 +0.456647 -0.956843 +0.113350 0.850107 +0.351074 -0.306095 +0.182684 0.825728 +0.914034 0.305636 +0.751486 0.898875 +0.216572 0.974637 +0.013273 0.062439 +0.469726 -1.226188 +0.060676 0.599451 +0.776310 0.902315 +0.061648 0.464446 +0.714077 0.947507 +0.559264 -0.715111 +0.121876 0.791703 +0.330586 -0.165819 +0.662909 0.379236 +0.785142 0.967030 +0.161352 0.979553 +0.985215 -0.317699 +0.457734 -0.890725 +0.171574 0.963749 +0.334277 -0.266228 +0.501065 -0.910313 +0.988736 -0.476222 +0.659242 0.218365 +0.359861 -0.338734 +0.790434 0.843387 +0.462458 -0.911647 +0.823012 0.813427 +0.594668 -0.603016 +0.498207 -0.878847 +0.574882 -0.419598 +0.570048 -0.442087 +0.331570 -0.347567 +0.195407 0.822284 +0.814327 0.974355 +0.641925 0.073217 +0.238778 0.657767 +0.400138 -0.715598 +0.670479 0.469662 +0.069076 0.680958 +0.294373 0.145767 +0.025628 0.179822 +0.697772 0.506253 +0.729626 0.786519 +0.293071 0.259997 +0.531802 -1.095833 +0.487338 -1.034481 +0.215780 0.933506 +0.625818 0.103845 +0.179389 0.892237 +0.192552 0.915516 +0.671661 0.330361 +0.952391 -0.060263 +0.795133 0.945157 +0.950494 -0.071855 +0.194894 1.000860 +0.351460 -0.227946 +0.863456 0.648456 +0.945221 -0.045667 +0.779840 0.979954 +0.996606 -0.450501 +0.632184 -0.036506 +0.790898 0.994890 +0.022503 0.386394 +0.318983 -0.152749 +0.369633 -0.423960 +0.157300 0.962858 +0.153223 0.882873 +0.360068 -0.653742 +0.433917 -0.872498 +0.133461 0.879002 +0.757252 1.123667 +0.309391 -0.102064 +0.195586 0.925339 +0.240259 0.689117 +0.340591 -0.455040 +0.243436 0.415760 +0.612755 -0.180844 +0.089407 0.723702 +0.469695 -0.987859 +0.943560 -0.097303 +0.177241 0.918082 +0.317756 -0.222902 +0.515337 -0.733668 +0.344773 -0.256893 +0.537029 -0.797272 +0.626878 0.048719 +0.208940 0.836531 +0.470697 -1.080283 +0.054448 0.624676 +0.109230 0.816921 +0.158325 1.044485 +0.976650 -0.309060 +0.643441 0.267336 +0.215841 1.018817 +0.905337 0.409871 +0.154354 0.920009 +0.947922 -0.112378 +0.201391 0.768894