aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorVinceShieh <vincent.xie@intel.com>2016-09-21 10:20:57 +0100
committerSean Owen <sowen@cloudera.com>2016-09-21 10:20:57 +0100
commit57dc326bd00cf0a49da971e9c573c48ae28acaa2 (patch)
tree3e57f59b33e42beddd1df72dbd85266e7b09ef7f /mllib/src/main
parentb366f18496e1ce8bd20fe58a0245ef7d91819a03 (diff)
downloadspark-57dc326bd00cf0a49da971e9c573c48ae28acaa2.tar.gz
spark-57dc326bd00cf0a49da971e9c573c48ae28acaa2.tar.bz2
spark-57dc326bd00cf0a49da971e9c573c48ae28acaa2.zip
[SPARK-17219][ML] Add NaN value handling in Bucketizer
## What changes were proposed in this pull request? This PR fixes an issue when Bucketizer is called to handle a dataset containing NaN value. Sometimes, null value might also be useful to users, so in these cases, Bucketizer should reserve one extra bucket for NaN values, instead of throwing an illegal exception. Before: ``` Bucketizer.transform on NaN value threw an illegal exception. ``` After: ``` NaN values will be grouped in an extra bucket. ``` ## How was this patch tested? New test cases added in `BucketizerSuite`. Signed-off-by: VinceShieh <vincent.xieintel.com> Author: VinceShieh <vincent.xie@intel.com> Closes #14858 from VinceShieh/spark-17219.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala9
2 files changed, 16 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 100d9e7f6c..ec0ea05f9e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -106,7 +106,10 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
@Since("1.6.0")
object Bucketizer extends DefaultParamsReadable[Bucketizer] {
- /** We require splits to be of length >= 3 and to be in strictly increasing order. */
+ /**
+ * We require splits to be of length >= 3 and to be in strictly increasing order.
+ * No NaN split should be accepted.
+ */
private[feature] def checkSplits(splits: Array[Double]): Boolean = {
if (splits.length < 3) {
false
@@ -114,10 +117,10 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
var i = 0
val n = splits.length - 1
while (i < n) {
- if (splits(i) >= splits(i + 1)) return false
+ if (splits(i) >= splits(i + 1) || splits(i).isNaN) return false
i += 1
}
- true
+ !splits(n).isNaN
}
}
@@ -126,7 +129,9 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
* @throws SparkException if a feature is < splits.head or > splits.last
*/
private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
- if (feature == splits.last) {
+ if (feature.isNaN) {
+ splits.length - 1
+ } else if (feature == splits.last) {
splits.length - 2
} else {
val idx = ju.Arrays.binarySearch(splits, feature)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index e09800877c..1e59d71a70 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -39,7 +39,7 @@ private[feature] trait QuantileDiscretizerBase extends Params
* default: 2
* @group param
*/
- val numBuckets = new IntParam(this, "numBuckets", "Maximum number of buckets (quantiles, or " +
+ val numBuckets = new IntParam(this, "numBuckets", "Number of buckets (quantiles, or " +
"categories) into which data points are grouped. Must be >= 2.",
ParamValidators.gtEq(2))
setDefault(numBuckets -> 2)
@@ -65,7 +65,12 @@ private[feature] trait QuantileDiscretizerBase extends Params
/**
* `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
- * categorical features. The number of bins can be set using the `numBuckets` parameter.
+ * categorical features. The number of bins can be set using the `numBuckets` parameter. It is
+ * possible that the number of buckets used will be less than this value, for example, if there
+ * are too few distinct values of the input to create enough distinct quantiles. Note also that
+ * NaN values are handled specially and placed into their own bucket. For example, if 4 buckets
+ * are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in a special
+ * bucket(4).
* The bin ranges are chosen using an approximate algorithm (see the documentation for
* [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]]
* for a detailed description). The precision of the approximation can be controlled with the