aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorVinceShieh <vincent.xie@intel.com>2016-08-24 10:16:58 +0100
committerSean Owen <sowen@cloudera.com>2016-08-24 10:16:58 +0100
commit92c0eaf348b42b3479610da0be761013f9d81c54 (patch)
tree87f25b1e86cfa8b469f83c0575c792fd4c4f4a48 /mllib/src
parent673a80d2230602c9e6573a23e35fb0f6b832bfca (diff)
downloadspark-92c0eaf348b42b3479610da0be761013f9d81c54.tar.gz
spark-92c0eaf348b42b3479610da0be761013f9d81c54.tar.bz2
spark-92c0eaf348b42b3479610da0be761013f9d81c54.zip
[SPARK-17086][ML] Fix InvalidArgumentException issue in QuantileDiscretizer when some quantiles are duplicated
## What changes were proposed in this pull request? In cases when QuantileDiscretizerSuite is called upon a numeric array with duplicated elements, we will take the unique elements generated from approxQuantiles as input for Bucketizer. ## How was this patch tested? An unit test is added in QuantileDiscretizerSuite QuantileDiscretizer.fit will throw an illegal exception when calling setSplits on a list of splits with duplicated elements. Bucketizer.setSplits should only accept either a numeric vector of two or more unique cut points, although that may produce less number of buckets than requested. Signed-off-by: VinceShieh <vincent.xieintel.com> Author: VinceShieh <vincent.xie@intel.com> Closes #14747 from VinceShieh/SPARK-17086.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala19
2 files changed, 25 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 558a7bbf0a..e09800877c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -114,7 +114,12 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
splits(0) = Double.NegativeInfinity
splits(splits.length - 1) = Double.PositiveInfinity
- val bucketizer = new Bucketizer(uid).setSplits(splits)
+ val distinctSplits = splits.distinct
+ if (splits.length != distinctSplits.length) {
+ log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" +
+ s" buckets as a result.")
+ }
+ val bucketizer = new Bucketizer(uid).setSplits(distinctSplits.sorted)
copyValues(bucketizer.setParent(this))
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index b73dbd6232..18f1e89ee8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -52,6 +52,25 @@ class QuantileDiscretizerSuite
"Bucket sizes are not within expected relative error tolerance.")
}
+ test("Test Bucketizer on duplicated splits") {
+ val spark = this.spark
+ import spark.implicits._
+
+ val datasetSize = 12
+ val numBuckets = 5
+ val df = sc.parallelize(Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0, 1.0, 3.0))
+ .map(Tuple1.apply).toDF("input")
+ val discretizer = new QuantileDiscretizer()
+ .setInputCol("input")
+ .setOutputCol("result")
+ .setNumBuckets(numBuckets)
+ val result = discretizer.fit(df).transform(df)
+
+ val observedNumBuckets = result.select("result").distinct.count
+ assert(2 <= observedNumBuckets && observedNumBuckets <= numBuckets,
+ "Observed number of buckets are not within expected range.")
+ }
+
test("Test transform method on unseen data") {
val spark = this.spark
import spark.implicits._