更新sklearn-svm-demo代码

This commit is contained in:
jiangzhonglian
2017-06-28 14:36:12 +08:00
parent da5437440a
commit 2df5439d8d
2 changed files with 97 additions and 17 deletions

View File

@@ -0,0 +1,80 @@
#!/usr/bin/python
# coding:utf8
"""
Created on 2017-06-28
Updated on 2017-06-28
SVM最大边距分离超平面
@author: 片刻
《机器学习实战》更新地址https://github.com/apachecn/MachineLearning
"""
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm
print(__doc__)
# 创建40个分离点
np.random.seed(0)
# X = np.r_[np.random.randn(20, 2) - [2, 2], np.random.randn(20, 2) + [2, 2]]
# Y = [0] * 20 + [1] * 20
def loadDataSet(fileName):
"""
对文件进行逐行解析,从而得到第行的类标签和整个数据矩阵
Args:
fileName 文件名
Returns:
dataMat 数据矩阵
labelMat 类标签
"""
dataMat = []
labelMat = []
fr = open(fileName)
for line in fr.readlines():
lineArr = line.strip().split('\t')
dataMat.append([float(lineArr[0]), float(lineArr[1])])
labelMat.append(float(lineArr[2]))
return dataMat, labelMat
X, Y = loadDataSet('input/6.SVM/testSet.txt')
X = np.mat(X)
print("X=", X)
print("Y=", Y)
# 拟合一个SVM模型
clf = svm.SVC(kernel='linear')
clf.fit(X, Y)
# 获取分割超平面
w = clf.coef_[0]
# 斜率
a = -w[0] / w[1]
# 从-5到5顺序间隔采样50个样本默认是num=50
# xx = np.linspace(-5, 5) # , num=50)
xx = np.linspace(-2, 10) # , num=50)
# 二维的直线方程
yy = a * xx - (clf.intercept_[0]) / w[1]
print("yy=", yy)
# plot the parallels to the separating hyperplane that pass through the support vectors
# 通过支持向量绘制分割超平面
print("support_vectors_=", clf.support_vectors_)
b = clf.support_vectors_[0]
yy_down = a * xx + (b[1] - a * b[0])
b = clf.support_vectors_[-1]
yy_up = a * xx + (b[1] - a * b[0])
# plot the line, the points, and the nearest vectors to the plane
plt.plot(xx, yy, 'k-')
plt.plot(xx, yy_down, 'k--')
plt.plot(xx, yy_up, 'k--')
plt.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1], s=80, facecolors='none')
plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired)
plt.axis('tight')
plt.show()

View File

@@ -497,26 +497,26 @@ def plotfig_SVM(xArr, yArr, ws, b, alphas):
if __name__ == "__main__":
# # 无核函数的测试
# # 获取特征和目标变量
# dataArr, labelArr = loadDataSet('input/6.SVM/testSet.txt')
# # print labelArr
# 无核函数的测试
# 获取特征和目标变量
dataArr, labelArr = loadDataSet('input/6.SVM/testSet.txt')
# print labelArr
# # b是常量值 alphas是拉格朗日乘子
# b, alphas = smoP(dataArr, labelArr, 0.6, 0.001, 40)
# print '/n/n/n'
# print 'b=', b
# print 'alphas[alphas>0]=', alphas[alphas > 0]
# print 'shape(alphas[alphas > 0])=', shape(alphas[alphas > 0])
# for i in range(100):
# if alphas[i] > 0:
# print dataArr[i], labelArr[i]
# # 画图
# ws = calcWs(alphas, dataArr, labelArr)
# plotfig_SVM(dataArr, labelArr, ws, b, alphas)
# b是常量值 alphas是拉格朗日乘子
b, alphas = smoP(dataArr, labelArr, 0.6, 0.001, 40)
print '/n/n/n'
print 'b=', b
print 'alphas[alphas>0]=', alphas[alphas > 0]
print 'shape(alphas[alphas > 0])=', shape(alphas[alphas > 0])
for i in range(100):
if alphas[i] > 0:
print dataArr[i], labelArr[i]
# 画图
ws = calcWs(alphas, dataArr, labelArr)
plotfig_SVM(dataArr, labelArr, ws, b, alphas)
# # 有核函数的测试
testRbf(0.8)
# testRbf(0.8)
# 项目实战
# 示例:手写识别问题回顾