mirror of
https://github.com/apachecn/ailearning.git
synced 2026-02-11 22:35:35 +08:00
DEV: ADD KNN PYTHON CODE AND DATA FILE
This commit is contained in:
@@ -1,24 +1,28 @@
|
||||
#!/usr/bin/env python
|
||||
# encoding: utf-8
|
||||
'''
|
||||
导入科学计算包numpy和运算符模块operator
|
||||
'''
|
||||
from numpy import *
|
||||
import operator
|
||||
from os import listdir
|
||||
|
||||
'''
|
||||
创建数据集和标签
|
||||
|
||||
调用方式
|
||||
import kNN
|
||||
group, labels = kNN.createDataSet()
|
||||
'''
|
||||
def createDataSet():
|
||||
group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
|
||||
labels = ['A','A','B','B']
|
||||
"""
|
||||
创建数据集和标签
|
||||
|
||||
调用方式
|
||||
import kNN
|
||||
group, labels = kNN.createDataSet()
|
||||
"""
|
||||
group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
|
||||
labels = ['A', 'A', 'B', 'B']
|
||||
return group, labels
|
||||
|
||||
|
||||
def classify0(inX, dataSet, labels, k):
|
||||
'''
|
||||
"""
|
||||
inX: 用于分类的输入向量
|
||||
dataSet: 输入的训练样本集
|
||||
labels: 标签向量
|
||||
@@ -27,19 +31,155 @@ def classify0(inX, dataSet, labels, k):
|
||||
|
||||
预测数据所在分类可在输入下列命令
|
||||
kNN.classify0([0,0], group, labels, 3)
|
||||
'''
|
||||
"""
|
||||
# 1. 距离计算
|
||||
dataSetSize = dataSet.shape[0]
|
||||
diffMat = tile(inX, (dataSetSize,1)) - dataSet
|
||||
sqDiffMat = diffMat**2
|
||||
# tile生成和训练样本对应的矩阵,并与训练样本求差
|
||||
diffMat = tile(inX, (dataSetSize, 1)) - dataSet
|
||||
# 取平方
|
||||
sqDiffMat = diffMat ** 2
|
||||
# 将矩阵的每一行相加
|
||||
sqDistances = sqDiffMat.sum(axis=1)
|
||||
distances = sqDistances**0.5
|
||||
# 开方
|
||||
distances = sqDistances ** 0.5
|
||||
# 距离排序
|
||||
sortedDistIndicies = distances.argsort()
|
||||
|
||||
# 2. 选择距离最小的k个点
|
||||
classCount = {}
|
||||
for i in range(k):
|
||||
# 找到该样本的类型
|
||||
voteIlabel = labels[sortedDistIndicies[i]]
|
||||
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
|
||||
# 3. 排序
|
||||
# 在字典中将该类型加一
|
||||
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
|
||||
# 3. 排序并返回出现最多的那个类型
|
||||
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
|
||||
return sortedClassCount[0][0]
|
||||
return sortedClassCount[0][0]
|
||||
|
||||
|
||||
def test1():
|
||||
"""
|
||||
第一个例子演示
|
||||
"""
|
||||
group, labels = createDataSet()
|
||||
print str(group)
|
||||
print str(labels)
|
||||
print classify0([0, 0], group, labels, 3)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------
|
||||
def file2matrix(filename):
|
||||
"""
|
||||
导入训练数据
|
||||
:param filename: 数据文件路径
|
||||
:return: 数据矩阵returnMat和对应的类别classLabelVector
|
||||
"""
|
||||
fr = open(filename)
|
||||
numberOfLines = len(fr.readlines()) # get the number of lines in the file
|
||||
# 生成对应的空矩阵
|
||||
returnMat = zeros((numberOfLines, 3)) # prepare matrix to return
|
||||
classLabelVector = [] # prepare labels return
|
||||
fr = open(filename)
|
||||
index = 0
|
||||
for line in fr.readlines():
|
||||
line = line.strip()
|
||||
listFromLine = line.split('\t')
|
||||
# 每列的属性数据
|
||||
returnMat[index, :] = listFromLine[0:3]
|
||||
# 每列的类别数据
|
||||
classLabelVector.append(int(listFromLine[-1]))
|
||||
index += 1
|
||||
# 返回数据矩阵returnMat和对应的类别classLabelVector
|
||||
return returnMat, classLabelVector
|
||||
|
||||
|
||||
def autoNorm(dataSet):
|
||||
"""
|
||||
归一化特征值,消除属性之间量级不同导致的影响
|
||||
:param dataSet: 数据集
|
||||
:return: 归一化后的数据集normDataSet,ranges和minVals即最小值与范围,并没有用到
|
||||
"""
|
||||
# 计算每种属性的最大值、最小值、范围
|
||||
minVals = dataSet.min(0)
|
||||
maxVals = dataSet.max(0)
|
||||
ranges = maxVals - minVals
|
||||
normDataSet = zeros(shape(dataSet))
|
||||
m = dataSet.shape[0]
|
||||
# 生成与最小值之差组成的矩阵
|
||||
normDataSet = dataSet - tile(minVals, (m, 1))
|
||||
# 将最小值之差除以范围组成矩阵
|
||||
normDataSet = normDataSet / tile(ranges, (m, 1)) # element wise divide
|
||||
return normDataSet, ranges, minVals
|
||||
|
||||
|
||||
def datingClassTest():
|
||||
"""
|
||||
对约会网站的测试方法
|
||||
:return: 错误数
|
||||
"""
|
||||
hoRatio = 0.9 # 测试范围,一部分测试一部分作为样本
|
||||
# 从文件中加载数据
|
||||
datingDataMat, datingLabels = file2matrix('../../../testData/datingTestSet2.txt') # load data setfrom file
|
||||
# 归一化数据
|
||||
normMat, ranges, minVals = autoNorm(datingDataMat)
|
||||
m = normMat.shape[0]
|
||||
# 测试的数据
|
||||
numTestVecs = int(m * hoRatio)
|
||||
errorCount = 0.0
|
||||
for i in range(numTestVecs):
|
||||
# 对数据测试,
|
||||
classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
|
||||
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])
|
||||
if (classifierResult != datingLabels[i]): errorCount += 1.0
|
||||
print "the total error rate is: %f" % (errorCount / float(numTestVecs))
|
||||
print errorCount
|
||||
|
||||
|
||||
def img2vector(filename):
|
||||
"""
|
||||
将图像数据转换为向量
|
||||
:param filename: 图片文件
|
||||
:return: 一纬矩阵
|
||||
"""
|
||||
returnVect = zeros((1, 1024))
|
||||
fr = open(filename)
|
||||
for i in range(32):
|
||||
lineStr = fr.readline()
|
||||
for j in range(32):
|
||||
returnVect[0, 32 * i + j] = int(lineStr[j])
|
||||
return returnVect
|
||||
|
||||
|
||||
def handwritingClassTest():
|
||||
# 1. 导入数据
|
||||
hwLabels = []
|
||||
trainingFileList = listdir('../../../testData/trainingDigits') # load the training set
|
||||
m = len(trainingFileList)
|
||||
trainingMat = zeros((m, 1024))
|
||||
for i in range(m):
|
||||
fileNameStr = trainingFileList[i]
|
||||
fileStr = fileNameStr.split('.')[0] # take off .txt
|
||||
classNumStr = int(fileStr.split('_')[0])
|
||||
hwLabels.append(classNumStr)
|
||||
trainingMat[i, :] = img2vector('../../../testData/trainingDigits/%s' % fileNameStr)
|
||||
|
||||
# 2. 导入测试数据
|
||||
testFileList = listdir('../../../testData/testDigits') # iterate through the test set
|
||||
errorCount = 0.0
|
||||
mTest = len(testFileList)
|
||||
for i in range(mTest):
|
||||
fileNameStr = testFileList[i]
|
||||
fileStr = fileNameStr.split('.')[0] # take off .txt
|
||||
classNumStr = int(fileStr.split('_')[0])
|
||||
vectorUnderTest = img2vector('../../../testData/testDigits/%s' % fileNameStr)
|
||||
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
|
||||
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)
|
||||
if (classifierResult != classNumStr): errorCount += 1.0
|
||||
print "\nthe total number of errors is: %d" % errorCount
|
||||
print "\nthe total error rate is: %f" % (errorCount / float(mTest))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test1()
|
||||
# datingClassTest()
|
||||
handwritingClassTest()
|
||||
|
||||
@@ -223,7 +223,7 @@ def crossValidation(xArr,yArr,numVal=10):
|
||||
|
||||
|
||||
#test for xianxinghuigui
|
||||
def regression1():
|
||||
def regression1():
|
||||
xArr, yArr = loadDataSet("ex0.txt")
|
||||
xMat = mat(xArr)
|
||||
yMat = mat(yArr)
|
||||
@@ -242,7 +242,7 @@ if __name__ == "__main__":
|
||||
|
||||
|
||||
#test for jiaquanhuigui
|
||||
def regression1():
|
||||
def regression1():
|
||||
xArr, yArr = loadDataSet("ex0.txt")
|
||||
yHat = lwlrTest(xArr, xArr, yArr, 0.003)
|
||||
xMat = mat(xArr)
|
||||
|
||||
Reference in New Issue
Block a user