mirror of
https://github.com/apachecn/ailearning.git
synced 2026-05-08 06:33:55 +08:00
更新 15章 代码新格式
This commit is contained in:
@@ -2,9 +2,11 @@
|
||||
# coding:utf8
|
||||
'''
|
||||
Created on 2017-04-07
|
||||
Update on 2017-06-20
|
||||
MapReduce version of Pegasos SVM
|
||||
Using mrjob to automate job flow
|
||||
@author: Peter/ApacheCN-xy
|
||||
@author: Peter/ApacheCN-xy/片刻
|
||||
《机器学习实战》更新地址:https://github.com/apachecn/MachineLearning
|
||||
'''
|
||||
from mrjob.job import MRJob
|
||||
|
||||
@@ -17,14 +19,14 @@ class MRsvm(MRJob):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MRsvm, self).__init__(*args, **kwargs)
|
||||
self.data = pickle.load(open('input/15.BigData_MapReduce/svmDat27'))
|
||||
self.data = pickle.load(open('/opt/git/MachineLearning/input/15.BigData_MapReduce/svmDat27'))
|
||||
self.w = 0
|
||||
self.eta = 0.69
|
||||
self.dataList = []
|
||||
self.k = self.options.batchsize
|
||||
self.numMappers = 1
|
||||
self.t = 1 # iteration number
|
||||
|
||||
|
||||
def configure_options(self):
|
||||
super(MRsvm, self).configure_options()
|
||||
self.add_passthrough_option(
|
||||
@@ -42,20 +44,20 @@ class MRsvm(MRJob):
|
||||
self.w = inVals[1]
|
||||
elif inVals[0] == 'x':
|
||||
self.dataList.append(inVals[1]) # 累积数据点计算
|
||||
elif inVals[0] == 't':
|
||||
elif inVals[0] == 't': # 迭代次数
|
||||
self.t = inVals[1]
|
||||
else:
|
||||
self.eta = inVals # 这用于 debug, eta未在map中使用
|
||||
self.eta = inVals # 这用于 debug, eta未在map中使用
|
||||
|
||||
def map_fin(self):
|
||||
labels = self.data[:,-1]
|
||||
X = self.data[:, 0:-1] # 将数据重新形成 X 和 Y
|
||||
if self.w == 0:
|
||||
labels = self.data[:, -1]
|
||||
X = self.data[:, :-1] # 将数据重新形成 X 和 Y
|
||||
if self.w == 0:
|
||||
self.w = [0.001] * shape(X)[1] # 在第一次迭代时,初始化 w
|
||||
for index in self.dataList:
|
||||
p = mat(self.w)*X[index, :].T # calc p=w*dataSet[key].T
|
||||
p = mat(self.w)*X[index, :].T # calc p=w*dataSet[key].T
|
||||
if labels[index]*p < 1.0:
|
||||
yield (1, ['u', index]) # 确保一切数据包含相同的key
|
||||
yield (1, ['u', index]) # 确保一切数据包含相同的key
|
||||
yield (1, ['w', self.w]) # 它们将在同一个 reducer
|
||||
yield (1, ['t', self.t])
|
||||
|
||||
@@ -66,7 +68,7 @@ class MRsvm(MRJob):
|
||||
elif valArr[0] == 'w':
|
||||
self.w = valArr[1]
|
||||
elif valArr[0] == 't':
|
||||
self.t = valArr[1]
|
||||
self.t = valArr[1]
|
||||
|
||||
labels = self.data[:, -1]
|
||||
X = self.data[:, 0:-1]
|
||||
|
||||
Reference in New Issue
Block a user