aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorGuoQiang Li <witgo@qq.com>2014-11-25 02:01:19 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-25 02:01:19 -0800
commitf515f9432b05f7e090b651c5536aa706d1cde487 (patch)
treeca61dd92c727476c798388043c63474bba33311a /mllib/src
parent89f912264603741c7d980135c26102d63e11791f (diff)
downloadspark-f515f9432b05f7e090b651c5536aa706d1cde487.tar.gz
spark-f515f9432b05f7e090b651c5536aa706d1cde487.tar.bz2
spark-f515f9432b05f7e090b651c5536aa706d1cde487.zip
[SPARK-4526][MLLIB]GradientDescent get a wrong gradient value according to the gradient formula.
This is caused by the miniBatchSize parameter.The number of `RDD.sample` returns is not fixed. cc mengxr Author: GuoQiang Li <witgo@qq.com> Closes #3399 from witgo/GradientDescent and squashes the following commits: 13cb228 [GuoQiang Li] review commit 668ab66 [GuoQiang Li] Double to Long b6aa11a [GuoQiang Li] Check miniBatchSize is greater than 0 0b5c3e3 [GuoQiang Li] Minor fix 12e7424 [GuoQiang Li] GradientDescent get a wrong gradient value according to the gradient formula, which is caused by the miniBatchSize parameter.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala45
1 files changed, 26 insertions, 19 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index a691205639..0857877951 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -160,14 +160,15 @@ object GradientDescent extends Logging {
val stochasticLossHistory = new ArrayBuffer[Double](numIterations)
val numExamples = data.count()
- val miniBatchSize = numExamples * miniBatchFraction
// if no data, return initial weights to avoid NaNs
if (numExamples == 0) {
-
- logInfo("GradientDescent.runMiniBatchSGD returning initial weights, no data found")
+ logWarning("GradientDescent.runMiniBatchSGD returning initial weights, no data found")
return (initialWeights, stochasticLossHistory.toArray)
+ }
+ if (numExamples * miniBatchFraction < 1) {
+ logWarning("The miniBatchFraction is too small")
}
// Initialize weights as a column vector
@@ -185,25 +186,31 @@ object GradientDescent extends Logging {
val bcWeights = data.context.broadcast(weights)
// Sample a subset (fraction miniBatchFraction) of the total data
// compute and sum up the subgradients on this subset (this is one map-reduce)
- val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i)
- .treeAggregate((BDV.zeros[Double](n), 0.0))(
- seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
- val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad))
- (grad, loss + l)
+ val (gradientSum, lossSum, miniBatchSize) = data.sample(false, miniBatchFraction, 42 + i)
+ .treeAggregate((BDV.zeros[Double](n), 0.0, 0L))(
+ seqOp = (c, v) => {
+ // c: (grad, loss, count), v: (label, features)
+ val l = gradient.compute(v._2, v._1, bcWeights.value, Vectors.fromBreeze(c._1))
+ (c._1, c._2 + l, c._3 + 1)
},
- combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
- (grad1 += grad2, loss1 + loss2)
+ combOp = (c1, c2) => {
+ // c: (grad, loss, count)
+ (c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3)
})
- /**
- * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration
- * and regVal is the regularization value computed in the previous iteration as well.
- */
- stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
- val update = updater.compute(
- weights, Vectors.fromBreeze(gradientSum / miniBatchSize), stepSize, i, regParam)
- weights = update._1
- regVal = update._2
+ if (miniBatchSize > 0) {
+ /**
+ * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration
+ * and regVal is the regularization value computed in the previous iteration as well.
+ */
+ stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
+ val update = updater.compute(
+ weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), stepSize, i, regParam)
+ weights = update._1
+ regVal = update._2
+ } else {
+ logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero")
+ }
}
logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format(