aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala13
1 files changed, 9 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 100d9e7f6c..ec0ea05f9e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -106,7 +106,10 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
@Since("1.6.0")
object Bucketizer extends DefaultParamsReadable[Bucketizer] {
- /** We require splits to be of length >= 3 and to be in strictly increasing order. */
+ /**
+ * We require splits to be of length >= 3 and to be in strictly increasing order.
+ * No NaN split should be accepted.
+ */
private[feature] def checkSplits(splits: Array[Double]): Boolean = {
if (splits.length < 3) {
false
@@ -114,10 +117,10 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
var i = 0
val n = splits.length - 1
while (i < n) {
- if (splits(i) >= splits(i + 1)) return false
+ if (splits(i) >= splits(i + 1) || splits(i).isNaN) return false
i += 1
}
- true
+ !splits(n).isNaN
}
}
@@ -126,7 +129,9 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
* @throws SparkException if a feature is < splits.head or > splits.last
*/
private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
- if (feature == splits.last) {
+ if (feature.isNaN) {
+ splits.length - 1
+ } else if (feature == splits.last) {
splits.length - 2
} else {
val idx = ju.Arrays.binarySearch(splits, feature)