diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-04-12 22:41:05 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-04-12 22:41:05 -0700 |
commit | 685ddcf5253c0ecb39853802431e22b0c7b61dee (patch) | |
tree | bdca5d046db951a373eda0dd0c27a7e34531978f /mllib | |
parent | d3792f54974e16cbe8f10b3091d248e0bdd48986 (diff) | |
download | spark-685ddcf5253c0ecb39853802431e22b0c7b61dee.tar.gz spark-685ddcf5253c0ecb39853802431e22b0c7b61dee.tar.bz2 spark-685ddcf5253c0ecb39853802431e22b0c7b61dee.zip |
[SPARK-5886][ML] Add StringIndexer as a feature transformer
This PR adds string indexer, which takes a column of string labels and outputs a double column with labels indexed by their frequency.
TODOs:
- [x] store feature to index map in output metadata
Author: Xiangrui Meng <meng@databricks.com>
Closes #4735 from mengxr/SPARK-5886 and squashes the following commits:
d82575f [Xiangrui Meng] fix test
700e70f [Xiangrui Meng] rename LabelIndexer to StringIndexer
16a6f8c [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5886
457166e [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5886
f8b30f4 [Xiangrui Meng] update label indexer to output metadata
e81ec28 [Xiangrui Meng] Merge branch 'openhashmap-contains' into SPARK-5886-2
d6e6f1f [Xiangrui Meng] add contains to primitivekeyopenhashmap
748a69b [Xiangrui Meng] add contains to OpenHashMap
def3c5c [Xiangrui Meng] add LabelIndexer
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala | 126 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala | 52 |
2 files changed, 178 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala new file mode 100644 index 0000000000..61e6742e88 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkException +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.param._ +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.util.collection.OpenHashMap + +/** + * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. + */ +private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + checkInputColumn(schema, map(inputCol), StringType) + val inputFields = schema.fields + val outputColName = map(outputCol) + require(inputFields.forall(_.name != outputColName), + s"Output column $outputColName already exists.") + val attr = NominalAttribute.defaultAttr.withName(map(outputCol)) + val outputFields = inputFields :+ attr.toStructField() + StructType(outputFields) + } +} + +/** + * :: AlphaComponent :: + * A label indexer that maps a string column of labels to an ML column of label indices. + * The indices are in [0, numLabels), ordered by label frequencies. + * So the most frequent label gets index 0. + */ +@AlphaComponent +class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + // TODO: handle unseen labels + + override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = { + val map = this.paramMap ++ paramMap + val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue() + val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray + val model = new StringIndexerModel(this, map, labels) + Params.inheritValues(map, this, model) + model + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +/** + * :: AlphaComponent :: + * Model fitted by [[StringIndexer]]. + */ +@AlphaComponent +class StringIndexerModel private[ml] ( + override val parent: StringIndexer, + override val fittingParamMap: ParamMap, + labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + + private val labelToIndex: OpenHashMap[String, Double] = { + val n = labels.length + val map = new OpenHashMap[String, Double](n) + var i = 0 + while (i < n) { + map.update(labels(i), i) + i += 1 + } + map + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + val map = this.paramMap ++ paramMap + val indexer = udf { label: String => + if (labelToIndex.contains(label)) { + labelToIndex(label) + } else { + // TODO: handle unseen labels + throw new SparkException(s"Unseen label: $label.") + } + } + val outputColName = map(outputCol) + val metadata = NominalAttribute.defaultAttr + .withName(outputColName).withValues(labels).toStructField().metadata + dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata)) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala new file mode 100644 index 0000000000..00b5d094d8 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.SQLContext + +class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { + private var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("StringIndexer") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val transformed = indexer.transform(df) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("a", "c", "b")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 0, b -> 2, c -> 1 + val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) + assert(output === expected) + } +} |