mirror of
https://github.com/apachecn/ailearning.git
synced 2026-02-13 07:15:26 +08:00
更新sklearn-svm-demo代码
This commit is contained in:
80
src/python/6.SVM/sklearn-svm-demo.py
Normal file
80
src/python/6.SVM/sklearn-svm-demo.py
Normal file
@@ -0,0 +1,80 @@
|
||||
#!/usr/bin/python
|
||||
# coding:utf8
|
||||
|
||||
"""
|
||||
Created on 2017-06-28
|
||||
Updated on 2017-06-28
|
||||
SVM:最大边距分离超平面
|
||||
@author: 片刻
|
||||
《机器学习实战》更新地址:https://github.com/apachecn/MachineLearning
|
||||
"""
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn import svm
|
||||
print(__doc__)
|
||||
|
||||
|
||||
# 创建40个分离点
|
||||
np.random.seed(0)
|
||||
# X = np.r_[np.random.randn(20, 2) - [2, 2], np.random.randn(20, 2) + [2, 2]]
|
||||
# Y = [0] * 20 + [1] * 20
|
||||
|
||||
|
||||
def loadDataSet(fileName):
|
||||
"""
|
||||
对文件进行逐行解析,从而得到第行的类标签和整个数据矩阵
|
||||
Args:
|
||||
fileName 文件名
|
||||
Returns:
|
||||
dataMat 数据矩阵
|
||||
labelMat 类标签
|
||||
"""
|
||||
dataMat = []
|
||||
labelMat = []
|
||||
fr = open(fileName)
|
||||
for line in fr.readlines():
|
||||
lineArr = line.strip().split('\t')
|
||||
dataMat.append([float(lineArr[0]), float(lineArr[1])])
|
||||
labelMat.append(float(lineArr[2]))
|
||||
return dataMat, labelMat
|
||||
|
||||
|
||||
X, Y = loadDataSet('input/6.SVM/testSet.txt')
|
||||
X = np.mat(X)
|
||||
|
||||
print("X=", X)
|
||||
print("Y=", Y)
|
||||
|
||||
# 拟合一个SVM模型
|
||||
clf = svm.SVC(kernel='linear')
|
||||
clf.fit(X, Y)
|
||||
|
||||
# 获取分割超平面
|
||||
w = clf.coef_[0]
|
||||
# 斜率
|
||||
a = -w[0] / w[1]
|
||||
# 从-5到5,顺序间隔采样50个样本,默认是num=50
|
||||
# xx = np.linspace(-5, 5) # , num=50)
|
||||
xx = np.linspace(-2, 10) # , num=50)
|
||||
# 二维的直线方程
|
||||
yy = a * xx - (clf.intercept_[0]) / w[1]
|
||||
print("yy=", yy)
|
||||
|
||||
# plot the parallels to the separating hyperplane that pass through the support vectors
|
||||
# 通过支持向量绘制分割超平面
|
||||
print("support_vectors_=", clf.support_vectors_)
|
||||
b = clf.support_vectors_[0]
|
||||
yy_down = a * xx + (b[1] - a * b[0])
|
||||
b = clf.support_vectors_[-1]
|
||||
yy_up = a * xx + (b[1] - a * b[0])
|
||||
|
||||
# plot the line, the points, and the nearest vectors to the plane
|
||||
plt.plot(xx, yy, 'k-')
|
||||
plt.plot(xx, yy_down, 'k--')
|
||||
plt.plot(xx, yy_up, 'k--')
|
||||
|
||||
plt.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1], s=80, facecolors='none')
|
||||
plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired)
|
||||
|
||||
plt.axis('tight')
|
||||
plt.show()
|
||||
@@ -497,26 +497,26 @@ def plotfig_SVM(xArr, yArr, ws, b, alphas):
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# # 无核函数的测试
|
||||
# # 获取特征和目标变量
|
||||
# dataArr, labelArr = loadDataSet('input/6.SVM/testSet.txt')
|
||||
# # print labelArr
|
||||
# 无核函数的测试
|
||||
# 获取特征和目标变量
|
||||
dataArr, labelArr = loadDataSet('input/6.SVM/testSet.txt')
|
||||
# print labelArr
|
||||
|
||||
# # b是常量值, alphas是拉格朗日乘子
|
||||
# b, alphas = smoP(dataArr, labelArr, 0.6, 0.001, 40)
|
||||
# print '/n/n/n'
|
||||
# print 'b=', b
|
||||
# print 'alphas[alphas>0]=', alphas[alphas > 0]
|
||||
# print 'shape(alphas[alphas > 0])=', shape(alphas[alphas > 0])
|
||||
# for i in range(100):
|
||||
# if alphas[i] > 0:
|
||||
# print dataArr[i], labelArr[i]
|
||||
# # 画图
|
||||
# ws = calcWs(alphas, dataArr, labelArr)
|
||||
# plotfig_SVM(dataArr, labelArr, ws, b, alphas)
|
||||
# b是常量值, alphas是拉格朗日乘子
|
||||
b, alphas = smoP(dataArr, labelArr, 0.6, 0.001, 40)
|
||||
print '/n/n/n'
|
||||
print 'b=', b
|
||||
print 'alphas[alphas>0]=', alphas[alphas > 0]
|
||||
print 'shape(alphas[alphas > 0])=', shape(alphas[alphas > 0])
|
||||
for i in range(100):
|
||||
if alphas[i] > 0:
|
||||
print dataArr[i], labelArr[i]
|
||||
# 画图
|
||||
ws = calcWs(alphas, dataArr, labelArr)
|
||||
plotfig_SVM(dataArr, labelArr, ws, b, alphas)
|
||||
|
||||
# # 有核函数的测试
|
||||
testRbf(0.8)
|
||||
# testRbf(0.8)
|
||||
|
||||
# 项目实战
|
||||
# 示例:手写识别问题回顾
|
||||
|
||||
Reference in New Issue
Block a user