aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSandy Ryza <sandy@cloudera.com>2015-05-05 12:34:02 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-05 12:34:02 -0700
commit47728db7cfac995d9417cdf0e16d07391aabd581 (patch)
tree4479d80b2c281512c29ea0f32d21168cca493e58
parentee374e89cd1f08730fed9d50b742627d5b19d241 (diff)
downloadspark-47728db7cfac995d9417cdf0e16d07391aabd581.tar.gz
spark-47728db7cfac995d9417cdf0e16d07391aabd581.tar.bz2
spark-47728db7cfac995d9417cdf0e16d07391aabd581.zip
[SPARK-5888] [MLLIB] Add OneHotEncoder as a Transformer
This patch adds a one hot encoder for categorical features. Planning to add documentation and another test after getting feedback on the approach. A couple choices made here: * There's an `includeFirst` option which, if false, creates numCategories - 1 columns and, if true, creates numCategories columns. The default is true, which is the behavior in scikit-learn. * The user is expected to pass a `Seq` of category names when instantiating a `OneHotEncoder`. These can be easily gotten from a `StringIndexer`. The names are used for the output column names, which take the form colName_categoryName. Author: Sandy Ryza <sandy@cloudera.com> Closes #5500 from sryza/sandy-spark-5888 and squashes the following commits: f383250 [Sandy Ryza] Infer label names automatically 6e257b9 [Sandy Ryza] Review comments 7c539cf [Sandy Ryza] Vector transformers 1c182dd [Sandy Ryza] SPARK-5888. [MLLIB]. Add OneHotEncoder as a Transformer
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala107
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala80
2 files changed, 187 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
new file mode 100644
index 0000000000..46514ae5f0
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkException
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
+import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+
+/**
+ * A one-hot encoder that maps a column of label indices to a column of binary vectors, with
+ * at most a single one-value. By default, the binary vector has an element for each category, so
+ * with 5 categories, an input value of 2.0 would map to an output vector of
+ * (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the
+ * output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value
+ * of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns
+ * linearly dependent because they sum up to one.
+ */
+@AlphaComponent
+class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder]
+ with HasInputCol with HasOutputCol {
+
+ /**
+ * Whether to include a component in the encoded vectors for the first category, defaults to true.
+ * @group param
+ */
+ final val includeFirst: BooleanParam =
+ new BooleanParam(this, "includeFirst", "include first category")
+ setDefault(includeFirst -> true)
+
+ private var categories: Array[String] = _
+
+ /** @group setParam */
+ def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value)
+
+ /** @group setParam */
+ override def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ override def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ override def transformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
+ val inputFields = schema.fields
+ val outputColName = $(outputCol)
+ require(inputFields.forall(_.name != $(outputCol)),
+ s"Output column ${$(outputCol)} already exists.")
+
+ val inputColAttr = Attribute.fromStructField(schema($(inputCol)))
+ categories = inputColAttr match {
+ case nominal: NominalAttribute =>
+ nominal.values.getOrElse((0 until nominal.numValues.get).map(_.toString).toArray)
+ case binary: BinaryAttribute => binary.values.getOrElse(Array("0", "1"))
+ case _ =>
+ throw new SparkException(s"OneHotEncoder input column ${$(inputCol)} is not nominal")
+ }
+
+ val attrValues = (if ($(includeFirst)) categories else categories.drop(1)).toArray
+ val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues)
+ val outputFields = inputFields :+ attr.toStructField()
+ StructType(outputFields)
+ }
+
+ protected override def createTransformFunc(): (Double) => Vector = {
+ val first = $(includeFirst)
+ val vecLen = if (first) categories.length else categories.length - 1
+ val oneValue = Array(1.0)
+ val emptyValues = Array[Double]()
+ val emptyIndices = Array[Int]()
+ label: Double => {
+ val values = if (first || label != 0.0) oneValue else emptyValues
+ val indices = if (first) {
+ Array(label.toInt)
+ } else if (label != 0.0) {
+ Array(label.toInt - 1)
+ } else {
+ emptyIndices
+ }
+ Vectors.sparse(vecLen, indices, values)
+ }
+ }
+
+ /**
+ * Returns the data type of the output column.
+ */
+ protected def outputDataType: DataType = new VectorUDT
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
new file mode 100644
index 0000000000..92ec407b98
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
+
+class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
+ private var sqlContext: SQLContext = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sqlContext = new SQLContext(sc)
+ }
+
+ def stringIndexed(): DataFrame = {
+ val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
+ val df = sqlContext.createDataFrame(data).toDF("id", "label")
+ val indexer = new StringIndexer()
+ .setInputCol("label")
+ .setOutputCol("labelIndex")
+ .fit(df)
+ indexer.transform(df)
+ }
+
+ test("OneHotEncoder includeFirst = true") {
+ val transformed = stringIndexed()
+ val encoder = new OneHotEncoder()
+ .setInputCol("labelIndex")
+ .setOutputCol("labelVec")
+ val encoded = encoder.transform(transformed)
+
+ val output = encoded.select("id", "labelVec").map { r =>
+ val vec = r.get(1).asInstanceOf[Vector]
+ (r.getInt(0), vec(0), vec(1), vec(2))
+ }.collect().toSet
+ // a -> 0, b -> 2, c -> 1
+ val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0),
+ (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0))
+ assert(output === expected)
+ }
+
+ test("OneHotEncoder includeFirst = false") {
+ val transformed = stringIndexed()
+ val encoder = new OneHotEncoder()
+ .setIncludeFirst(false)
+ .setInputCol("labelIndex")
+ .setOutputCol("labelVec")
+ val encoded = encoder.transform(transformed)
+
+ val output = encoded.select("id", "labelVec").map { r =>
+ val vec = r.get(1).asInstanceOf[Vector]
+ (r.getInt(0), vec(0), vec(1))
+ }.collect().toSet
+ // a -> 0, b -> 2, c -> 1
+ val expected = Set((0, 0.0, 0.0), (1, 0.0, 1.0), (2, 1.0, 0.0),
+ (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0))
+ assert(output === expected)
+ }
+
+}