From 3ad4f112a1ab86bd15f9ca0bb1a68415ca7a41ee Mon Sep 17 00:00:00 2001 From: geekidentity Date: Tue, 14 Mar 2017 22:07:14 +0800 Subject: [PATCH] add classify0() in kNN.py --- src/python/02.kNN/kNN.py | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/src/python/02.kNN/kNN.py b/src/python/02.kNN/kNN.py index acebbd2c..cd75ebca 100644 --- a/src/python/02.kNN/kNN.py +++ b/src/python/02.kNN/kNN.py @@ -1,6 +1,5 @@ ''' 导入科学计算包numpy和运算符模块operator -@author: geekidentity ''' from numpy import * import operator @@ -10,9 +9,37 @@ import operator 调用方式 import kNN - group, labels = createDateSet()11 + group, labels = kNN.createDataSet() ''' def createDataSet(): group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]]) labels = ['A','A','B','B'] - return group, labels \ No newline at end of file + return group, labels + + +def classify0(inX, dataSet, labels, k): + ''' + inX: 用于分类的输入向量 + dataSet: 输入的训练样本集 + labels: 标签向量 + k: 选择最近邻居的数目 + 注意:labels元素数目和dataSet行数相同;程序使用欧式距离公式. + + 预测数据所在分类可在输入下列命令 + kNN.classify0([0,0], group, labels, 3) + ''' + # 1. 距离计算 + dataSetSize = dataSet.shape[0] + diffMat = tile(inX, (dataSetSize,1)) - dataSet + sqDiffMat = diffMat**2 + sqDistances = sqDiffMat.sum(axis=1) + distances = sqDistances**0.5 + sortedDistIndicies = distances.argsort() + # 2. 选择距离最小的k个点 + classCount = {} + for i in range(k): + voteIlabel = labels[sortedDistIndicies[i]] + classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 + # 3. 排序 + sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) + return sortedClassCount[0][0] \ No newline at end of file