aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2015-05-11 18:41:22 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-11 18:41:36 -0700
commitf1888159894c2ed446cae0ce52fad199e2dda4cd (patch)
treece8b947170e2f9bc3d3326c940bccc5a29f1d02c /mllib/src/main
parente1e599d58ce5ad9e9c0f9e78dd9961c55ea69850 (diff)
downloadspark-f1888159894c2ed446cae0ce52fad199e2dda4cd.tar.gz
spark-f1888159894c2ed446cae0ce52fad199e2dda4cd.tar.bz2
spark-f1888159894c2ed446cae0ce52fad199e2dda4cd.zip
[SPARK-5893] [ML] Add bucketizer
JIRA issue [here](https://issues.apache.org/jira/browse/SPARK-5893). One thing to make clear, the `buckets` parameter, which is an array of `Double`, performs as split points. Say, ```scala buckets = Array(-0.5, 0.0, 0.5) ``` splits the real number into 4 ranges, (-inf, -0.5], (-0.5, 0.0], (0.0, 0.5], (0.5, +inf), which is encoded as 0, 1, 2, 3. Author: Xusen Yin <yinxusen@gmail.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #5980 from yinxusen/SPARK-5893 and squashes the following commits: dc8c843 [Xusen Yin] Merge pull request #4 from jkbradley/yinxusen-SPARK-5893 1ca973a [Joseph K. Bradley] one more bucketizer test 34f124a [Joseph K. Bradley] Removed lowerInclusive, upperInclusive params from Bucketizer, and used splits instead. eacfcfa [Xusen Yin] change ML attribute from splits into buckets c3cc770 [Xusen Yin] add more unit test for binary search 3a16cc2 [Xusen Yin] refine comments and names ac77859 [Xusen Yin] fix style error fb30d79 [Xusen Yin] fix and test binary search 2466322 [Xusen Yin] refactor Bucketizer 11fb00a [Xusen Yin] change it into an Estimator 998bc87 [Xusen Yin] check buckets 4024cf1 [Xusen Yin] add test suite 5fe190e [Xusen Yin] add bucketizer (cherry picked from commit 35fb42a0b01d3043b7d5e27256d1b45a08583aab) Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala131
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala11
2 files changed, 142 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
new file mode 100644
index 0000000000..7dba64bc35
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.attribute.NominalAttribute
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.sql._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
+
+/**
+ * :: AlphaComponent ::
+ * `Bucketizer` maps a column of continuous features to a column of feature buckets.
+ */
+@AlphaComponent
+final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])
+ extends Model[Bucketizer] with HasInputCol with HasOutputCol {
+
+ def this() = this(null)
+
+ /**
+ * Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets.
+ * A bucket defined by splits x,y holds values in the range [x,y). Splits should be strictly
+ * increasing. Values at -inf, inf must be explicitly provided to cover all Double values;
+ * otherwise, values outside the splits specified will be treated as errors.
+ * @group param
+ */
+ val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits",
+ "Split points for mapping continuous features into buckets. With n splits, there are n+1 " +
+ "buckets. A bucket defined by splits x,y holds values in the range [x,y). The splits " +
+ "should be strictly increasing. Values at -inf, inf must be explicitly provided to cover" +
+ " all Double values; otherwise, values outside the splits specified will be treated as" +
+ " errors.",
+ Bucketizer.checkSplits)
+
+ /** @group getParam */
+ def getSplits: Array[Double] = $(splits)
+
+ /** @group setParam */
+ def setSplits(value: Array[Double]): this.type = set(splits, value)
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ transformSchema(dataset.schema)
+ val bucketizer = udf { feature: Double =>
+ Bucketizer.binarySearchForBuckets($(splits), feature)
+ }
+ val newCol = bucketizer(dataset($(inputCol)))
+ val newField = prepOutputField(dataset.schema)
+ dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
+ }
+
+ private def prepOutputField(schema: StructType): StructField = {
+ val buckets = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray
+ val attr = new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true),
+ values = Some(buckets))
+ attr.toStructField()
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
+ SchemaUtils.appendColumn(schema, prepOutputField(schema))
+ }
+}
+
+private[feature] object Bucketizer {
+ /** We require splits to be of length >= 3 and to be in strictly increasing order. */
+ def checkSplits(splits: Array[Double]): Boolean = {
+ if (splits.length < 3) {
+ false
+ } else {
+ var i = 0
+ while (i < splits.length - 1) {
+ if (splits(i) >= splits(i + 1)) return false
+ i += 1
+ }
+ true
+ }
+ }
+
+ /**
+ * Binary searching in several buckets to place each data point.
+ * @throws RuntimeException if a feature is < splits.head or >= splits.last
+ */
+ def binarySearchForBuckets(
+ splits: Array[Double],
+ feature: Double): Double = {
+ // Check bounds. We make an exception for +inf so that it can exist in some bin.
+ if ((feature < splits.head) || (feature >= splits.last && feature != Double.PositiveInfinity)) {
+ throw new RuntimeException(s"Feature value $feature out of Bucketizer bounds" +
+ s" [${splits.head}, ${splits.last}). Check your features, or loosen " +
+ s"the lower/upper bound constraints.")
+ }
+ var left = 0
+ var right = splits.length - 2
+ while (left < right) {
+ val mid = (left + right) / 2
+ val split = splits(mid + 1)
+ if (feature < split) {
+ right = mid
+ } else {
+ left = mid + 1
+ }
+ }
+ left
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
index 0383bf0b38..11592b77eb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
@@ -58,4 +58,15 @@ object SchemaUtils {
val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false)
StructType(outputFields)
}
+
+ /**
+ * Appends a new column to the input schema. This fails if the given output column already exists.
+ * @param schema input schema
+ * @param col New column schema
+ * @return new schema with the input column appended
+ */
+ def appendColumn(schema: StructType, col: StructField): StructType = {
+ require(!schema.fieldNames.contains(col.name), s"Column ${col.name} already exists.")
+ StructType(schema.fields :+ col)
+ }
}