mirror of
https://github.com/apachecn/ailearning.git
synced 2026-02-12 06:46:14 +08:00
修改8.Regression文档和代码
This commit is contained in:
@@ -396,6 +396,41 @@ def regression2():
|
||||
plt.show()
|
||||
|
||||
|
||||
# test for abloneDataSet
|
||||
def abaloneTest():
|
||||
'''
|
||||
Desc:
|
||||
预测鲍鱼的年龄
|
||||
Args:
|
||||
None
|
||||
Returns:
|
||||
None
|
||||
'''
|
||||
# 加载数据
|
||||
abX, abY = loadDataSet("input/8.Regression/abalone.txt")
|
||||
# 使用不同的核进行预测
|
||||
oldyHat01 = lwlrTest(abX[0:99], abX[0:99], abY[0:99], 0.1)
|
||||
oldyHat1 = lwlrTest(abX[0:99], abX[0:99], abY[0:99], 1)
|
||||
oldyHat10 = lwlrTest(abX[0:99], abX[0:99], abY[0:99], 10)
|
||||
# 打印出不同的核预测值与训练数据集上的真实值之间的误差大小
|
||||
print "old yHat01 error Size is :" , rssError(abY[0:99], oldyHat01.T)
|
||||
print "old yHat1 error Size is :" , rssError(abY[0:99], oldyHat1.T)
|
||||
print "old yHat10 error Size is :" , rssError(abY[0:99], oldyHat10.T)
|
||||
|
||||
# 打印出 不同的核预测值 与 新数据集(测试数据集)上的真实值之间的误差大小
|
||||
newyHat01 = lwlrTest(abX[100:199], abX[0:99], abY[0:99], 0.1)
|
||||
print "new yHat01 error Size is :" , rssError(abY[0:99], yHat01.T)
|
||||
newyHat1 = lwlrTest(abX[100:199], abX[0:99], abY[0:99], 1)
|
||||
print "new yHat1 error Size is :" , rssError(abY[0:99], yHat1.T)
|
||||
newyHat10 = lwlrTest(abX[100:199], abX[0:99], abY[0:99], 10)
|
||||
print "new yHat10 error Size is :" , rssError(abY[0:99], yHat10.T)
|
||||
|
||||
# 使用简单的 线性回归 进行预测,与上面的计算进行比较
|
||||
standWs = standRegres(abX[0:99], abY[0:99])
|
||||
standyHat = mat(abx[100:199]) * standWs
|
||||
print "standRegress error Size is:", rssError(abY[100:199], standyHat.T.A)
|
||||
|
||||
|
||||
#test for ridgeRegression
|
||||
def regression3():
|
||||
abX,abY = loadDataSet("input/8.Regression/abalone.txt")
|
||||
|
||||
Reference in New Issue
Block a user