aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala9
2 files changed, 16 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 100d9e7f6c..ec0ea05f9e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -106,7 +106,10 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
@Since("1.6.0")
object Bucketizer extends DefaultParamsReadable[Bucketizer] {
- /** We require splits to be of length >= 3 and to be in strictly increasing order. */
+ /**
+ * We require splits to be of length >= 3 and to be in strictly increasing order.
+ * No NaN split should be accepted.
+ */
private[feature] def checkSplits(splits: Array[Double]): Boolean = {
if (splits.length < 3) {
false
@@ -114,10 +117,10 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
var i = 0
val n = splits.length - 1
while (i < n) {
- if (splits(i) >= splits(i + 1)) return false
+ if (splits(i) >= splits(i + 1) || splits(i).isNaN) return false
i += 1
}
- true
+ !splits(n).isNaN
}
}
@@ -126,7 +129,9 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
* @throws SparkException if a feature is < splits.head or > splits.last
*/
private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
- if (feature == splits.last) {
+ if (feature.isNaN) {
+ splits.length - 1
+ } else if (feature == splits.last) {
splits.length - 2
} else {
val idx = ju.Arrays.binarySearch(splits, feature)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index e09800877c..1e59d71a70 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -39,7 +39,7 @@ private[feature] trait QuantileDiscretizerBase extends Params
* default: 2
* @group param
*/
- val numBuckets = new IntParam(this, "numBuckets", "Maximum number of buckets (quantiles, or " +
+ val numBuckets = new IntParam(this, "numBuckets", "Number of buckets (quantiles, or " +
"categories) into which data points are grouped. Must be >= 2.",
ParamValidators.gtEq(2))
setDefault(numBuckets -> 2)
@@ -65,7 +65,12 @@ private[feature] trait QuantileDiscretizerBase extends Params
/**
* `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
- * categorical features. The number of bins can be set using the `numBuckets` parameter.
+ * categorical features. The number of bins can be set using the `numBuckets` parameter. It is
+ * possible that the number of buckets used will be less than this value, for example, if there
+ * are too few distinct values of the input to create enough distinct quantiles. Note also that
+ * NaN values are handled specially and placed into their own bucket. For example, if 4 buckets
+ * are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in a special
+ * bucket(4).
* The bin ranges are chosen using an approximate algorithm (see the documentation for
* [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]]
* for a detailed description). The precision of the approximation can be controlled with the