aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-12 14:24:26 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-12 14:24:26 -0700
commit23b9863e2aa7ecd0c4fa3aa8a59fdae09b4fe1d7 (patch)
tree90683222288bbb083b31ea1e1e2fc61f1d9649ed /mllib
parent2a41c0d71a13558f12c6811bf98791e01186f3ad (diff)
downloadspark-23b9863e2aa7ecd0c4fa3aa8a59fdae09b4fe1d7.tar.gz
spark-23b9863e2aa7ecd0c4fa3aa8a59fdae09b4fe1d7.tar.bz2
spark-23b9863e2aa7ecd0c4fa3aa8a59fdae09b4fe1d7.zip
[SPARK-7559] [MLLIB] Bucketizer should include the right most boundary in the last bucket.
We make special treatment for +inf in `Bucketizer`. This could be simplified by always including the largest split value in the last bucket. E.g., (x1, x2, x3) defines buckets [x1, x2) and [x2, x3]. This shouldn't affect user code much, and there are applications that need to include the right-most value. For example, we can bucketize ratings from 0 to 10 to bad, neutral, and good with splits 0, 4, 6, 10. It may reads weird if the users need to put 0, 4, 6, 10.1 (or 11). This also update the impl to use `Arrays.binarySearch` and `withClue` in test. yinxusen jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #6075 from mengxr/SPARK-7559 and squashes the following commits: e28f910 [Xiangrui Meng] update bucketizer impl
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala55
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala25
2 files changed, 41 insertions, 39 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 7dba64bc35..b28c88aaae 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -17,6 +17,9 @@
package org.apache.spark.ml.feature
+import java.{util => ju}
+
+import org.apache.spark.SparkException
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
@@ -38,18 +41,19 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])
def this() = this(null)
/**
- * Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets.
- * A bucket defined by splits x,y holds values in the range [x,y). Splits should be strictly
- * increasing. Values at -inf, inf must be explicitly provided to cover all Double values;
+ * Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets.
+ * A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which
+ * also includes y. Splits should be strictly increasing.
+ * Values at -inf, inf must be explicitly provided to cover all Double values;
* otherwise, values outside the splits specified will be treated as errors.
* @group param
*/
val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits",
- "Split points for mapping continuous features into buckets. With n splits, there are n+1 " +
- "buckets. A bucket defined by splits x,y holds values in the range [x,y). The splits " +
- "should be strictly increasing. Values at -inf, inf must be explicitly provided to cover" +
- " all Double values; otherwise, values outside the splits specified will be treated as" +
- " errors.",
+ "Split points for mapping continuous features into buckets. With n+1 splits, there are n " +
+ "buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last " +
+ "bucket, which also includes y. The splits should be strictly increasing. " +
+ "Values at -inf, inf must be explicitly provided to cover all Double values; " +
+ "otherwise, values outside the splits specified will be treated as errors.",
Bucketizer.checkSplits)
/** @group getParam */
@@ -104,28 +108,25 @@ private[feature] object Bucketizer {
/**
* Binary searching in several buckets to place each data point.
- * @throws RuntimeException if a feature is < splits.head or >= splits.last
+ * @throws SparkException if a feature is < splits.head or > splits.last
*/
- def binarySearchForBuckets(
- splits: Array[Double],
- feature: Double): Double = {
- // Check bounds. We make an exception for +inf so that it can exist in some bin.
- if ((feature < splits.head) || (feature >= splits.last && feature != Double.PositiveInfinity)) {
- throw new RuntimeException(s"Feature value $feature out of Bucketizer bounds" +
- s" [${splits.head}, ${splits.last}). Check your features, or loosen " +
- s"the lower/upper bound constraints.")
- }
- var left = 0
- var right = splits.length - 2
- while (left < right) {
- val mid = (left + right) / 2
- val split = splits(mid + 1)
- if (feature < split) {
- right = mid
+ def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
+ if (feature == splits.last) {
+ splits.length - 2
+ } else {
+ val idx = ju.Arrays.binarySearch(splits, feature)
+ if (idx >= 0) {
+ idx
} else {
- left = mid + 1
+ val insertPos = -idx - 1
+ if (insertPos == 0 || insertPos == splits.length) {
+ throw new SparkException(s"Feature value $feature out of Bucketizer bounds" +
+ s" [${splits.head}, ${splits.last}]. Check your features, or loosen " +
+ s"the lower/upper bound constraints.")
+ } else {
+ insertPos - 1
+ }
}
}
- left
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index acb46c0a35..1900820400 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -57,16 +57,18 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
// Check for exceptions when using a set of invalid feature values.
val invalidData1: Array[Double] = Array(-0.9) ++ validData
- val invalidData2 = Array(0.5) ++ validData
+ val invalidData2 = Array(0.51) ++ validData
val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx")
- intercept[RuntimeException]{
- bucketizer.transform(badDF1).collect()
- println("Invalid feature value -0.9 was not caught as an invalid feature!")
+ withClue("Invalid feature value -0.9 was not caught as an invalid feature!") {
+ intercept[SparkException] {
+ bucketizer.transform(badDF1).collect()
+ }
}
val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx")
- intercept[RuntimeException]{
- bucketizer.transform(badDF2).collect()
- println("Invalid feature value 0.5 was not caught as an invalid feature!")
+ withClue("Invalid feature value 0.51 was not caught as an invalid feature!") {
+ intercept[SparkException] {
+ bucketizer.transform(badDF2).collect()
+ }
}
}
@@ -137,12 +139,11 @@ private object BucketizerSuite extends FunSuite {
}
var i = 0
while (i < splits.length - 1) {
- testFeature(splits(i), i) // Split i should fall in bucket i.
- testFeature((splits(i) + splits(i + 1)) / 2, i) // Value between splits i,i+1 should be in i.
+ // Split i should fall in bucket i.
+ testFeature(splits(i), i)
+ // Value between splits i,i+1 should be in i, which is also true if the (i+1)-th split is inf.
+ testFeature((splits(i) + splits(i + 1)) / 2, i)
i += 1
}
- if (splits.last === Double.PositiveInfinity) {
- testFeature(Double.PositiveInfinity, splits.length - 2)
- }
}
}