aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala20
1 files changed, 20 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index 6a2c601bbe..25fabf64d5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -71,6 +71,26 @@ class QuantileDiscretizerSuite
}
}
+ test("Test splits on dataset larger than minSamplesRequired") {
+ val sqlCtx = SQLContext.getOrCreate(sc)
+ import sqlCtx.implicits._
+
+ val datasetSize = QuantileDiscretizer.minSamplesRequired + 1
+ val numBuckets = 5
+ val df = sc.parallelize((1.0 to datasetSize by 1.0).map(Tuple1.apply)).toDF("input")
+ val discretizer = new QuantileDiscretizer()
+ .setInputCol("input")
+ .setOutputCol("result")
+ .setNumBuckets(numBuckets)
+ .setSeed(1)
+
+ val result = discretizer.fit(df).transform(df)
+ val observedNumBuckets = result.select("result").distinct.count
+
+ assert(observedNumBuckets === numBuckets,
+ "Observed number of buckets does not equal expected number of buckets.")
+ }
+
test("read/write") {
val t = new QuantileDiscretizer()
.setInputCol("myInputCol")