diff options
author | VinceShieh <vincent.xie@intel.com> | 2016-08-24 10:16:58 +0100 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2016-08-24 10:16:58 +0100 |
commit | 92c0eaf348b42b3479610da0be761013f9d81c54 (patch) | |
tree | 87f25b1e86cfa8b469f83c0575c792fd4c4f4a48 /mllib | |
parent | 673a80d2230602c9e6573a23e35fb0f6b832bfca (diff) | |
download | spark-92c0eaf348b42b3479610da0be761013f9d81c54.tar.gz spark-92c0eaf348b42b3479610da0be761013f9d81c54.tar.bz2 spark-92c0eaf348b42b3479610da0be761013f9d81c54.zip |
[SPARK-17086][ML] Fix InvalidArgumentException issue in QuantileDiscretizer when some quantiles are duplicated
## What changes were proposed in this pull request?
In cases when QuantileDiscretizerSuite is called upon a numeric array with duplicated elements, we will take the unique elements generated from approxQuantiles as input for Bucketizer.
## How was this patch tested?
An unit test is added in QuantileDiscretizerSuite
QuantileDiscretizer.fit will throw an illegal exception when calling setSplits on a list of splits
with duplicated elements. Bucketizer.setSplits should only accept either a numeric vector of two
or more unique cut points, although that may produce less number of buckets than requested.
Signed-off-by: VinceShieh <vincent.xieintel.com>
Author: VinceShieh <vincent.xie@intel.com>
Closes #14747 from VinceShieh/SPARK-17086.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala | 7 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala | 19 |
2 files changed, 25 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 558a7bbf0a..e09800877c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -114,7 +114,12 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui splits(0) = Double.NegativeInfinity splits(splits.length - 1) = Double.PositiveInfinity - val bucketizer = new Bucketizer(uid).setSplits(splits) + val distinctSplits = splits.distinct + if (splits.length != distinctSplits.length) { + log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" + + s" buckets as a result.") + } + val bucketizer = new Bucketizer(uid).setSplits(distinctSplits.sorted) copyValues(bucketizer.setParent(this)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index b73dbd6232..18f1e89ee8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -52,6 +52,25 @@ class QuantileDiscretizerSuite "Bucket sizes are not within expected relative error tolerance.") } + test("Test Bucketizer on duplicated splits") { + val spark = this.spark + import spark.implicits._ + + val datasetSize = 12 + val numBuckets = 5 + val df = sc.parallelize(Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0, 1.0, 3.0)) + .map(Tuple1.apply).toDF("input") + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(numBuckets) + val result = discretizer.fit(df).transform(df) + + val observedNumBuckets = result.select("result").distinct.count + assert(2 <= observedNumBuckets && observedNumBuckets <= numBuckets, + "Observed number of buckets are not within expected range.") + } + test("Test transform method on unseen data") { val spark = this.spark import spark.implicits._ |