diff options
author | Oliver Pierson <ocp@gatech.edu> | 2016-04-11 12:02:48 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-04-11 12:02:48 -0700 |
commit | 89a41c5b7a3f727b44a7f615a1352ca006d12f73 (patch) | |
tree | 1c59e13c4fe03bbb0c5717f6c08311a2d2648da2 /mllib/src/test/scala/org | |
parent | 2dacc81ec31233e558855a26340ad4662d470387 (diff) | |
download | spark-89a41c5b7a3f727b44a7f615a1352ca006d12f73.tar.gz spark-89a41c5b7a3f727b44a7f615a1352ca006d12f73.tar.bz2 spark-89a41c5b7a3f727b44a7f615a1352ca006d12f73.zip |
[SPARK-13600][MLLIB] Use approxQuantile from DataFrame stats in QuantileDiscretizer
## What changes were proposed in this pull request?
QuantileDiscretizer can return an unexpected number of buckets in certain cases. This PR proposes to fix this issue and also refactor QuantileDiscretizer to use approxQuantiles from DataFrame stats functions.
## How was this patch tested?
QuantileDiscretizerSuite unit tests (some existing tests will change or even be removed in this PR)
Author: Oliver Pierson <ocp@gatech.edu>
Closes #11553 from oliverpierson/SPARK-13600.
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala | 115 |
1 files changed, 40 insertions, 75 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 25fabf64d5..8895d630a0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -17,78 +17,60 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.functions.udf class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - import org.apache.spark.ml.feature.QuantileDiscretizerSuite._ - - test("Test quantile discretizer") { - checkDiscretizedData(sc, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - 10, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity")) - - checkDiscretizedData(sc, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - 4, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity")) - - checkDiscretizedData(sc, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - 3, - Array[Double](0, 1, 2, 2, 2, 2, 2, 2, 2), - Array("-Infinity, 2.0", "2.0, 3.0", "3.0, Infinity")) + test("Test observed number of buckets and their sizes match expected values") { + val sqlCtx = SQLContext.getOrCreate(sc) + import sqlCtx.implicits._ - checkDiscretizedData(sc, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - 2, - Array[Double](0, 1, 1, 1, 1, 1, 1, 1, 1), - Array("-Infinity, 2.0", "2.0, Infinity")) + val datasetSize = 100000 + val numBuckets = 5 + val df = sc.parallelize(1.0 to datasetSize by 1.0).map(Tuple1.apply).toDF("input") + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(numBuckets) + val result = discretizer.fit(df).transform(df) - } + val observedNumBuckets = result.select("result").distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") - test("Test getting splits") { - val splitTestPoints = Array( - Array[Double]() -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(Double.NegativeInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(Double.PositiveInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(Double.NegativeInfinity, Double.PositiveInfinity) - -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(0.0) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(1.0) -> Array(Double.NegativeInfinity, 1, Double.PositiveInfinity), - Array(0.0, 1.0) -> Array(Double.NegativeInfinity, 0, 1, Double.PositiveInfinity) - ) - for ((ori, res) <- splitTestPoints) { - assert(QuantileDiscretizer.getSplits(ori) === res, "Returned splits are invalid.") + val relativeError = discretizer.getRelativeError + val isGoodBucket = udf { + (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize) } + val numGoodBuckets = result.groupBy("result").count.filter(isGoodBucket($"count")).count + assert(numGoodBuckets === numBuckets, + "Bucket sizes are not within expected relative error tolerance.") } - test("Test splits on dataset larger than minSamplesRequired") { + test("Test transform method on unseen data") { val sqlCtx = SQLContext.getOrCreate(sc) import sqlCtx.implicits._ - val datasetSize = QuantileDiscretizer.minSamplesRequired + 1 - val numBuckets = 5 - val df = sc.parallelize((1.0 to datasetSize by 1.0).map(Tuple1.apply)).toDF("input") + val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input") + val testDF = sc.parallelize(-10.0 to 110.0 by 1.0).map(Tuple1.apply).toDF("input") val discretizer = new QuantileDiscretizer() .setInputCol("input") .setOutputCol("result") - .setNumBuckets(numBuckets) - .setSeed(1) + .setNumBuckets(5) - val result = discretizer.fit(df).transform(df) - val observedNumBuckets = result.select("result").distinct.count + val result = discretizer.fit(trainDF).transform(testDF) + val firstBucketSize = result.filter(result("result") === 0.0).count + val lastBucketSize = result.filter(result("result") === 4.0).count - assert(observedNumBuckets === numBuckets, - "Observed number of buckets does not equal expected number of buckets.") + assert(firstBucketSize === 30L, + s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.") + assert(lastBucketSize === 31L, + s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.") } test("read/write") { @@ -98,34 +80,17 @@ class QuantileDiscretizerSuite .setNumBuckets(6) testDefaultReadWrite(t) } -} - -private object QuantileDiscretizerSuite extends SparkFunSuite { - def checkDiscretizedData( - sc: SparkContext, - data: Array[Double], - numBucket: Int, - expectedResult: Array[Double], - expectedAttrs: Array[String]): Unit = { + test("Verify resulting model has parent") { val sqlCtx = SQLContext.getOrCreate(sc) import sqlCtx.implicits._ - val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input") - val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result") - .setNumBuckets(numBucket).setSeed(1) + val df = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("input") + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(5) val model = discretizer.fit(df) assert(model.hasParent) - val result = model.transform(df) - - val transformedFeatures = result.select("result").collect() - .map { case Row(transformedFeature: Double) => transformedFeature } - val transformedAttrs = Attribute.fromStructField(result.schema("result")) - .asInstanceOf[NominalAttribute].values.get - - assert(transformedFeatures === expectedResult, - "Transformed features do not equal expected features.") - assert(transformedAttrs === expectedAttrs, - "Transformed attributes do not equal expected attributes.") } } |