diff options
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala | 13 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala | 9 |
2 files changed, 16 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 100d9e7f6c..ec0ea05f9e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -106,7 +106,10 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.6.0") object Bucketizer extends DefaultParamsReadable[Bucketizer] { - /** We require splits to be of length >= 3 and to be in strictly increasing order. */ + /** + * We require splits to be of length >= 3 and to be in strictly increasing order. + * No NaN split should be accepted. + */ private[feature] def checkSplits(splits: Array[Double]): Boolean = { if (splits.length < 3) { false @@ -114,10 +117,10 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { var i = 0 val n = splits.length - 1 while (i < n) { - if (splits(i) >= splits(i + 1)) return false + if (splits(i) >= splits(i + 1) || splits(i).isNaN) return false i += 1 } - true + !splits(n).isNaN } } @@ -126,7 +129,9 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { * @throws SparkException if a feature is < splits.head or > splits.last */ private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { - if (feature == splits.last) { + if (feature.isNaN) { + splits.length - 1 + } else if (feature == splits.last) { splits.length - 2 } else { val idx = ju.Arrays.binarySearch(splits, feature) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index e09800877c..1e59d71a70 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -39,7 +39,7 @@ private[feature] trait QuantileDiscretizerBase extends Params * default: 2 * @group param */ - val numBuckets = new IntParam(this, "numBuckets", "Maximum number of buckets (quantiles, or " + + val numBuckets = new IntParam(this, "numBuckets", "Number of buckets (quantiles, or " + "categories) into which data points are grouped. Must be >= 2.", ParamValidators.gtEq(2)) setDefault(numBuckets -> 2) @@ -65,7 +65,12 @@ private[feature] trait QuantileDiscretizerBase extends Params /** * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned - * categorical features. The number of bins can be set using the `numBuckets` parameter. + * categorical features. The number of bins can be set using the `numBuckets` parameter. It is + * possible that the number of buckets used will be less than this value, for example, if there + * are too few distinct values of the input to create enough distinct quantiles. Note also that + * NaN values are handled specially and placed into their own bucket. For example, if 4 buckets + * are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in a special + * bucket(4). * The bin ranges are chosen using an approximate algorithm (see the documentation for * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] * for a detailed description). The precision of the approximation can be controlled with the |