diff options
author | Sandy Ryza <sandy@cloudera.com> | 2015-05-05 12:34:02 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-05 12:34:11 -0700 |
commit | 94ac9eba2188d2b1d7140bec4929a77fde66474f (patch) | |
tree | 07b10beb3560c32cf6c836ff6607f2898859620d /mllib/src/main | |
parent | dfb6bfce42b2b91977f0190548a691d0f72b71c5 (diff) | |
download | spark-94ac9eba2188d2b1d7140bec4929a77fde66474f.tar.gz spark-94ac9eba2188d2b1d7140bec4929a77fde66474f.tar.bz2 spark-94ac9eba2188d2b1d7140bec4929a77fde66474f.zip |
[SPARK-5888] [MLLIB] Add OneHotEncoder as a Transformer
This patch adds a one hot encoder for categorical features. Planning to add documentation and another test after getting feedback on the approach.
A couple choices made here:
* There's an `includeFirst` option which, if false, creates numCategories - 1 columns and, if true, creates numCategories columns. The default is true, which is the behavior in scikit-learn.
* The user is expected to pass a `Seq` of category names when instantiating a `OneHotEncoder`. These can be easily gotten from a `StringIndexer`. The names are used for the output column names, which take the form colName_categoryName.
Author: Sandy Ryza <sandy@cloudera.com>
Closes #5500 from sryza/sandy-spark-5888 and squashes the following commits:
f383250 [Sandy Ryza] Infer label names automatically
6e257b9 [Sandy Ryza] Review comments
7c539cf [Sandy Ryza] Vector transformers
1c182dd [Sandy Ryza] SPARK-5888. [MLLIB]. Add OneHotEncoder as a Transformer
(cherry picked from commit 47728db7cfac995d9417cdf0e16d07391aabd581)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala | 107 |
1 files changed, 107 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala new file mode 100644 index 0000000000..46514ae5f0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkException +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} + +/** + * A one-hot encoder that maps a column of label indices to a column of binary vectors, with + * at most a single one-value. By default, the binary vector has an element for each category, so + * with 5 categories, an input value of 2.0 would map to an output vector of + * (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the + * output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value + * of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns + * linearly dependent because they sum up to one. + */ +@AlphaComponent +class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder] + with HasInputCol with HasOutputCol { + + /** + * Whether to include a component in the encoded vectors for the first category, defaults to true. + * @group param + */ + final val includeFirst: BooleanParam = + new BooleanParam(this, "includeFirst", "include first category") + setDefault(includeFirst -> true) + + private var categories: Array[String] = _ + + /** @group setParam */ + def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value) + + /** @group setParam */ + override def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + override def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) + val inputFields = schema.fields + val outputColName = $(outputCol) + require(inputFields.forall(_.name != $(outputCol)), + s"Output column ${$(outputCol)} already exists.") + + val inputColAttr = Attribute.fromStructField(schema($(inputCol))) + categories = inputColAttr match { + case nominal: NominalAttribute => + nominal.values.getOrElse((0 until nominal.numValues.get).map(_.toString).toArray) + case binary: BinaryAttribute => binary.values.getOrElse(Array("0", "1")) + case _ => + throw new SparkException(s"OneHotEncoder input column ${$(inputCol)} is not nominal") + } + + val attrValues = (if ($(includeFirst)) categories else categories.drop(1)).toArray + val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues) + val outputFields = inputFields :+ attr.toStructField() + StructType(outputFields) + } + + protected override def createTransformFunc(): (Double) => Vector = { + val first = $(includeFirst) + val vecLen = if (first) categories.length else categories.length - 1 + val oneValue = Array(1.0) + val emptyValues = Array[Double]() + val emptyIndices = Array[Int]() + label: Double => { + val values = if (first || label != 0.0) oneValue else emptyValues + val indices = if (first) { + Array(label.toInt) + } else if (label != 0.0) { + Array(label.toInt - 1) + } else { + emptyIndices + } + Vectors.sparse(vecLen, indices, values) + } + } + + /** + * Returns the data type of the output column. + */ + protected def outputDataType: DataType = new VectorUDT +} |