diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 100d9e7f6c..ec0ea05f9e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -106,7 +106,10 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.6.0") object Bucketizer extends DefaultParamsReadable[Bucketizer] { - /** We require splits to be of length >= 3 and to be in strictly increasing order. */ + /** + * We require splits to be of length >= 3 and to be in strictly increasing order. + * No NaN split should be accepted. + */ private[feature] def checkSplits(splits: Array[Double]): Boolean = { if (splits.length < 3) { false @@ -114,10 +117,10 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { var i = 0 val n = splits.length - 1 while (i < n) { - if (splits(i) >= splits(i + 1)) return false + if (splits(i) >= splits(i + 1) || splits(i).isNaN) return false i += 1 } - true + !splits(n).isNaN } } @@ -126,7 +129,9 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { * @throws SparkException if a feature is < splits.head or > splits.last */ private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { - if (feature == splits.last) { + if (feature.isNaN) { + splits.length - 1 + } else if (feature == splits.last) { splits.length - 2 } else { val idx = ju.Arrays.binarySearch(splits, feature) |