diff --git a/src/python/6.SVM/sklearn-svm-demo.py b/src/python/6.SVM/sklearn-svm-demo.py new file mode 100644 index 00000000..88e7aa28 --- /dev/null +++ b/src/python/6.SVM/sklearn-svm-demo.py @@ -0,0 +1,80 @@ +#!/usr/bin/python +# coding:utf8 + +""" +Created on 2017-06-28 +Updated on 2017-06-28 +SVM:最大边距分离超平面 +@author: 片刻 +《机器学习实战》更新地址:https://github.com/apachecn/MachineLearning +""" +import numpy as np +import matplotlib.pyplot as plt +from sklearn import svm +print(__doc__) + + +# 创建40个分离点 +np.random.seed(0) +# X = np.r_[np.random.randn(20, 2) - [2, 2], np.random.randn(20, 2) + [2, 2]] +# Y = [0] * 20 + [1] * 20 + + +def loadDataSet(fileName): + """ + 对文件进行逐行解析,从而得到第行的类标签和整个数据矩阵 + Args: + fileName 文件名 + Returns: + dataMat 数据矩阵 + labelMat 类标签 + """ + dataMat = [] + labelMat = [] + fr = open(fileName) + for line in fr.readlines(): + lineArr = line.strip().split('\t') + dataMat.append([float(lineArr[0]), float(lineArr[1])]) + labelMat.append(float(lineArr[2])) + return dataMat, labelMat + + +X, Y = loadDataSet('input/6.SVM/testSet.txt') +X = np.mat(X) + +print("X=", X) +print("Y=", Y) + +# 拟合一个SVM模型 +clf = svm.SVC(kernel='linear') +clf.fit(X, Y) + +# 获取分割超平面 +w = clf.coef_[0] +# 斜率 +a = -w[0] / w[1] +# 从-5到5,顺序间隔采样50个样本,默认是num=50 +# xx = np.linspace(-5, 5) # , num=50) +xx = np.linspace(-2, 10) # , num=50) +# 二维的直线方程 +yy = a * xx - (clf.intercept_[0]) / w[1] +print("yy=", yy) + +# plot the parallels to the separating hyperplane that pass through the support vectors +# 通过支持向量绘制分割超平面 +print("support_vectors_=", clf.support_vectors_) +b = clf.support_vectors_[0] +yy_down = a * xx + (b[1] - a * b[0]) +b = clf.support_vectors_[-1] +yy_up = a * xx + (b[1] - a * b[0]) + +# plot the line, the points, and the nearest vectors to the plane +plt.plot(xx, yy, 'k-') +plt.plot(xx, yy_down, 'k--') +plt.plot(xx, yy_up, 'k--') + +plt.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1], s=80, facecolors='none') +plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired) + +plt.axis('tight') +plt.show() diff --git a/src/python/6.SVM/svm-complete.py b/src/python/6.SVM/svm-complete.py index 44286375..f20977aa 100644 --- a/src/python/6.SVM/svm-complete.py +++ b/src/python/6.SVM/svm-complete.py @@ -497,26 +497,26 @@ def plotfig_SVM(xArr, yArr, ws, b, alphas): if __name__ == "__main__": - # # 无核函数的测试 - # # 获取特征和目标变量 - # dataArr, labelArr = loadDataSet('input/6.SVM/testSet.txt') - # # print labelArr + # 无核函数的测试 + # 获取特征和目标变量 + dataArr, labelArr = loadDataSet('input/6.SVM/testSet.txt') + # print labelArr - # # b是常量值, alphas是拉格朗日乘子 - # b, alphas = smoP(dataArr, labelArr, 0.6, 0.001, 40) - # print '/n/n/n' - # print 'b=', b - # print 'alphas[alphas>0]=', alphas[alphas > 0] - # print 'shape(alphas[alphas > 0])=', shape(alphas[alphas > 0]) - # for i in range(100): - # if alphas[i] > 0: - # print dataArr[i], labelArr[i] - # # 画图 - # ws = calcWs(alphas, dataArr, labelArr) - # plotfig_SVM(dataArr, labelArr, ws, b, alphas) + # b是常量值, alphas是拉格朗日乘子 + b, alphas = smoP(dataArr, labelArr, 0.6, 0.001, 40) + print '/n/n/n' + print 'b=', b + print 'alphas[alphas>0]=', alphas[alphas > 0] + print 'shape(alphas[alphas > 0])=', shape(alphas[alphas > 0]) + for i in range(100): + if alphas[i] > 0: + print dataArr[i], labelArr[i] + # 画图 + ws = calcWs(alphas, dataArr, labelArr) + plotfig_SVM(dataArr, labelArr, ws, b, alphas) # # 有核函数的测试 - testRbf(0.8) + # testRbf(0.8) # 项目实战 # 示例:手写识别问题回顾