aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala55
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala25
2 files changed, 41 insertions, 39 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 7dba64bc35..b28c88aaae 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -17,6 +17,9 @@
package org.apache.spark.ml.feature
+import java.{util => ju}
+
+import org.apache.spark.SparkException
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
@@ -38,18 +41,19 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])
def this() = this(null)
/**
- * Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets.
- * A bucket defined by splits x,y holds values in the range [x,y). Splits should be strictly
- * increasing. Values at -inf, inf must be explicitly provided to cover all Double values;
+ * Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets.
+ * A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which
+ * also includes y. Splits should be strictly increasing.
+ * Values at -inf, inf must be explicitly provided to cover all Double values;
* otherwise, values outside the splits specified will be treated as errors.
* @group param
*/
val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits",
- "Split points for mapping continuous features into buckets. With n splits, there are n+1 " +
- "buckets. A bucket defined by splits x,y holds values in the range [x,y). The splits " +
- "should be strictly increasing. Values at -inf, inf must be explicitly provided to cover" +
- " all Double values; otherwise, values outside the splits specified will be treated as" +
- " errors.",
+ "Split points for mapping continuous features into buckets. With n+1 splits, there are n " +
+ "buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last " +
+ "bucket, which also includes y. The splits should be strictly increasing. " +
+ "Values at -inf, inf must be explicitly provided to cover all Double values; " +
+ "otherwise, values outside the splits specified will be treated as errors.",
Bucketizer.checkSplits)
/** @group getParam */
@@ -104,28 +108,25 @@ private[feature] object Bucketizer {
/**
* Binary searching in several buckets to place each data point.
- * @throws RuntimeException if a feature is < splits.head or >= splits.last
+ * @throws SparkException if a feature is < splits.head or > splits.last
*/
- def binarySearchForBuckets(
- splits: Array[Double],
- feature: Double): Double = {
- // Check bounds. We make an exception for +inf so that it can exist in some bin.
- if ((feature < splits.head) || (feature >= splits.last && feature != Double.PositiveInfinity)) {
- throw new RuntimeException(s"Feature value $feature out of Bucketizer bounds" +
- s" [${splits.head}, ${splits.last}). Check your features, or loosen " +
- s"the lower/upper bound constraints.")
- }
- var left = 0
- var right = splits.length - 2
- while (left < right) {
- val mid = (left + right) / 2
- val split = splits(mid + 1)
- if (feature < split) {
- right = mid
+ def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
+ if (feature == splits.last) {
+ splits.length - 2
+ } else {
+ val idx = ju.Arrays.binarySearch(splits, feature)
+ if (idx >= 0) {
+ idx
} else {
- left = mid + 1
+ val insertPos = -idx - 1
+ if (insertPos == 0 || insertPos == splits.length) {
+ throw new SparkException(s"Feature value $feature out of Bucketizer bounds" +
+ s" [${splits.head}, ${splits.last}]. Check your features, or loosen " +
+ s"the lower/upper bound constraints.")
+ } else {
+ insertPos - 1
+ }
}
}
- left
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index acb46c0a35..1900820400 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -57,16 +57,18 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
// Check for exceptions when using a set of invalid feature values.
val invalidData1: Array[Double] = Array(-0.9) ++ validData
- val invalidData2 = Array(0.5) ++ validData
+ val invalidData2 = Array(0.51) ++ validData
val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx")
- intercept[RuntimeException]{
- bucketizer.transform(badDF1).collect()
- println("Invalid feature value -0.9 was not caught as an invalid feature!")
+ withClue("Invalid feature value -0.9 was not caught as an invalid feature!") {
+ intercept[SparkException] {
+ bucketizer.transform(badDF1).collect()
+ }
}
val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx")
- intercept[RuntimeException]{
- bucketizer.transform(badDF2).collect()
- println("Invalid feature value 0.5 was not caught as an invalid feature!")
+ withClue("Invalid feature value 0.51 was not caught as an invalid feature!") {
+ intercept[SparkException] {
+ bucketizer.transform(badDF2).collect()
+ }
}
}
@@ -137,12 +139,11 @@ private object BucketizerSuite extends FunSuite {
}
var i = 0
while (i < splits.length - 1) {
- testFeature(splits(i), i) // Split i should fall in bucket i.
- testFeature((splits(i) + splits(i + 1)) / 2, i) // Value between splits i,i+1 should be in i.
+ // Split i should fall in bucket i.
+ testFeature(splits(i), i)
+ // Value between splits i,i+1 should be in i, which is also true if the (i+1)-th split is inf.
+ testFeature((splits(i) + splits(i + 1)) / 2, i)
i += 1
}
- if (splits.last === Double.PositiveInfinity) {
- testFeature(Double.PositiveInfinity, splits.length - 2)
- }
}
}