From 0b4bbf48a151fb9a4cd268c557806a9e68fdd023 Mon Sep 17 00:00:00 2001 From: jiangzhonglian Date: Thu, 14 Sep 2017 16:35:06 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=205=20logistic=E5=9B=9E?= =?UTF-8?q?=E5=BD=92=20=E7=90=86=E8=AE=BA=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/5.Logistic回归.md | 10 +++------- src/python/5.Logistic/logistic.py | 8 ++++---- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/docs/5.Logistic回归.md b/docs/5.Logistic回归.md index b598da29..703e1e68 100644 --- a/docs/5.Logistic回归.md +++ b/docs/5.Logistic回归.md @@ -78,17 +78,13 @@ Sigmoid 函数的输入记为 z ,由下面公式得到: 梯度上升算法用来求函数的最大值,而梯度下降算法用来求函数的最小值。 -**如果大家对上面的例子不理解,下面我们看一个比较容易理解的例子。** +**局部最优现象** ![梯度下降图_4](../images/5.Logistic/LR_20.png) 上图表示参数 θ 与误差函数 J(θ) 的关系图,红色的部分是表示 J(θ) 有着比较高的取值,我们需要的是,能够让 J(θ) 的值尽量的低。也就是深蓝色的部分。θ0,θ1 表示 θ 向量的两个维度。 -在上面提到梯度下降法的第一步是给 θ 给一个初值,假设随机给的初值是在图上的十字点。 - -然后我们将 θ 按照梯度下降的方向进行调整,就会使得 J(θ) 往更低的方向进行变化,如图所示,算法的结束将是在θ下降到无法继续下降为止。 - -当然,可能梯度下降的最终点并非是全局最小点,可能是一个局部最小点,如我们上图中的右边的梯度下降曲线,描述的是最终到达一个局部最小点,这是我们重新选择了一个初始点得到的。 +可能梯度下降的最终点并非是全局最小点,可能是一个局部最小点,如我们上图中的右边的梯度下降曲线,描述的是最终到达一个局部最小点,这是我们重新选择了一个初始点得到的。 看来我们这个算法将会在很大的程度上被初始点的选择影响而陷入局部最小点。 @@ -100,7 +96,7 @@ Sigmoid 函数的输入记为 z ,由下面公式得到: 每个回归系数初始化为 1 重复 R 次: 计算整个数据集的梯度 - 使用 alpha x gradient 更新回归系数的向量 + 使用 步长 x 梯度 更新回归系数的向量 返回回归系数 ``` diff --git a/src/python/5.Logistic/logistic.py b/src/python/5.Logistic/logistic.py index b281e7a1..f0e2eafb 100644 --- a/src/python/5.Logistic/logistic.py +++ b/src/python/5.Logistic/logistic.py @@ -199,9 +199,9 @@ def simpleTest(): # 因为数组没有是复制n份, array的乘法就是乘法 dataArr = array(dataMat) # print dataArr - weights = gradAscent(dataArr, labelMat) + # weights = gradAscent(dataArr, labelMat) # weights = stocGradAscent0(dataArr, labelMat) - # weights = stocGradAscent1(dataArr, labelMat) + weights = stocGradAscent1(dataArr, labelMat) # print '*'*30, weights # 数据可视化 @@ -278,5 +278,5 @@ def multiTest(): if __name__ == "__main__": - simpleTest() - # multiTest() + # simpleTest() + multiTest()