From c413c5f62621fdd58e1fa68c9e4364ac2212c10c Mon Sep 17 00:00:00 2001 From: jiangzhonglian Date: Tue, 18 Apr 2017 16:50:38 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86SVM=E6=97=A0?= =?UTF-8?q?=E6=A0=B8=E5=87=BD=E6=95=B0=E7=9A=84=EF=BC=8C=E6=B3=A8=E9=87=8A?= =?UTF-8?q?=E5=8F=82=E8=80=83=E6=9C=89=E6=9C=89=E6=A0=B8=E5=87=BD=E6=95=B0?= =?UTF-8?q?=E7=9A=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- input/6.SVM/testSetRBF.txt | 100 ++++++ input/6.SVM/testSetRBF2.txt | 100 ++++++ src/python/6.SVM/svm-complete.py | 29 +- src/python/6.SVM/svm-complete_Non-Kernel.py | 375 ++++++++++++++++++++ 4 files changed, 592 insertions(+), 12 deletions(-) create mode 100755 input/6.SVM/testSetRBF.txt create mode 100755 input/6.SVM/testSetRBF2.txt create mode 100644 src/python/6.SVM/svm-complete_Non-Kernel.py diff --git a/input/6.SVM/testSetRBF.txt b/input/6.SVM/testSetRBF.txt new file mode 100755 index 00000000..ff78d65e --- /dev/null +++ b/input/6.SVM/testSetRBF.txt @@ -0,0 +1,100 @@ +-0.214824 0.662756 -1.000000 +-0.061569 -0.091875 1.000000 +0.406933 0.648055 -1.000000 +0.223650 0.130142 1.000000 +0.231317 0.766906 -1.000000 +-0.748800 -0.531637 -1.000000 +-0.557789 0.375797 -1.000000 +0.207123 -0.019463 1.000000 +0.286462 0.719470 -1.000000 +0.195300 -0.179039 1.000000 +-0.152696 -0.153030 1.000000 +0.384471 0.653336 -1.000000 +-0.117280 -0.153217 1.000000 +-0.238076 0.000583 1.000000 +-0.413576 0.145681 1.000000 +0.490767 -0.680029 -1.000000 +0.199894 -0.199381 1.000000 +-0.356048 0.537960 -1.000000 +-0.392868 -0.125261 1.000000 +0.353588 -0.070617 1.000000 +0.020984 0.925720 -1.000000 +-0.475167 -0.346247 -1.000000 +0.074952 0.042783 1.000000 +0.394164 -0.058217 1.000000 +0.663418 0.436525 -1.000000 +0.402158 0.577744 -1.000000 +-0.449349 -0.038074 1.000000 +0.619080 -0.088188 -1.000000 +0.268066 -0.071621 1.000000 +-0.015165 0.359326 1.000000 +0.539368 -0.374972 -1.000000 +-0.319153 0.629673 -1.000000 +0.694424 0.641180 -1.000000 +0.079522 0.193198 1.000000 +0.253289 -0.285861 1.000000 +-0.035558 -0.010086 1.000000 +-0.403483 0.474466 -1.000000 +-0.034312 0.995685 -1.000000 +-0.590657 0.438051 -1.000000 +-0.098871 -0.023953 1.000000 +-0.250001 0.141621 1.000000 +-0.012998 0.525985 -1.000000 +0.153738 0.491531 -1.000000 +0.388215 -0.656567 -1.000000 +0.049008 0.013499 1.000000 +0.068286 0.392741 1.000000 +0.747800 -0.066630 -1.000000 +0.004621 -0.042932 1.000000 +-0.701600 0.190983 -1.000000 +0.055413 -0.024380 1.000000 +0.035398 -0.333682 1.000000 +0.211795 0.024689 1.000000 +-0.045677 0.172907 1.000000 +0.595222 0.209570 -1.000000 +0.229465 0.250409 1.000000 +-0.089293 0.068198 1.000000 +0.384300 -0.176570 1.000000 +0.834912 -0.110321 -1.000000 +-0.307768 0.503038 -1.000000 +-0.777063 -0.348066 -1.000000 +0.017390 0.152441 1.000000 +-0.293382 -0.139778 1.000000 +-0.203272 0.286855 1.000000 +0.957812 -0.152444 -1.000000 +0.004609 -0.070617 1.000000 +-0.755431 0.096711 -1.000000 +-0.526487 0.547282 -1.000000 +-0.246873 0.833713 -1.000000 +0.185639 -0.066162 1.000000 +0.851934 0.456603 -1.000000 +-0.827912 0.117122 -1.000000 +0.233512 -0.106274 1.000000 +0.583671 -0.709033 -1.000000 +-0.487023 0.625140 -1.000000 +-0.448939 0.176725 1.000000 +0.155907 -0.166371 1.000000 +0.334204 0.381237 -1.000000 +0.081536 -0.106212 1.000000 +0.227222 0.527437 -1.000000 +0.759290 0.330720 -1.000000 +0.204177 -0.023516 1.000000 +0.577939 0.403784 -1.000000 +-0.568534 0.442948 -1.000000 +-0.011520 0.021165 1.000000 +0.875720 0.422476 -1.000000 +0.297885 -0.632874 -1.000000 +-0.015821 0.031226 1.000000 +0.541359 -0.205969 -1.000000 +-0.689946 -0.508674 -1.000000 +-0.343049 0.841653 -1.000000 +0.523902 -0.436156 -1.000000 +0.249281 -0.711840 -1.000000 +0.193449 0.574598 -1.000000 +-0.257542 -0.753885 -1.000000 +-0.021605 0.158080 1.000000 +0.601559 -0.727041 -1.000000 +-0.791603 0.095651 -1.000000 +-0.908298 -0.053376 -1.000000 +0.122020 0.850966 -1.000000 +-0.725568 -0.292022 -1.000000 diff --git a/input/6.SVM/testSetRBF2.txt b/input/6.SVM/testSetRBF2.txt new file mode 100755 index 00000000..10b40211 --- /dev/null +++ b/input/6.SVM/testSetRBF2.txt @@ -0,0 +1,100 @@ +0.676771 -0.486687 -1.000000 +0.008473 0.186070 1.000000 +-0.727789 0.594062 -1.000000 +0.112367 0.287852 1.000000 +0.383633 -0.038068 1.000000 +-0.927138 -0.032633 -1.000000 +-0.842803 -0.423115 -1.000000 +-0.003677 -0.367338 1.000000 +0.443211 -0.698469 -1.000000 +-0.473835 0.005233 1.000000 +0.616741 0.590841 -1.000000 +0.557463 -0.373461 -1.000000 +-0.498535 -0.223231 -1.000000 +-0.246744 0.276413 1.000000 +-0.761980 -0.244188 -1.000000 +0.641594 -0.479861 -1.000000 +-0.659140 0.529830 -1.000000 +-0.054873 -0.238900 1.000000 +-0.089644 -0.244683 1.000000 +-0.431576 -0.481538 -1.000000 +-0.099535 0.728679 -1.000000 +-0.188428 0.156443 1.000000 +0.267051 0.318101 1.000000 +0.222114 -0.528887 -1.000000 +0.030369 0.113317 1.000000 +0.392321 0.026089 1.000000 +0.298871 -0.915427 -1.000000 +-0.034581 -0.133887 1.000000 +0.405956 0.206980 1.000000 +0.144902 -0.605762 -1.000000 +0.274362 -0.401338 1.000000 +0.397998 -0.780144 -1.000000 +0.037863 0.155137 1.000000 +-0.010363 -0.004170 1.000000 +0.506519 0.486619 -1.000000 +0.000082 -0.020625 1.000000 +0.057761 -0.155140 1.000000 +0.027748 -0.553763 -1.000000 +-0.413363 -0.746830 -1.000000 +0.081500 -0.014264 1.000000 +0.047137 -0.491271 1.000000 +-0.267459 0.024770 1.000000 +-0.148288 -0.532471 -1.000000 +-0.225559 -0.201622 1.000000 +0.772360 -0.518986 -1.000000 +-0.440670 0.688739 -1.000000 +0.329064 -0.095349 1.000000 +0.970170 -0.010671 -1.000000 +-0.689447 -0.318722 -1.000000 +-0.465493 -0.227468 -1.000000 +-0.049370 0.405711 1.000000 +-0.166117 0.274807 1.000000 +0.054483 0.012643 1.000000 +0.021389 0.076125 1.000000 +-0.104404 -0.914042 -1.000000 +0.294487 0.440886 -1.000000 +0.107915 -0.493703 -1.000000 +0.076311 0.438860 1.000000 +0.370593 -0.728737 -1.000000 +0.409890 0.306851 -1.000000 +0.285445 0.474399 -1.000000 +-0.870134 -0.161685 -1.000000 +-0.654144 -0.675129 -1.000000 +0.285278 -0.767310 -1.000000 +0.049548 -0.000907 1.000000 +0.030014 -0.093265 1.000000 +-0.128859 0.278865 1.000000 +0.307463 0.085667 1.000000 +0.023440 0.298638 1.000000 +0.053920 0.235344 1.000000 +0.059675 0.533339 -1.000000 +0.817125 0.016536 -1.000000 +-0.108771 0.477254 1.000000 +-0.118106 0.017284 1.000000 +0.288339 0.195457 1.000000 +0.567309 -0.200203 -1.000000 +-0.202446 0.409387 1.000000 +-0.330769 -0.240797 1.000000 +-0.422377 0.480683 -1.000000 +-0.295269 0.326017 1.000000 +0.261132 0.046478 1.000000 +-0.492244 -0.319998 -1.000000 +-0.384419 0.099170 1.000000 +0.101882 -0.781145 -1.000000 +0.234592 -0.383446 1.000000 +-0.020478 -0.901833 -1.000000 +0.328449 0.186633 1.000000 +-0.150059 -0.409158 1.000000 +-0.155876 -0.843413 -1.000000 +-0.098134 -0.136786 1.000000 +0.110575 -0.197205 1.000000 +0.219021 0.054347 1.000000 +0.030152 0.251682 1.000000 +0.033447 -0.122824 1.000000 +-0.686225 -0.020779 -1.000000 +-0.911211 -0.262011 -1.000000 +0.572557 0.377526 -1.000000 +-0.073647 -0.519163 -1.000000 +-0.281830 -0.797236 -1.000000 +-0.555263 0.126232 -1.000000 diff --git a/src/python/6.SVM/svm-complete.py b/src/python/6.SVM/svm-complete.py index 0a5c549e..b537d385 100644 --- a/src/python/6.SVM/svm-complete.py +++ b/src/python/6.SVM/svm-complete.py @@ -123,7 +123,7 @@ def selectJrand(i, m): def selectJ(i, oS, Ei): # this is the second choice -heurstic, and calcs Ej - """selectJ() + """selectJ(返回最优的j和Ej) 内循环的启发式方法。 选择第二个(内循环)alpha的alpha值 @@ -187,9 +187,6 @@ def updateEk(oS, k): Args: oS optStruct对象 k 某一列的行号 - - Returns: - """ # 求 误差:预测值-真实值的差 @@ -214,14 +211,15 @@ def clipAlpha(aj, H, L): def innerL(i, oS): - """ + """innerL 内循环代码 Args: i 具体的某一行 oS optStruct对象 Returns: - + 0 找不到最优的值 + 1 找到了最优的值,并且oS.Cache到缓存中 """ # 求 Ek误差:预测值-真实值的差 @@ -311,20 +309,29 @@ def smoP(dataMatIn, classLabels, C, toler, maxIter, kTup=('lin', 0)): entireSet = True alphaPairsChanged = 0 - # 循环遍历: + # 循环遍历:循环maxIter次 并且 (alphaPairsChanged存在可以改变 or 所有行遍历一遍) while (iter < maxIter) and ((alphaPairsChanged > 0) or (entireSet)): alphaPairsChanged = 0 - if entireSet: # 在数据集上遍历所有可能的alpha + + # 当entireSet=true or 非边界alpha对没有了;就开始寻找 alpha对,然后决定是否要进行else。 + if entireSet: + # 在数据集上遍历所有可能的alpha for i in range(oS.m): + # 是否存在alpha对,存在就+1 alphaPairsChanged += innerL(i, oS) print("fullSet, iter: %d i:%d, pairs changed %d" % (iter, i, alphaPairsChanged)) iter += 1 - else: # 遍历所有的非边界alpha值,也就是不在边界0或C上的值。 + + # 对已存在 alpha对,选出非边界的alpha值,进行优化。 + else: + # 遍历所有的非边界alpha值,也就是不在边界0或C上的值。 nonBoundIs = nonzero((oS.alphas.A > 0) * (oS.alphas.A < C))[0] for i in nonBoundIs: alphaPairsChanged += innerL(i, oS) print("non-bound, iter: %d i:%d, pairs changed %d" % (iter, i, alphaPairsChanged)) iter += 1 + + # 如果找到alpha对,就优化非边界alpha值,否则,就重新进行寻找,如果寻找一遍 遍历所有的行还是没找到,就退出循环。 if entireSet: entireSet = False # toggle entire set loop elif (alphaPairsChanged == 0): @@ -438,6 +445,7 @@ if __name__ == "__main__": + def testRbf(k1=1.3): @@ -548,9 +556,6 @@ def calcEkK(oS, k): return Ek - - - def selectJK(i, oS, Ei): # this is the second choice -heurstic, and calcs Ej maxK = -1 maxDeltaE = 0 diff --git a/src/python/6.SVM/svm-complete_Non-Kernel.py b/src/python/6.SVM/svm-complete_Non-Kernel.py new file mode 100644 index 00000000..3f5fa11c --- /dev/null +++ b/src/python/6.SVM/svm-complete_Non-Kernel.py @@ -0,0 +1,375 @@ +#!/usr/bin/python +# coding:utf8 + +""" +Created on Nov 4, 2010 +Update on 2017-03-21 +Chapter 5 source file for Machine Learing in Action +@author: Peter/geekidentity/片刻 +""" +from numpy import * +import matplotlib.pyplot as plt + + +def loadDataSet(fileName): + """loadDataSet(对文件进行逐行解析,从而得到第行的类标签和整个数据矩阵) + + Args: + fileName 文件名 + Returns: + dataMat 数据矩阵 + labelMat 类标签 + """ + dataMat = [] + labelMat = [] + fr = open(fileName) + for line in fr.readlines(): + lineArr = line.strip().split('\t') + dataMat.append([float(lineArr[0]), float(lineArr[1])]) + labelMat.append(float(lineArr[2])) + return dataMat, labelMat + + +def selectJrand(i, m): + """ + 随机选择一个整数 + Args: + i 第一个alpha的下标 + m 所有alpha的数目 + Returns: + j 返回一个不为i的随机数,在0~m之间的整数值 + """ + j = i + while j == i: + j = int(random.uniform(0, m)) + return j + + +def clipAlpha(aj, H, L): + """clipAlpha(调整aj的值,使aj处于 L<=aj<=H) + Args: + aj 目标值 + H 最大值 + L 最小值 + Returns: + aj 目标值 + """ + if aj > H: + aj = H + if L > aj: + aj = L + return aj + + +def calcWs(alphas, dataArr, classLabels): + """ + 基于alpha计算w值 + Args: + alphas 拉格朗日乘子 + dataArr feature数据集 + classLabels 目标变量数据集 + + Returns: + wc 回归系数 + """ + X = mat(dataArr) + labelMat = mat(classLabels).transpose() + m, n = shape(X) + w = zeros((n, 1)) + for i in range(m): + w += multiply(alphas[i] * labelMat[i], X[i, :].T) + return w + + +''' +#######******************************** +Non-Kernel VErsions below +#######******************************** +''' + +class optStruct: + def __init__(self, dataMatIn, classLabels, C, toler): # Initialize the structure with the parameters + self.X = dataMatIn + self.labelMat = classLabels + self.C = C + self.tol = toler + self.m = shape(dataMatIn)[0] + self.alphas = mat(zeros((self.m, 1))) + self.b = 0 + self.eCache = mat(zeros((self.m, 2))) # first column is valid flag + + +def calcEk(oS, k): + fXk = float(multiply(oS.alphas, oS.labelMat).T * (oS.X * oS.X[k, :].T)) + oS.b + Ek = fXk - float(oS.labelMat[k]) + return Ek + + +def selectJ(i, oS, Ei): # this is the second choice -heurstic, and calcs Ej + maxK = -1 + maxDeltaE = 0 + Ej = 0 + oS.eCache[i] = [1, Ei] # set valid #choose the alpha that gives the maximum delta E + validEcacheList = nonzero(oS.eCache[:, 0].A)[0] + if (len(validEcacheList)) > 1: + for k in validEcacheList: # loop through valid Ecache values and find the one that maximizes delta E + if k == i: continue # don't calc for i, waste of time + Ek = calcEk(oS, k) + deltaE = abs(Ei - Ek) + if (deltaE > maxDeltaE): + maxK = k + maxDeltaE = deltaE + Ej = Ek + return maxK, Ej + else: # in this case (first time around) we don't have any valid eCache values + j = selectJrand(i, oS.m) + Ej = calcEk(oS, j) + return j, Ej + + +def updateEk(oS, k): # after any alpha has changed update the new value in the cache + Ek = calcEk(oS, k) + oS.eCache[k] = [1, Ek] + + +def innerL(i, oS): + Ei = calcEk(oS, i) + if ((oS.labelMat[i] * Ei < -oS.tol) and (oS.alphas[i] < oS.C)) or ( + (oS.labelMat[i] * Ei > oS.tol) and (oS.alphas[i] > 0)): + j, Ej = selectJ(i, oS, Ei) # this has been changed from selectJrand + alphaIold = oS.alphas[i].copy() + alphaJold = oS.alphas[j].copy() + if (oS.labelMat[i] != oS.labelMat[j]): + L = max(0, oS.alphas[j] - oS.alphas[i]) + H = min(oS.C, oS.C + oS.alphas[j] - oS.alphas[i]) + else: + L = max(0, oS.alphas[j] + oS.alphas[i] - oS.C) + H = min(oS.C, oS.alphas[j] + oS.alphas[i]) + if L == H: + print("L==H") + return 0 + eta = 2.0 * oS.X[i, :] * oS.X[j, :].T - oS.X[i, :] * oS.X[i, :].T - oS.X[j, :] * oS.X[j, :].T + if eta >= 0: + print("eta>=0") + return 0 + oS.alphas[j] -= oS.labelMat[j] * (Ei - Ej) / eta + oS.alphas[j] = clipAlpha(oS.alphas[j], H, L) + updateEk(oS, j) # added this for the Ecache + if (abs(oS.alphas[j] - alphaJold) < 0.00001): + print("j not moving enough") + return 0 + oS.alphas[i] += oS.labelMat[j] * oS.labelMat[i] * (alphaJold - oS.alphas[j]) # update i by the same amount as j + updateEk(oS, i) # added this for the Ecache #the update is in the oppostie direction + b1 = oS.b - Ei - oS.labelMat[i] * (oS.alphas[i] - alphaIold) * oS.X[i, :] * oS.X[i, :].T - oS.labelMat[j] * ( + oS.alphas[j] - alphaJold) * oS.X[i, :] * oS.X[j, :].T + b2 = oS.b - Ej - oS.labelMat[i] * (oS.alphas[i] - alphaIold) * oS.X[i, :] * oS.X[j, :].T - oS.labelMat[j] * ( + oS.alphas[j] - alphaJold) * oS.X[j, :] * oS.X[j, :].T + if (0 < oS.alphas[i]) and (oS.C > oS.alphas[i]): + oS.b = b1 + elif (0 < oS.alphas[j]) and (oS.C > oS.alphas[j]): + oS.b = b2 + else: + oS.b = (b1 + b2) / 2.0 + return 1 + else: + return 0 + + +def smoP(dataMatIn, classLabels, C, toler, maxIter): # full Platt SMO + oS = optStruct(mat(dataMatIn), mat(classLabels).transpose(), C, toler) + iter = 0 + entireSet = True + alphaPairsChanged = 0 + while (iter < maxIter) and ((alphaPairsChanged > 0) or (entireSet)): + alphaPairsChanged = 0 + if entireSet: # go over all + for i in range(oS.m): + alphaPairsChanged += innerL(i, oS) + print("fullSet, iter: %d i:%d, pairs changed %d" % (iter, i, alphaPairsChanged)) + iter += 1 + else: # go over non-bound (railed) alphas + nonBoundIs = nonzero((oS.alphas.A > 0) * (oS.alphas.A < C))[0] + for i in nonBoundIs: + alphaPairsChanged += innerL(i, oS) + print("non-bound, iter: %d i:%d, pairs changed %d" % (iter, i, alphaPairsChanged)) + iter += 1 + if entireSet: + entireSet = False # toggle entire set loop + elif (alphaPairsChanged == 0): + entireSet = True + print("iteration number: %d" % iter) + return oS.b, oS.alphas + + +def plotfig_SVM(xArr, yArr, ws, b, alphas): + """ + 参考地址: + http://blog.csdn.net/maoersong/article/details/24315633 + http://www.cnblogs.com/JustForCS/p/5283489.html + http://blog.csdn.net/kkxgx/article/details/6951959 + """ + + xMat = mat(xArr) + yMat = mat(yArr) + + # b原来是矩阵,先转为数组类型后其数组大小为(1,1),所以后面加[0],变为(1,) + b = array(b)[0] + fig = plt.figure() + ax = fig.add_subplot(111) + + # 注意flatten的用法 + ax.scatter(xMat[:, 0].flatten().A[0], xMat[:, 1].flatten().A[0]) + + # x最大值,最小值根据原数据集dataArr[:, 0]的大小而定 + x = arange(-1.0, 10.0, 0.1) + + # 根据x.w + b = 0 得到,其式子展开为w0.x1 + w1.x2 + b = 0, x2就是y值 + y = (-b-ws[0, 0]*x)/ws[1, 0] + ax.plot(x, y) + + for i in range(shape(yMat[0, :])[1]): + if yMat[0, i] > 0: + ax.plot(xMat[i, 0], xMat[i, 1], 'cx') + else: + ax.plot(xMat[i, 0], xMat[i, 1], 'kp') + + # 找到支持向量,并在图中标红 + for i in range(100): + if alphas[i] > 0.0: + ax.plot(xMat[i, 0], xMat[i, 1], 'ro') + plt.show() + + +if __name__ == "__main__": + # 获取特征和目标变量 + dataArr, labelArr = loadDataSet('input/6.SVM/testSet.txt') + # print labelArr + + # b是常量值, alphas是拉格朗日乘子 + b, alphas = smoP(dataArr, labelArr, 0.6, 0.001, 40) + print '/n/n/n' + print 'b=', b + print 'alphas[alphas>0]=', alphas[alphas > 0] + print 'shape(alphas[alphas > 0])=', shape(alphas[alphas > 0]) + for i in range(100): + if alphas[i] > 0: + print dataArr[i], labelArr[i] + # 画图 + ws = calcWs(alphas, dataArr, labelArr) + plotfig_SVM(dataArr, labelArr, ws, b, alphas) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +def testRbf(k1=1.3): + dataArr, labelArr = loadDataSet('testSetRBF.txt') + b, alphas = smoP(dataArr, labelArr, 200, 0.0001, 10000, ('rbf', k1)) # C=200 important + datMat = mat(dataArr) + labelMat = mat(labelArr).transpose() + svInd = nonzero(alphas.A > 0)[0] + sVs = datMat[svInd] # get matrix of only support vectors + labelSV = labelMat[svInd] + print("there are %d Support Vectors" % shape(sVs)[0]) + m, n = shape(datMat) + errorCount = 0 + for i in range(m): + kernelEval = kernelTrans(sVs, datMat[i, :], ('rbf', k1)) + predict = kernelEval.T * multiply(labelSV, alphas[svInd]) + b + if sign(predict) != sign(labelArr[i]): errorCount += 1 + print("the training error rate is: %f" % (float(errorCount) / m)) + dataArr, labelArr = loadDataSet('testSetRBF2.txt') + errorCount = 0 + datMat = mat(dataArr) + labelMat = mat(labelArr).transpose() + m, n = shape(datMat) + for i in range(m): + kernelEval = kernelTrans(sVs, datMat[i, :], ('rbf', k1)) + predict = kernelEval.T * multiply(labelSV, alphas[svInd]) + b + if sign(predict) != sign(labelArr[i]): errorCount += 1 + print("the test error rate is: %f" % (float(errorCount) / m)) + + +def img2vector(filename): + returnVect = zeros((1, 1024)) + fr = open(filename) + for i in range(32): + lineStr = fr.readline() + for j in range(32): + returnVect[0, 32 * i + j] = int(lineStr[j]) + return returnVect + + +def loadImages(dirName): + from os import listdir + hwLabels = [] + print(dirName) + trainingFileList = listdir(dirName) # load the training set + m = len(trainingFileList) + trainingMat = zeros((m, 1024)) + for i in range(m): + fileNameStr = trainingFileList[i] + fileStr = fileNameStr.split('.')[0] # take off .txt + classNumStr = int(fileStr.split('_')[0]) + if classNumStr == 9: + hwLabels.append(-1) + else: + hwLabels.append(1) + trainingMat[i, :] = img2vector('%s/%s' % (dirName, fileNameStr)) + return trainingMat, hwLabels + + +def testDigits(kTup=('rbf', 10)): + dataArr, labelArr = loadImages('trainingDigits') + b, alphas = smoP(dataArr, labelArr, 200, 0.0001, 10000, kTup) + datMat = mat(dataArr) + labelMat = mat(labelArr).transpose() + svInd = nonzero(alphas.A > 0)[0] + sVs = datMat[svInd] + labelSV = labelMat[svInd] + print("there are %d Support Vectors" % shape(sVs)[0]) + m, n = shape(datMat) + errorCount = 0 + for i in range(m): + kernelEval = kernelTrans(sVs, datMat[i, :], kTup) + predict = kernelEval.T * multiply(labelSV, alphas[svInd]) + b + if sign(predict) != sign(labelArr[i]): errorCount += 1 + print("the training error rate is: %f" % (float(errorCount) / m)) + dataArr, labelArr = loadImages('testDigits') + errorCount = 0 + datMat = mat(dataArr) + labelMat = mat(labelArr).transpose() + m, n = shape(datMat) + for i in range(m): + kernelEval = kernelTrans(sVs, datMat[i, :], kTup) + predict = kernelEval.T * multiply(labelSV, alphas[svInd]) + b + if sign(predict) != sign(labelArr[i]): errorCount += 1 + print("the test error rate is: %f" % (float(errorCount) / m)) +