diff options
author | Holden Karau <holden@pigscanfly.ca> | 2014-04-16 09:33:27 -0700 |
---|---|---|
committer | Patrick Wendell <pwendell@gmail.com> | 2014-04-16 09:33:27 -0700 |
commit | c3527a333a0877f4b49614f3fd1f041b01749651 (patch) | |
tree | ba2c97de461cabdfe2cc0d786b1ff65f5a5b557e /mllib/src | |
parent | 9edd88782e0268439c5ab57400d6a7ab432fc269 (diff) | |
download | spark-c3527a333a0877f4b49614f3fd1f041b01749651.tar.gz spark-c3527a333a0877f4b49614f3fd1f041b01749651.tar.bz2 spark-c3527a333a0877f4b49614f3fd1f041b01749651.zip |
SPARK-1310: Start adding k-fold cross validation to MLLib [adds kFold to MLUtils & fixes bug in BernoulliSampler]
Author: Holden Karau <holden@pigscanfly.ca>
Closes #18 from holdenk/addkfoldcrossvalidation and squashes the following commits:
208db9b [Holden Karau] Fix a bad space
e84f2fc [Holden Karau] Fix the test, we should be looking at the second element instead
6ddbf05 [Holden Karau] swap training and validation order
7157ae9 [Holden Karau] CR feedback
90896c7 [Holden Karau] New line
150889c [Holden Karau] Fix up error messages in the MLUtilsSuite
2cb90b3 [Holden Karau] Fix the names in kFold
c702a96 [Holden Karau] Fix imports in MLUtils
e187e35 [Holden Karau] Move { up to same line as whenExecuting(random) in RandomSamplerSuite.scala
c5b723f [Holden Karau] clean up
7ebe4d5 [Holden Karau] CR feedback, remove unecessary learners (came back during merge mistake) and insert an empty line
bb5fa56 [Holden Karau] extra line sadness
163c5b1 [Holden Karau] code review feedback 1.to -> 1 to and folds -> numFolds
5a33f1d [Holden Karau] Code review follow up.
e8741a7 [Holden Karau] CR feedback
b78804e [Holden Karau] Remove cross validation [TODO in another pull request]
91eae64 [Holden Karau] Consolidate things in mlutils
264502a [Holden Karau] Add a test for the bug that was found with BernoulliSampler not copying the complement param
dd0b737 [Holden Karau] Wrap long lines (oops)
c0b7fa4 [Holden Karau] Switch FoldedRDD to use BernoulliSampler and PartitionwiseSampledRDD
08f8e4d [Holden Karau] Fix BernoulliSampler to respect complement
a751ec6 [Holden Karau] Add k-fold cross validation to MLLib
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala | 21 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala | 39 |
2 files changed, 60 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 901c3180ea..2f3ac10397 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -17,11 +17,16 @@ package org.apache.spark.mllib.util +import scala.reflect.ClassTag + import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance} import org.apache.spark.annotation.Experimental import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.PartitionwiseSampledRDD +import org.apache.spark.SparkContext._ +import org.apache.spark.util.random.BernoulliSampler import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.Vectors @@ -158,6 +163,22 @@ object MLUtils { } /** + * Return a k element array of pairs of RDDs with the first element of each pair + * containing the training data, a complement of the validation data and the second + * element, the validation data, containing a unique 1/kth of the data. Where k=numFolds. + */ + def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { + val numFoldsF = numFolds.toFloat + (1 to numFolds).map { fold => + val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, + complement = false) + val validation = new PartitionwiseSampledRDD(rdd, sampler, seed) + val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), seed) + (training, validation) + }.toArray + } + + /** * Returns the squared Euclidean distance between two vectors. The following formula will be used * if it does not introduce too much numerical error: * <pre> diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 812a843478..674378a34c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -19,6 +19,9 @@ package org.apache.spark.mllib.util import java.io.File +import scala.math +import scala.util.Random + import org.scalatest.FunSuite import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm, @@ -93,4 +96,40 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { case t: Throwable => } } + + test("kFold") { + val data = sc.parallelize(1 to 100, 2) + val collectedData = data.collect().sorted + val twoFoldedRdd = MLUtils.kFold(data, 2, 1) + assert(twoFoldedRdd(0)._1.collect().sorted === twoFoldedRdd(1)._2.collect().sorted) + assert(twoFoldedRdd(0)._2.collect().sorted === twoFoldedRdd(1)._1.collect().sorted) + for (folds <- 2 to 10) { + for (seed <- 1 to 5) { + val foldedRdds = MLUtils.kFold(data, folds, seed) + assert(foldedRdds.size === folds) + foldedRdds.map { case (training, validation) => + val result = validation.union(training).collect().sorted + val validationSize = validation.collect().size.toFloat + assert(validationSize > 0, "empty validation data") + val p = 1 / folds.toFloat + // Within 3 standard deviations of the mean + val range = 3 * math.sqrt(100 * p * (1 - p)) + val expected = 100 * p + val lowerBound = expected - range + val upperBound = expected + range + assert(validationSize > lowerBound, + s"Validation data ($validationSize) smaller than expected ($lowerBound)" ) + assert(validationSize < upperBound, + s"Validation data ($validationSize) larger than expected ($upperBound)" ) + assert(training.collect().size > 0, "empty training data") + assert(result === collectedData, + "Each training+validation set combined should contain all of the data.") + } + // K fold cross validation should only have each element in the validation set exactly once + assert(foldedRdds.map(_._2).reduce((x,y) => x.union(y)).collect().sorted === + data.collect().sorted) + } + } + } + } |