mirror of
https://github.com/apachecn/ailearning.git
synced 2026-02-11 22:35:35 +08:00
45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
'''
|
||
导入科学计算包numpy和运算符模块operator
|
||
'''
|
||
from numpy import *
|
||
import operator
|
||
|
||
'''
|
||
创建数据集和标签
|
||
|
||
调用方式
|
||
import kNN
|
||
group, labels = kNN.createDataSet()
|
||
'''
|
||
def createDataSet():
|
||
group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
|
||
labels = ['A','A','B','B']
|
||
return group, labels
|
||
|
||
|
||
def classify0(inX, dataSet, labels, k):
|
||
'''
|
||
inX: 用于分类的输入向量
|
||
dataSet: 输入的训练样本集
|
||
labels: 标签向量
|
||
k: 选择最近邻居的数目
|
||
注意:labels元素数目和dataSet行数相同;程序使用欧式距离公式.
|
||
|
||
预测数据所在分类可在输入下列命令
|
||
kNN.classify0([0,0], group, labels, 3)
|
||
'''
|
||
# 1. 距离计算
|
||
dataSetSize = dataSet.shape[0]
|
||
diffMat = tile(inX, (dataSetSize,1)) - dataSet
|
||
sqDiffMat = diffMat**2
|
||
sqDistances = sqDiffMat.sum(axis=1)
|
||
distances = sqDistances**0.5
|
||
sortedDistIndicies = distances.argsort()
|
||
# 2. 选择距离最小的k个点
|
||
classCount = {}
|
||
for i in range(k):
|
||
voteIlabel = labels[sortedDistIndicies[i]]
|
||
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
|
||
# 3. 排序
|
||
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
|
||
return sortedClassCount[0][0] |