aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorOliver Pierson <ocp@gatech.edu>2016-04-11 12:02:48 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-11 12:02:48 -0700
commit89a41c5b7a3f727b44a7f615a1352ca006d12f73 (patch)
tree1c59e13c4fe03bbb0c5717f6c08311a2d2648da2 /mllib/src/test
parent2dacc81ec31233e558855a26340ad4662d470387 (diff)
downloadspark-89a41c5b7a3f727b44a7f615a1352ca006d12f73.tar.gz
spark-89a41c5b7a3f727b44a7f615a1352ca006d12f73.tar.bz2
spark-89a41c5b7a3f727b44a7f615a1352ca006d12f73.zip
[SPARK-13600][MLLIB] Use approxQuantile from DataFrame stats in QuantileDiscretizer
## What changes were proposed in this pull request? QuantileDiscretizer can return an unexpected number of buckets in certain cases. This PR proposes to fix this issue and also refactor QuantileDiscretizer to use approxQuantiles from DataFrame stats functions. ## How was this patch tested? QuantileDiscretizerSuite unit tests (some existing tests will change or even be removed in this PR) Author: Oliver Pierson <ocp@gatech.edu> Closes #11553 from oliverpierson/SPARK-13600.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala115
1 files changed, 40 insertions, 75 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index 25fabf64d5..8895d630a0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -17,78 +17,60 @@
package org.apache.spark.ml.feature
-import org.apache.spark.{SparkContext, SparkFunSuite}
-import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.functions.udf
class QuantileDiscretizerSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
- import org.apache.spark.ml.feature.QuantileDiscretizerSuite._
-
- test("Test quantile discretizer") {
- checkDiscretizedData(sc,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- 10,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity"))
-
- checkDiscretizedData(sc,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- 4,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity"))
-
- checkDiscretizedData(sc,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- 3,
- Array[Double](0, 1, 2, 2, 2, 2, 2, 2, 2),
- Array("-Infinity, 2.0", "2.0, 3.0", "3.0, Infinity"))
+ test("Test observed number of buckets and their sizes match expected values") {
+ val sqlCtx = SQLContext.getOrCreate(sc)
+ import sqlCtx.implicits._
- checkDiscretizedData(sc,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- 2,
- Array[Double](0, 1, 1, 1, 1, 1, 1, 1, 1),
- Array("-Infinity, 2.0", "2.0, Infinity"))
+ val datasetSize = 100000
+ val numBuckets = 5
+ val df = sc.parallelize(1.0 to datasetSize by 1.0).map(Tuple1.apply).toDF("input")
+ val discretizer = new QuantileDiscretizer()
+ .setInputCol("input")
+ .setOutputCol("result")
+ .setNumBuckets(numBuckets)
+ val result = discretizer.fit(df).transform(df)
- }
+ val observedNumBuckets = result.select("result").distinct.count
+ assert(observedNumBuckets === numBuckets,
+ "Observed number of buckets does not equal expected number of buckets.")
- test("Test getting splits") {
- val splitTestPoints = Array(
- Array[Double]() -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
- Array(Double.NegativeInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
- Array(Double.PositiveInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
- Array(Double.NegativeInfinity, Double.PositiveInfinity)
- -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
- Array(0.0) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
- Array(1.0) -> Array(Double.NegativeInfinity, 1, Double.PositiveInfinity),
- Array(0.0, 1.0) -> Array(Double.NegativeInfinity, 0, 1, Double.PositiveInfinity)
- )
- for ((ori, res) <- splitTestPoints) {
- assert(QuantileDiscretizer.getSplits(ori) === res, "Returned splits are invalid.")
+ val relativeError = discretizer.getRelativeError
+ val isGoodBucket = udf {
+ (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize)
}
+ val numGoodBuckets = result.groupBy("result").count.filter(isGoodBucket($"count")).count
+ assert(numGoodBuckets === numBuckets,
+ "Bucket sizes are not within expected relative error tolerance.")
}
- test("Test splits on dataset larger than minSamplesRequired") {
+ test("Test transform method on unseen data") {
val sqlCtx = SQLContext.getOrCreate(sc)
import sqlCtx.implicits._
- val datasetSize = QuantileDiscretizer.minSamplesRequired + 1
- val numBuckets = 5
- val df = sc.parallelize((1.0 to datasetSize by 1.0).map(Tuple1.apply)).toDF("input")
+ val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input")
+ val testDF = sc.parallelize(-10.0 to 110.0 by 1.0).map(Tuple1.apply).toDF("input")
val discretizer = new QuantileDiscretizer()
.setInputCol("input")
.setOutputCol("result")
- .setNumBuckets(numBuckets)
- .setSeed(1)
+ .setNumBuckets(5)
- val result = discretizer.fit(df).transform(df)
- val observedNumBuckets = result.select("result").distinct.count
+ val result = discretizer.fit(trainDF).transform(testDF)
+ val firstBucketSize = result.filter(result("result") === 0.0).count
+ val lastBucketSize = result.filter(result("result") === 4.0).count
- assert(observedNumBuckets === numBuckets,
- "Observed number of buckets does not equal expected number of buckets.")
+ assert(firstBucketSize === 30L,
+ s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.")
+ assert(lastBucketSize === 31L,
+ s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.")
}
test("read/write") {
@@ -98,34 +80,17 @@ class QuantileDiscretizerSuite
.setNumBuckets(6)
testDefaultReadWrite(t)
}
-}
-
-private object QuantileDiscretizerSuite extends SparkFunSuite {
- def checkDiscretizedData(
- sc: SparkContext,
- data: Array[Double],
- numBucket: Int,
- expectedResult: Array[Double],
- expectedAttrs: Array[String]): Unit = {
+ test("Verify resulting model has parent") {
val sqlCtx = SQLContext.getOrCreate(sc)
import sqlCtx.implicits._
- val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input")
- val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result")
- .setNumBuckets(numBucket).setSeed(1)
+ val df = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("input")
+ val discretizer = new QuantileDiscretizer()
+ .setInputCol("input")
+ .setOutputCol("result")
+ .setNumBuckets(5)
val model = discretizer.fit(df)
assert(model.hasParent)
- val result = model.transform(df)
-
- val transformedFeatures = result.select("result").collect()
- .map { case Row(transformedFeature: Double) => transformedFeature }
- val transformedAttrs = Attribute.fromStructField(result.schema("result"))
- .asInstanceOf[NominalAttribute].values.get
-
- assert(transformedFeatures === expectedResult,
- "Transformed features do not equal expected features.")
- assert(transformedAttrs === expectedAttrs,
- "Transformed attributes do not equal expected attributes.")
}
}