diff options
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala | 55 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala | 25 |
2 files changed, 41 insertions, 39 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 7dba64bc35..b28c88aaae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -17,6 +17,9 @@ package org.apache.spark.ml.feature +import java.{util => ju} + +import org.apache.spark.SparkException import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ @@ -38,18 +41,19 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer]) def this() = this(null) /** - * Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets. - * A bucket defined by splits x,y holds values in the range [x,y). Splits should be strictly - * increasing. Values at -inf, inf must be explicitly provided to cover all Double values; + * Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets. + * A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which + * also includes y. Splits should be strictly increasing. + * Values at -inf, inf must be explicitly provided to cover all Double values; * otherwise, values outside the splits specified will be treated as errors. * @group param */ val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits", - "Split points for mapping continuous features into buckets. With n splits, there are n+1 " + - "buckets. A bucket defined by splits x,y holds values in the range [x,y). The splits " + - "should be strictly increasing. Values at -inf, inf must be explicitly provided to cover" + - " all Double values; otherwise, values outside the splits specified will be treated as" + - " errors.", + "Split points for mapping continuous features into buckets. With n+1 splits, there are n " + + "buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last " + + "bucket, which also includes y. The splits should be strictly increasing. " + + "Values at -inf, inf must be explicitly provided to cover all Double values; " + + "otherwise, values outside the splits specified will be treated as errors.", Bucketizer.checkSplits) /** @group getParam */ @@ -104,28 +108,25 @@ private[feature] object Bucketizer { /** * Binary searching in several buckets to place each data point. - * @throws RuntimeException if a feature is < splits.head or >= splits.last + * @throws SparkException if a feature is < splits.head or > splits.last */ - def binarySearchForBuckets( - splits: Array[Double], - feature: Double): Double = { - // Check bounds. We make an exception for +inf so that it can exist in some bin. - if ((feature < splits.head) || (feature >= splits.last && feature != Double.PositiveInfinity)) { - throw new RuntimeException(s"Feature value $feature out of Bucketizer bounds" + - s" [${splits.head}, ${splits.last}). Check your features, or loosen " + - s"the lower/upper bound constraints.") - } - var left = 0 - var right = splits.length - 2 - while (left < right) { - val mid = (left + right) / 2 - val split = splits(mid + 1) - if (feature < split) { - right = mid + def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { + if (feature == splits.last) { + splits.length - 2 + } else { + val idx = ju.Arrays.binarySearch(splits, feature) + if (idx >= 0) { + idx } else { - left = mid + 1 + val insertPos = -idx - 1 + if (insertPos == 0 || insertPos == splits.length) { + throw new SparkException(s"Feature value $feature out of Bucketizer bounds" + + s" [${splits.head}, ${splits.last}]. Check your features, or loosen " + + s"the lower/upper bound constraints.") + } else { + insertPos - 1 + } } } - left } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index acb46c0a35..1900820400 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -57,16 +57,18 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext { // Check for exceptions when using a set of invalid feature values. val invalidData1: Array[Double] = Array(-0.9) ++ validData - val invalidData2 = Array(0.5) ++ validData + val invalidData2 = Array(0.51) ++ validData val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx") - intercept[RuntimeException]{ - bucketizer.transform(badDF1).collect() - println("Invalid feature value -0.9 was not caught as an invalid feature!") + withClue("Invalid feature value -0.9 was not caught as an invalid feature!") { + intercept[SparkException] { + bucketizer.transform(badDF1).collect() + } } val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx") - intercept[RuntimeException]{ - bucketizer.transform(badDF2).collect() - println("Invalid feature value 0.5 was not caught as an invalid feature!") + withClue("Invalid feature value 0.51 was not caught as an invalid feature!") { + intercept[SparkException] { + bucketizer.transform(badDF2).collect() + } } } @@ -137,12 +139,11 @@ private object BucketizerSuite extends FunSuite { } var i = 0 while (i < splits.length - 1) { - testFeature(splits(i), i) // Split i should fall in bucket i. - testFeature((splits(i) + splits(i + 1)) / 2, i) // Value between splits i,i+1 should be in i. + // Split i should fall in bucket i. + testFeature(splits(i), i) + // Value between splits i,i+1 should be in i, which is also true if the (i+1)-th split is inf. + testFeature((splits(i) + splits(i + 1)) / 2, i) i += 1 } - if (splits.last === Double.PositiveInfinity) { - testFeature(Double.PositiveInfinity, splits.length - 2) - } } } |