From 0ddfeab7a614dae544afd1e5ccff41606985f4a7 Mon Sep 17 00:00:00 2001 From: jiangzhonglian Date: Mon, 20 Mar 2017 18:20:28 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0Knn=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/python/02.kNN/kNN.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/src/python/02.kNN/kNN.py b/src/python/02.kNN/kNN.py index 271aa621..f2f8616a 100644 --- a/src/python/02.kNN/kNN.py +++ b/src/python/02.kNN/kNN.py @@ -37,8 +37,29 @@ def classify0(inX, dataSet, labels, k): # 1. 距离计算 dataSetSize = dataSet.shape[0] # tile生成和训练样本对应的矩阵,并与训练样本求差 + """ + tile: 列-3表示复制的行树, 行-1/2表示对inx的重复的次数 + + In [8]: tile(inx, (3, 1)) + Out[8]: + array([[1, 2], + [1, 2], + [1, 2]]) + + In [9]: tile(inx, (3, 2)) + Out[9]: + array([[1, 2, 1, 2], + [1, 2, 1, 2], + [1, 2, 1, 2]]) + """ diffMat = tile(inX, (dataSetSize, 1)) - dataSet """ + 欧氏距离: 点到点之间的距离 + 第一行: 同一个点 到 dataSet的第一个点的距离。 + 第二行: 同一个点 到 dataSet的第二个点的距离。 + ... + 第N行: 同一个点 到 dataSet的第N个点的距离。 + [[1,2,3],[1,2,3]]-[[1,2,3],[1,2,0]] (A1-A2)^2+(B1-B2)^2+(c1-c2)^2 """ @@ -104,10 +125,14 @@ def autoNorm(dataSet): 归一化特征值,消除属性之间量级不同导致的影响 :param dataSet: 数据集 :return: 归一化后的数据集normDataSet,ranges和minVals即最小值与范围,并没有用到 + + 归一化公式: + Y = (X-Xmin)-(Xmax-Xmin) """ # 计算每种属性的最大值、最小值、范围 minVals = dataSet.min(0) maxVals = dataSet.max(0) + # 极差 ranges = maxVals - minVals normDataSet = zeros(shape(dataSet)) m = dataSet.shape[0] @@ -123,14 +148,16 @@ def datingClassTest(): 对约会网站的测试方法 :return: 错误数 """ - hoRatio = 0.9 # 测试范围,一部分测试一部分作为样本 + # 设置测试数据的的一个比例(训练数据集比例=1-hoRatio) + hoRatio = 0.1 # 测试范围,一部分测试一部分作为样本 # 从文件中加载数据 datingDataMat, datingLabels = file2matrix('testData/datingTestSet2.txt') # load data setfrom file # 归一化数据 normMat, ranges, minVals = autoNorm(datingDataMat) m = normMat.shape[0] - # 测试的数据 + # 设置测试的样本数量, numTestVecs:m表示训练样本的数量 numTestVecs = int(m * hoRatio) + print 'numTestVecs=', numTestVecs errorCount = 0.0 for i in range(numTestVecs): # 对数据测试, @@ -162,6 +189,7 @@ def handwritingClassTest(): trainingFileList = listdir('testData/trainingDigits') # load the training set m = len(trainingFileList) trainingMat = zeros((m, 1024)) + # hwLabels存储0~9对应的index位置, trainingMat存放的每个位置对应的图片向量 for i in range(m): fileNameStr = trainingFileList[i] fileStr = fileNameStr.split('.')[0] # take off .txt