aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-04-12 22:41:05 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-12 22:41:05 -0700
commit685ddcf5253c0ecb39853802431e22b0c7b61dee (patch)
treebdca5d046db951a373eda0dd0c27a7e34531978f /mllib/src
parentd3792f54974e16cbe8f10b3091d248e0bdd48986 (diff)
downloadspark-685ddcf5253c0ecb39853802431e22b0c7b61dee.tar.gz
spark-685ddcf5253c0ecb39853802431e22b0c7b61dee.tar.bz2
spark-685ddcf5253c0ecb39853802431e22b0c7b61dee.zip
[SPARK-5886][ML] Add StringIndexer as a feature transformer
This PR adds string indexer, which takes a column of string labels and outputs a double column with labels indexed by their frequency. TODOs: - [x] store feature to index map in output metadata Author: Xiangrui Meng <meng@databricks.com> Closes #4735 from mengxr/SPARK-5886 and squashes the following commits: d82575f [Xiangrui Meng] fix test 700e70f [Xiangrui Meng] rename LabelIndexer to StringIndexer 16a6f8c [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5886 457166e [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5886 f8b30f4 [Xiangrui Meng] update label indexer to output metadata e81ec28 [Xiangrui Meng] Merge branch 'openhashmap-contains' into SPARK-5886-2 d6e6f1f [Xiangrui Meng] add contains to primitivekeyopenhashmap 748a69b [Xiangrui Meng] add contains to OpenHashMap def3c5c [Xiangrui Meng] add LabelIndexer
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala126
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala52
2 files changed, 178 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
new file mode 100644
index 0000000000..61e6742e88
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkException
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.attribute.NominalAttribute
+import org.apache.spark.ml.param._
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{StringType, StructType}
+import org.apache.spark.util.collection.OpenHashMap
+
+/**
+ * Base trait for [[StringIndexer]] and [[StringIndexerModel]].
+ */
+private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
+
+ /** Validates and transforms the input schema. */
+ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = this.paramMap ++ paramMap
+ checkInputColumn(schema, map(inputCol), StringType)
+ val inputFields = schema.fields
+ val outputColName = map(outputCol)
+ require(inputFields.forall(_.name != outputColName),
+ s"Output column $outputColName already exists.")
+ val attr = NominalAttribute.defaultAttr.withName(map(outputCol))
+ val outputFields = inputFields :+ attr.toStructField()
+ StructType(outputFields)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * A label indexer that maps a string column of labels to an ML column of label indices.
+ * The indices are in [0, numLabels), ordered by label frequencies.
+ * So the most frequent label gets index 0.
+ */
+@AlphaComponent
+class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase {
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ // TODO: handle unseen labels
+
+ override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = {
+ val map = this.paramMap ++ paramMap
+ val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue()
+ val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
+ val model = new StringIndexerModel(this, map, labels)
+ Params.inheritValues(map, this, model)
+ model
+ }
+
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Model fitted by [[StringIndexer]].
+ */
+@AlphaComponent
+class StringIndexerModel private[ml] (
+ override val parent: StringIndexer,
+ override val fittingParamMap: ParamMap,
+ labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase {
+
+ private val labelToIndex: OpenHashMap[String, Double] = {
+ val n = labels.length
+ val map = new OpenHashMap[String, Double](n)
+ var i = 0
+ while (i < n) {
+ map.update(labels(i), i)
+ i += 1
+ }
+ map
+ }
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ val map = this.paramMap ++ paramMap
+ val indexer = udf { label: String =>
+ if (labelToIndex.contains(label)) {
+ labelToIndex(label)
+ } else {
+ // TODO: handle unseen labels
+ throw new SparkException(s"Unseen label: $label.")
+ }
+ }
+ val outputColName = map(outputCol)
+ val metadata = NominalAttribute.defaultAttr
+ .withName(outputColName).withValues(labels).toStructField().metadata
+ dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata))
+ }
+
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
new file mode 100644
index 0000000000..00b5d094d8
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.SQLContext
+
+class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
+ private var sqlContext: SQLContext = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sqlContext = new SQLContext(sc)
+ }
+
+ test("StringIndexer") {
+ val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
+ val df = sqlContext.createDataFrame(data).toDF("id", "label")
+ val indexer = new StringIndexer()
+ .setInputCol("label")
+ .setOutputCol("labelIndex")
+ .fit(df)
+ val transformed = indexer.transform(df)
+ val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attr.values.get === Array("a", "c", "b"))
+ val output = transformed.select("id", "labelIndex").map { r =>
+ (r.getInt(0), r.getDouble(1))
+ }.collect().toSet
+ // a -> 0, b -> 2, c -> 1
+ val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
+ assert(output === expected)
+ }
+}