aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-07-26 22:56:07 -0700
committerReynold Xin <rxin@apache.org>2014-07-26 22:56:07 -0700
commitaaf2b735fddbebccd28012006ee4647af3b3624f (patch)
treeeb132ba2fa45cddaf7730628403e836afecb34e3 /mllib/src
parentb547f69bdb5f4a6d5f471a2d998c2df6fb2a9347 (diff)
downloadspark-aaf2b735fddbebccd28012006ee4647af3b3624f.tar.gz
spark-aaf2b735fddbebccd28012006ee4647af3b3624f.tar.bz2
spark-aaf2b735fddbebccd28012006ee4647af3b3624f.zip
[SPARK-2361][MLLIB] Use broadcast instead of serializing data directly into task closure
We saw task serialization problems with large feature dimension, which could be avoid if we don't serialize data directly into task but use broadcast variables. This PR uses broadcast in both training and prediction and adds tests to make sure the task size is small. Author: Xiangrui Meng <meng@databricks.com> Closes #1427 from mengxr/broadcast-new and squashes the following commits: b9a1228 [Xiangrui Meng] style update b97c184 [Xiangrui Meng] minimal change to LBFGS 9ebadcc [Xiangrui Meng] add task size test to RowMatrix 9427bf0 [Xiangrui Meng] add task size tests to linear methods e0a5cf2 [Xiangrui Meng] add task size test to GD 28a8411 [Xiangrui Meng] add test for NaiveBayes 380778c [Xiangrui Meng] update KMeans test bccab92 [Xiangrui Meng] add task size test to LBFGS 02103ba [Xiangrui Meng] remove print e73d68e [Xiangrui Meng] update tests for k-means 174cb15 [Xiangrui Meng] use local-cluster for test with a small akka.frameSize 1928a5a [Xiangrui Meng] add test for KMeans task size e00c2da [Xiangrui Meng] use broadcast in GD, KMeans 010d076 [Xiangrui Meng] modify NaiveBayesModel and GLM to use broadcast
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala19
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala7
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala18
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala20
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala25
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala75
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala29
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala34
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala30
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala21
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala21
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala23
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala42
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala7
19 files changed, 330 insertions, 70 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index b6e0c4a80e..6c7be0a4f1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -54,7 +54,13 @@ class NaiveBayesModel private[mllib] (
}
}
- override def predict(testData: RDD[Vector]): RDD[Double] = testData.map(predict)
+ override def predict(testData: RDD[Vector]): RDD[Double] = {
+ val bcModel = testData.context.broadcast(this)
+ testData.mapPartitions { iter =>
+ val model = bcModel.value
+ iter.map(model.predict)
+ }
+ }
override def predict(testData: Vector): Double = {
labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index de22fbb6ff..db425d866b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -165,18 +165,21 @@ class KMeans private (
val activeCenters = activeRuns.map(r => centers(r)).toArray
val costAccums = activeRuns.map(_ => sc.accumulator(0.0))
+ val bcActiveCenters = sc.broadcast(activeCenters)
+
// Find the sum and count of points mapping to each center
val totalContribs = data.mapPartitions { points =>
- val runs = activeCenters.length
- val k = activeCenters(0).length
- val dims = activeCenters(0)(0).vector.length
+ val thisActiveCenters = bcActiveCenters.value
+ val runs = thisActiveCenters.length
+ val k = thisActiveCenters(0).length
+ val dims = thisActiveCenters(0)(0).vector.length
val sums = Array.fill(runs, k)(BDV.zeros[Double](dims).asInstanceOf[BV[Double]])
val counts = Array.fill(runs, k)(0L)
points.foreach { point =>
(0 until runs).foreach { i =>
- val (bestCenter, cost) = KMeans.findClosest(activeCenters(i), point)
+ val (bestCenter, cost) = KMeans.findClosest(thisActiveCenters(i), point)
costAccums(i) += cost
sums(i)(bestCenter) += point.vector
counts(i)(bestCenter) += 1
@@ -264,16 +267,17 @@ class KMeans private (
// to their squared distance from that run's current centers
var step = 0
while (step < initializationSteps) {
+ val bcCenters = data.context.broadcast(centers)
val sumCosts = data.flatMap { point =>
(0 until runs).map { r =>
- (r, KMeans.pointCost(centers(r), point))
+ (r, KMeans.pointCost(bcCenters.value(r), point))
}
}.reduceByKey(_ + _).collectAsMap()
val chosen = data.mapPartitionsWithIndex { (index, points) =>
val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
points.flatMap { p =>
(0 until runs).filter { r =>
- rand.nextDouble() < 2.0 * KMeans.pointCost(centers(r), p) * k / sumCosts(r)
+ rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r)
}.map((_, p))
}
}.collect()
@@ -286,9 +290,10 @@ class KMeans private (
// Finally, we might have a set of more than k candidate centers for each run; weigh each
// candidate by the number of points in the dataset mapping to it and run a local k-means++
// on the weighted centers to pick just k of them
+ val bcCenters = data.context.broadcast(centers)
val weightMap = data.flatMap { p =>
(0 until runs).map { r =>
- ((r, KMeans.findClosest(centers(r), p)._1), 1.0)
+ ((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0)
}
}.reduceByKey(_ + _).collectAsMap()
val finalCenters = (0 until runs).map { r =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index fba21aefaa..5823cb6e52 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -38,7 +38,8 @@ class KMeansModel private[mllib] (val clusterCenters: Array[Vector]) extends Ser
/** Maps given points to their cluster indices. */
def predict(points: RDD[Vector]): RDD[Int] = {
val centersWithNorm = clusterCentersWithNorm
- points.map(p => KMeans.findClosest(centersWithNorm, new BreezeVectorWithNorm(p))._1)
+ val bcCentersWithNorm = points.context.broadcast(centersWithNorm)
+ points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new BreezeVectorWithNorm(p))._1)
}
/** Maps given points to their cluster indices. */
@@ -51,7 +52,8 @@ class KMeansModel private[mllib] (val clusterCenters: Array[Vector]) extends Ser
*/
def computeCost(data: RDD[Vector]): Double = {
val centersWithNorm = clusterCentersWithNorm
- data.map(p => KMeans.pointCost(centersWithNorm, new BreezeVectorWithNorm(p))).sum()
+ val bcCentersWithNorm = data.context.broadcast(centersWithNorm)
+ data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new BreezeVectorWithNorm(p))).sum()
}
private def clusterCentersWithNorm: Iterable[BreezeVectorWithNorm] =
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index 7030eeabe4..9fd760bf78 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -163,6 +163,7 @@ object GradientDescent extends Logging {
// Initialize weights as a column vector
var weights = Vectors.dense(initialWeights.toArray)
+ val n = weights.size
/**
* For the first iteration, the regVal will be initialized as sum of weight squares
@@ -172,12 +173,13 @@ object GradientDescent extends Logging {
weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2
for (i <- 1 to numIterations) {
+ val bcWeights = data.context.broadcast(weights)
// Sample a subset (fraction miniBatchFraction) of the total data
// compute and sum up the subgradients on this subset (this is one map-reduce)
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i)
- .aggregate((BDV.zeros[Double](weights.size), 0.0))(
+ .aggregate((BDV.zeros[Double](n), 0.0))(
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
- val l = gradient.compute(features, label, weights, Vectors.fromBreeze(grad))
+ val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad))
(grad, loss + l)
},
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
index 7bbed9c8fd..179cd4a3f1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
@@ -195,13 +195,14 @@ object LBFGS extends Logging {
override def calculate(weights: BDV[Double]) = {
// Have a local copy to avoid the serialization of CostFun object which is not serializable.
- val localData = data
val localGradient = gradient
+ val n = weights.length
+ val bcWeights = data.context.broadcast(weights)
- val (gradientSum, lossSum) = localData.aggregate((BDV.zeros[Double](weights.size), 0.0))(
+ val (gradientSum, lossSum) = data.aggregate((BDV.zeros[Double](n), 0.0))(
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
val l = localGradient.compute(
- features, label, Vectors.fromBreeze(weights), Vectors.fromBreeze(grad))
+ features, label, Vectors.fromBreeze(bcWeights.value), Vectors.fromBreeze(grad))
(grad, loss + l)
},
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index fe41863bce..54854252d7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -56,9 +56,12 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double
// A small optimization to avoid serializing the entire model. Only the weightsMatrix
// and intercept is needed.
val localWeights = weights
+ val bcWeights = testData.context.broadcast(localWeights)
val localIntercept = intercept
-
- testData.map(v => predictPoint(v, localWeights, localIntercept))
+ testData.mapPartitions { iter =>
+ val w = bcWeights.value
+ iter.map(v => predictPoint(v, w, localIntercept))
+ }
}
/**
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
index faa675b59c..862221d487 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
@@ -92,8 +92,6 @@ public class JavaLogisticRegressionSuite implements Serializable {
testRDD.rdd(), 100, 1.0, 1.0);
int numAccurate = validatePrediction(validationData, model);
- System.out.println(numAccurate);
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
}
-
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 44b757b6a1..3f6ff85937 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -25,7 +25,7 @@ import org.scalatest.Matchers
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
object LogisticRegressionSuite {
@@ -126,3 +126,19 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
}
+
+class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+ }.cache()
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = LogisticRegressionWithSGD.train(points, 2)
+ val predictions = model.predict(points.map(_.features))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index 516895d042..06cdd04f5f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -23,7 +23,7 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
object NaiveBayesSuite {
@@ -96,3 +96,21 @@ class NaiveBayesSuite extends FunSuite with LocalSparkContext {
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
}
+
+class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 10
+ val n = 200000
+ val examples = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map { i =>
+ LabeledPoint(random.nextInt(2), Vectors.dense(Array.fill(n)(random.nextDouble())))
+ }
+ }
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = NaiveBayes.train(examples)
+ val predictions = model.predict(examples.map(_.features))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index 886c71dde3..65e5df58db 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -17,17 +17,16 @@
package org.apache.spark.mllib.classification
-import scala.util.Random
import scala.collection.JavaConversions._
-
-import org.scalatest.FunSuite
+import scala.util.Random
import org.jblas.DoubleMatrix
+import org.scalatest.FunSuite
import org.apache.spark.SparkException
-import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression._
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
object SVMSuite {
@@ -193,3 +192,19 @@ class SVMSuite extends FunSuite with LocalSparkContext {
new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
}
}
+
+class SVMClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+ }.cache()
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = SVMWithSGD.train(points, 2)
+ val predictions = model.predict(points.map(_.features))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 76a3bdf9b1..34bc4537a7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -17,14 +17,16 @@
package org.apache.spark.mllib.clustering
+import scala.util.Random
+
import org.scalatest.FunSuite
-import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
class KMeansSuite extends FunSuite with LocalSparkContext {
- import KMeans.{RANDOM, K_MEANS_PARALLEL}
+ import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM}
test("single cluster") {
val data = sc.parallelize(Array(
@@ -38,26 +40,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
// No matter how many runs or iterations we use, we should get one cluster,
// centered at the mean of the points
- var model = KMeans.train(data, k=1, maxIterations=1)
+ var model = KMeans.train(data, k = 1, maxIterations = 1)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=2)
+ model = KMeans.train(data, k = 1, maxIterations = 2)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=5)
+ model = KMeans.train(data, k = 1, maxIterations = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
assert(model.clusterCenters.head === center)
model = KMeans.train(
- data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL)
+ data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL)
assert(model.clusterCenters.head === center)
}
@@ -100,26 +102,27 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
val center = Vectors.dense(1.0, 3.0, 4.0)
- var model = KMeans.train(data, k=1, maxIterations=1)
+ var model = KMeans.train(data, k = 1, maxIterations = 1)
assert(model.clusterCenters.size === 1)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=2)
+ model = KMeans.train(data, k = 1, maxIterations = 2)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=5)
+ model = KMeans.train(data, k = 1, maxIterations = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1,
+ initializationMode = K_MEANS_PARALLEL)
assert(model.clusterCenters.head === center)
}
@@ -145,25 +148,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
val center = Vectors.sparse(n, Seq((0, 1.0), (1, 3.0), (2, 4.0)))
- var model = KMeans.train(data, k=1, maxIterations=1)
+ var model = KMeans.train(data, k = 1, maxIterations = 1)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=2)
+ model = KMeans.train(data, k = 1, maxIterations = 2)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=5)
+ model = KMeans.train(data, k = 1, maxIterations = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1,
+ initializationMode = K_MEANS_PARALLEL)
assert(model.clusterCenters.head === center)
data.unpersist()
@@ -183,15 +187,15 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
// it will make at least five passes, and it will give non-zero probability to each
// unselected point as long as it hasn't yet selected all of them
- var model = KMeans.train(rdd, k=5, maxIterations=1)
+ var model = KMeans.train(rdd, k = 5, maxIterations = 1)
assert(Set(model.clusterCenters: _*) === Set(points: _*))
// Iterations of Lloyd's should not change the answer either
- model = KMeans.train(rdd, k=5, maxIterations=10)
+ model = KMeans.train(rdd, k = 5, maxIterations = 10)
assert(Set(model.clusterCenters: _*) === Set(points: _*))
// Neither should more runs
- model = KMeans.train(rdd, k=5, maxIterations=10, runs=5)
+ model = KMeans.train(rdd, k = 5, maxIterations = 10, runs = 5)
assert(Set(model.clusterCenters: _*) === Set(points: _*))
}
@@ -220,3 +224,22 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
}
}
}
+
+class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => Vectors.dense(Array.fill(n)(random.nextDouble)))
+ }.cache()
+ for (initMode <- Seq(KMeans.RANDOM, KMeans.K_MEANS_PARALLEL)) {
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = KMeans.train(points, 2, 2, 1, initMode)
+ val predictions = model.predict(points).collect()
+ val cost = model.computeCost(points)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
index a961f89456..325b817980 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
@@ -17,12 +17,13 @@
package org.apache.spark.mllib.linalg.distributed
-import org.scalatest.FunSuite
+import scala.util.Random
import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd}
+import org.scalatest.FunSuite
-import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector}
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
class RowMatrixSuite extends FunSuite with LocalSparkContext {
@@ -193,3 +194,27 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext {
}
}
}
+
+class RowMatrixClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ var mat: RowMatrix = _
+
+ override def beforeAll() {
+ super.beforeAll()
+ val m = 4
+ val n = 200000
+ val rows = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => Vectors.dense(Array.fill(n)(random.nextDouble())))
+ }
+ mat = new RowMatrix(rows)
+ }
+
+ test("task size should be small in svd") {
+ val svd = mat.computeSVD(1, computeU = true)
+ }
+
+ test("task size should be small in summarize") {
+ val summary = mat.computeColumnSummaryStatistics()
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index 951b4f7c6e..dfb2eb7f0d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.optimization
-import scala.util.Random
import scala.collection.JavaConversions._
+import scala.util.Random
-import org.scalatest.FunSuite
-import org.scalatest.Matchers
+import org.scalatest.{FunSuite, Matchers}
-import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression._
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
object GradientDescentSuite {
@@ -46,7 +45,7 @@ object GradientDescentSuite {
val rnd = new Random(seed)
val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
- val unifRand = new scala.util.Random(45)
+ val unifRand = new Random(45)
val rLogis = (0 until nPoints).map { i =>
val u = unifRand.nextDouble()
math.log(u) - math.log(1.0-u)
@@ -144,3 +143,26 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers
"should be initialWeightsWithIntercept.")
}
}
+
+class GradientDescentClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => (1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+ }.cache()
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val (weights, loss) = GradientDescent.runMiniBatchSGD(
+ points,
+ new LogisticGradient,
+ new SquaredL2Updater,
+ 0.1,
+ 2,
+ 1.0,
+ 1.0,
+ Vectors.dense(new Array[Double](n)))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
index fe7a9033cd..ff414742e8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
@@ -17,12 +17,13 @@
package org.apache.spark.mllib.optimization
-import org.scalatest.FunSuite
-import org.scalatest.Matchers
+import scala.util.Random
+
+import org.scalatest.{FunSuite, Matchers}
-import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
@@ -230,3 +231,24 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
"The weight differences between LBFGS and GD should be within 2%.")
}
}
+
+class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small") {
+ val m = 10
+ val n = 200000
+ val examples = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => (1.0, Vectors.dense(Array.fill(n)(random.nextDouble))))
+ }.cache()
+ val lbfgs = new LBFGS(new LogisticGradient, new SquaredL2Updater)
+ .setNumCorrections(1)
+ .setConvergenceTol(1e-12)
+ .setMaxNumIterations(1)
+ .setRegParam(1.0)
+ val random = new Random(0)
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val weights = lbfgs.optimize(examples, Vectors.dense(Array.fill(n)(random.nextDouble)))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index bfa42959c8..7aa96421ae 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -17,10 +17,13 @@
package org.apache.spark.mllib.regression
+import scala.util.Random
+
import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
+ LocalSparkContext}
class LassoSuite extends FunSuite with LocalSparkContext {
@@ -113,3 +116,19 @@ class LassoSuite extends FunSuite with LocalSparkContext {
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
}
+
+class LassoClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+ }.cache()
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = LassoWithSGD.train(points, 2)
+ val predictions = model.predict(points.map(_.features))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 7aaad7d7a3..4f89112b65 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -17,10 +17,13 @@
package org.apache.spark.mllib.regression
+import scala.util.Random
+
import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
+ LocalSparkContext}
class LinearRegressionSuite extends FunSuite with LocalSparkContext {
@@ -122,3 +125,19 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
sparseValidationData.map(row => model.predict(row.features)), sparseValidationData)
}
}
+
+class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+ }.cache()
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = LinearRegressionWithSGD.train(points, 2)
+ val predictions = model.predict(points.map(_.features))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index 67768e17fb..727bbd051f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -17,11 +17,14 @@
package org.apache.spark.mllib.regression
-import org.scalatest.FunSuite
+import scala.util.Random
import org.jblas.DoubleMatrix
+import org.scalatest.FunSuite
-import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
+ LocalSparkContext}
class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
@@ -73,3 +76,19 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
"ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
}
}
+
+class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+ }.cache()
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = RidgeRegressionWithSGD.train(points, 2)
+ val predictions = model.predict(points.map(_.features))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
new file mode 100644
index 0000000000..5e9101cdd3
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.util
+
+import org.scalatest.{Suite, BeforeAndAfterAll}
+
+import org.apache.spark.{SparkConf, SparkContext}
+
+trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite =>
+ @transient var sc: SparkContext = _
+
+ override def beforeAll() {
+ val conf = new SparkConf()
+ .setMaster("local-cluster[2, 1, 512]")
+ .setAppName("test-cluster")
+ .set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data
+ sc = new SparkContext(conf)
+ super.beforeAll()
+ }
+
+ override def afterAll() {
+ if (sc != null) {
+ sc.stop()
+ }
+ super.afterAll()
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
index 0d4868f3d9..7857d9e5ee 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
@@ -20,13 +20,16 @@ package org.apache.spark.mllib.util
import org.scalatest.Suite
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkConf, SparkContext}
trait LocalSparkContext extends BeforeAndAfterAll { self: Suite =>
@transient var sc: SparkContext = _
override def beforeAll() {
- sc = new SparkContext("local", "test")
+ val conf = new SparkConf()
+ .setMaster("local")
+ .setAppName("test")
+ sc = new SparkContext(conf)
super.beforeAll()
}