mirror of
https://github.com/apachecn/ailearning.git
synced 2026-02-11 14:26:04 +08:00
add classify0() in kNN.py
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
'''
|
||||
导入科学计算包numpy和运算符模块operator
|
||||
@author: geekidentity
|
||||
'''
|
||||
from numpy import *
|
||||
import operator
|
||||
@@ -10,9 +9,37 @@ import operator
|
||||
|
||||
调用方式
|
||||
import kNN
|
||||
group, labels = createDateSet()11
|
||||
group, labels = kNN.createDataSet()
|
||||
'''
|
||||
def createDataSet():
|
||||
group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
|
||||
labels = ['A','A','B','B']
|
||||
return group, labels
|
||||
return group, labels
|
||||
|
||||
|
||||
def classify0(inX, dataSet, labels, k):
|
||||
'''
|
||||
inX: 用于分类的输入向量
|
||||
dataSet: 输入的训练样本集
|
||||
labels: 标签向量
|
||||
k: 选择最近邻居的数目
|
||||
注意:labels元素数目和dataSet行数相同;程序使用欧式距离公式.
|
||||
|
||||
预测数据所在分类可在输入下列命令
|
||||
kNN.classify0([0,0], group, labels, 3)
|
||||
'''
|
||||
# 1. 距离计算
|
||||
dataSetSize = dataSet.shape[0]
|
||||
diffMat = tile(inX, (dataSetSize,1)) - dataSet
|
||||
sqDiffMat = diffMat**2
|
||||
sqDistances = sqDiffMat.sum(axis=1)
|
||||
distances = sqDistances**0.5
|
||||
sortedDistIndicies = distances.argsort()
|
||||
# 2. 选择距离最小的k个点
|
||||
classCount = {}
|
||||
for i in range(k):
|
||||
voteIlabel = labels[sortedDistIndicies[i]]
|
||||
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
|
||||
# 3. 排序
|
||||
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
|
||||
return sortedClassCount[0][0]
|
||||
Reference in New Issue
Block a user