mirror of
https://github.com/apachecn/ailearning.git
synced 2026-02-10 05:45:40 +08:00
测试完:回归树 VS 模型树 VS 线性回归
This commit is contained in:
@@ -11,6 +11,7 @@ import os
|
||||
from numpy import *
|
||||
import matplotlib.pylab as plt
|
||||
|
||||
|
||||
def loadDataSet(fileName): #general function to parse tab -delimited floats
|
||||
numFeat = len(open(fileName).readline().split('\t')) - 1 #get number of fields
|
||||
dataMat = []; labelMat = []
|
||||
@@ -24,6 +25,7 @@ def loadDataSet(fileName): #general function to parse tab -delimited floats
|
||||
labelMat.append(float(curLine[-1]))
|
||||
return dataMat,labelMat
|
||||
|
||||
|
||||
def standRegres(xArr,yArr):
|
||||
# >>> A.T # transpose, 转置
|
||||
xMat = mat(xArr); yMat = mat(yArr).T
|
||||
@@ -37,6 +39,7 @@ def standRegres(xArr,yArr):
|
||||
ws = xTx.I * (xMat.T*yMat) # 最小二乘法求最优解
|
||||
return ws
|
||||
|
||||
|
||||
def plotBestFit(xArr, yArr, ws):
|
||||
|
||||
xMat = mat(xArr)
|
||||
@@ -60,6 +63,7 @@ def plotBestFit(xArr, yArr, ws):
|
||||
plt.xlabel('X'); plt.ylabel('Y')
|
||||
plt.show()
|
||||
|
||||
|
||||
def main1():
|
||||
# w0*x0+w1*x1+w2*x2=f(x)
|
||||
project_dir = os.path.dirname(os.path.dirname(os.getcwd()))
|
||||
@@ -91,6 +95,7 @@ def lwlr(testPoint, xArr, yArr,k=1.0):
|
||||
ws = xTx.I * (xMat.T * (weights * yMat))
|
||||
return testPoint * ws
|
||||
|
||||
|
||||
def lwlrTest(testArr,xArr,yArr,k=1.0): #loops over all the data points and applies lwlr to each one
|
||||
m = shape(testArr)[0]
|
||||
# m*1的矩阵
|
||||
@@ -101,6 +106,7 @@ def lwlrTest(testArr,xArr,yArr,k=1.0): #loops over all the data points and appl
|
||||
yHat[i] = lwlr(testArr[i],xArr,yArr,k)
|
||||
return yHat
|
||||
|
||||
|
||||
def lwlrTestPlot(xArr, yArr, yHat):
|
||||
|
||||
xMat = mat(xArr)
|
||||
@@ -123,11 +129,13 @@ def lwlrTestPlot(xArr, yArr, yHat):
|
||||
plt.xlabel('X'); plt.ylabel('Y')
|
||||
plt.show()
|
||||
|
||||
|
||||
def main2():
|
||||
# w0*x0+w1*x1+w2*x2=f(x)
|
||||
project_dir = os.path.dirname(os.path.dirname(os.getcwd()))
|
||||
# project_dir = os.path.dirname(os.path.dirname(os.getcwd()))
|
||||
# 1.收集并准备数据
|
||||
xArr, yArr = loadDataSet("%s/resources/ex0.txt" % project_dir)
|
||||
# xArr, yArr = loadDataSet("%s/resources/ex0.txt" % project_dir)
|
||||
xArr, yArr = loadDataSet("testData/Regression_data.txt")
|
||||
# print xArr, '---\n', yArr
|
||||
# 2.训练模型, f(x)=a1*x1+b2*x2+..+nn*xn中 (a1,b2, .., nn).T的矩阵值
|
||||
yHat = lwlrTest(xArr, xArr, yArr, 0.003)
|
||||
@@ -136,12 +144,14 @@ def main2():
|
||||
# 数据可视化
|
||||
lwlrTestPlot(xArr, yArr, yHat)
|
||||
|
||||
if __name__=="__main__":
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 线性回归
|
||||
# main1()
|
||||
# 局部加权线性回归
|
||||
main2()
|
||||
|
||||
|
||||
def rssError(yArr,yHatArr): #yArr and yHatArr both need to be arrays
|
||||
return ((yArr-yHatArr)**2).sum()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user