aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala28
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala19
2 files changed, 37 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
index e0e41f711b..7a714db853 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
@@ -241,16 +241,24 @@ object LBFGS extends Logging {
val bcW = data.context.broadcast(w)
val localGradient = gradient
- val (gradientSum, lossSum) = data.treeAggregate((Vectors.zeros(n), 0.0))(
- seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
- val l = localGradient.compute(
- features, label, bcW.value, grad)
- (grad, loss + l)
- },
- combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
- axpy(1.0, grad2, grad1)
- (grad1, loss1 + loss2)
- })
+ val seqOp = (c: (Vector, Double), v: (Double, Vector)) =>
+ (c, v) match {
+ case ((grad, loss), (label, features)) =>
+ val denseGrad = grad.toDense
+ val l = localGradient.compute(features, label, bcW.value, denseGrad)
+ (denseGrad, loss + l)
+ }
+
+ val combOp = (c1: (Vector, Double), c2: (Vector, Double)) =>
+ (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
+ val denseGrad1 = grad1.toDense
+ val denseGrad2 = grad2.toDense
+ axpy(1.0, denseGrad2, denseGrad1)
+ (denseGrad1, loss1 + loss2)
+ }
+
+ val zeroSparseVector = Vectors.sparse(n, Seq())
+ val (gradientSum, lossSum) = data.treeAggregate((zeroSparseVector, 0.0))(seqOp, combOp)
// broadcasted model is not needed anymore
bcW.destroy()
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
index 75ae0eb32f..572959200f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
@@ -230,6 +230,25 @@ class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers
(weightLBFGS(0) ~= weightGD(0) relTol 0.02) && (weightLBFGS(1) ~= weightGD(1) relTol 0.02),
"The weight differences between LBFGS and GD should be within 2%.")
}
+
+ test("SPARK-18471: LBFGS aggregator on empty partitions") {
+ val regParam = 0
+
+ val initialWeightsWithIntercept = Vectors.dense(0.0)
+ val convergenceTol = 1e-12
+ val numIterations = 1
+ val dataWithEmptyPartitions = sc.parallelize(Seq((1.0, Vectors.dense(2.0))), 2)
+
+ LBFGS.runLBFGS(
+ dataWithEmptyPartitions,
+ gradient,
+ simpleUpdater,
+ numCorrections,
+ convergenceTol,
+ numIterations,
+ regParam,
+ initialWeightsWithIntercept)
+ }
}
class LBFGSClusterSuite extends SparkFunSuite with LocalClusterSparkContext {