diff options
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala | 55 |
1 files changed, 28 insertions, 27 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 7dba64bc35..b28c88aaae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -17,6 +17,9 @@ package org.apache.spark.ml.feature +import java.{util => ju} + +import org.apache.spark.SparkException import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ @@ -38,18 +41,19 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer]) def this() = this(null) /** - * Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets. - * A bucket defined by splits x,y holds values in the range [x,y). Splits should be strictly - * increasing. Values at -inf, inf must be explicitly provided to cover all Double values; + * Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets. + * A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which + * also includes y. Splits should be strictly increasing. + * Values at -inf, inf must be explicitly provided to cover all Double values; * otherwise, values outside the splits specified will be treated as errors. * @group param */ val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits", - "Split points for mapping continuous features into buckets. With n splits, there are n+1 " + - "buckets. A bucket defined by splits x,y holds values in the range [x,y). The splits " + - "should be strictly increasing. Values at -inf, inf must be explicitly provided to cover" + - " all Double values; otherwise, values outside the splits specified will be treated as" + - " errors.", + "Split points for mapping continuous features into buckets. With n+1 splits, there are n " + + "buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last " + + "bucket, which also includes y. The splits should be strictly increasing. " + + "Values at -inf, inf must be explicitly provided to cover all Double values; " + + "otherwise, values outside the splits specified will be treated as errors.", Bucketizer.checkSplits) /** @group getParam */ @@ -104,28 +108,25 @@ private[feature] object Bucketizer { /** * Binary searching in several buckets to place each data point. - * @throws RuntimeException if a feature is < splits.head or >= splits.last + * @throws SparkException if a feature is < splits.head or > splits.last */ - def binarySearchForBuckets( - splits: Array[Double], - feature: Double): Double = { - // Check bounds. We make an exception for +inf so that it can exist in some bin. - if ((feature < splits.head) || (feature >= splits.last && feature != Double.PositiveInfinity)) { - throw new RuntimeException(s"Feature value $feature out of Bucketizer bounds" + - s" [${splits.head}, ${splits.last}). Check your features, or loosen " + - s"the lower/upper bound constraints.") - } - var left = 0 - var right = splits.length - 2 - while (left < right) { - val mid = (left + right) / 2 - val split = splits(mid + 1) - if (feature < split) { - right = mid + def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { + if (feature == splits.last) { + splits.length - 2 + } else { + val idx = ju.Arrays.binarySearch(splits, feature) + if (idx >= 0) { + idx } else { - left = mid + 1 + val insertPos = -idx - 1 + if (insertPos == 0 || insertPos == splits.length) { + throw new SparkException(s"Feature value $feature out of Bucketizer bounds" + + s" [${splits.head}, ${splits.last}]. Check your features, or loosen " + + s"the lower/upper bound constraints.") + } else { + insertPos - 1 + } } } - left } } |