From d117f8fa44a4cf2f51c0fb1a1a6bac65527a63b0 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Tue, 25 Nov 2014 02:01:19 -0800 Subject: [SPARK-4526][MLLIB]GradientDescent get a wrong gradient value according to the gradient formula. This is caused by the miniBatchSize parameter.The number of `RDD.sample` returns is not fixed. cc mengxr Author: GuoQiang Li Closes #3399 from witgo/GradientDescent and squashes the following commits: 13cb228 [GuoQiang Li] review commit 668ab66 [GuoQiang Li] Double to Long b6aa11a [GuoQiang Li] Check miniBatchSize is greater than 0 0b5c3e3 [GuoQiang Li] Minor fix 12e7424 [GuoQiang Li] GradientDescent get a wrong gradient value according to the gradient formula, which is caused by the miniBatchSize parameter. (cherry picked from commit f515f9432b05f7e090b651c5536aa706d1cde487) Signed-off-by: Xiangrui Meng --- .../spark/mllib/optimization/GradientDescent.scala | 45 +++++++++++++--------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index a691205639..0857877951 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -160,14 +160,15 @@ object GradientDescent extends Logging { val stochasticLossHistory = new ArrayBuffer[Double](numIterations) val numExamples = data.count() - val miniBatchSize = numExamples * miniBatchFraction // if no data, return initial weights to avoid NaNs if (numExamples == 0) { - - logInfo("GradientDescent.runMiniBatchSGD returning initial weights, no data found") + logWarning("GradientDescent.runMiniBatchSGD returning initial weights, no data found") return (initialWeights, stochasticLossHistory.toArray) + } + if (numExamples * miniBatchFraction < 1) { + logWarning("The miniBatchFraction is too small") } // Initialize weights as a column vector @@ -185,25 +186,31 @@ object GradientDescent extends Logging { val bcWeights = data.context.broadcast(weights) // Sample a subset (fraction miniBatchFraction) of the total data // compute and sum up the subgradients on this subset (this is one map-reduce) - val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i) - .treeAggregate((BDV.zeros[Double](n), 0.0))( - seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => - val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad)) - (grad, loss + l) + val (gradientSum, lossSum, miniBatchSize) = data.sample(false, miniBatchFraction, 42 + i) + .treeAggregate((BDV.zeros[Double](n), 0.0, 0L))( + seqOp = (c, v) => { + // c: (grad, loss, count), v: (label, features) + val l = gradient.compute(v._2, v._1, bcWeights.value, Vectors.fromBreeze(c._1)) + (c._1, c._2 + l, c._3 + 1) }, - combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => - (grad1 += grad2, loss1 + loss2) + combOp = (c1, c2) => { + // c: (grad, loss, count) + (c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3) }) - /** - * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration - * and regVal is the regularization value computed in the previous iteration as well. - */ - stochasticLossHistory.append(lossSum / miniBatchSize + regVal) - val update = updater.compute( - weights, Vectors.fromBreeze(gradientSum / miniBatchSize), stepSize, i, regParam) - weights = update._1 - regVal = update._2 + if (miniBatchSize > 0) { + /** + * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration + * and regVal is the regularization value computed in the previous iteration as well. + */ + stochasticLossHistory.append(lossSum / miniBatchSize + regVal) + val update = updater.compute( + weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), stepSize, i, regParam) + weights = update._1 + regVal = update._2 + } else { + logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero") + } } logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format( -- cgit v1.2.3