diff options
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 6a2c601bbe..25fabf64d5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -71,6 +71,26 @@ class QuantileDiscretizerSuite } } + test("Test splits on dataset larger than minSamplesRequired") { + val sqlCtx = SQLContext.getOrCreate(sc) + import sqlCtx.implicits._ + + val datasetSize = QuantileDiscretizer.minSamplesRequired + 1 + val numBuckets = 5 + val df = sc.parallelize((1.0 to datasetSize by 1.0).map(Tuple1.apply)).toDF("input") + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(numBuckets) + .setSeed(1) + + val result = discretizer.fit(df).transform(df) + val observedNumBuckets = result.select("result").distinct.count + + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") + } + test("read/write") { val t = new QuantileDiscretizer() .setInputCol("myInputCol") |