aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-08-01 01:09:38 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-01 01:09:38 -0700
commit65038973a17904e0e04d453799ec108af240fbab (patch)
treee90123661088eb27645dcf0f9f684b9f8ab344b2 /mllib
parent60ea7ab4bbfaea29a6cdf4e0e71ddc56afd04de6 (diff)
downloadspark-65038973a17904e0e04d453799ec108af240fbab.tar.gz
spark-65038973a17904e0e04d453799ec108af240fbab.tar.bz2
spark-65038973a17904e0e04d453799ec108af240fbab.zip
[SPARK-7446] [MLLIB] Add inverse transform for string indexer
It is useful to convert the encoded indices back to their string representation for result inspection. We can add a function which creates an inverse transformation. Author: Holden Karau <holden@pigscanfly.ca> Closes #6339 from holdenk/SPARK-7446-inverse-transform-for-string-indexer and squashes the following commits: 7cdf915 [Holden Karau] scala style comment fix b9cffb6 [Holden Karau] Update the labels param to have the metadata note 6a38edb [Holden Karau] Setting the default needs to come after the value gets defined 9e241d8 [Holden Karau] use Array.empty 21c8cfa [Holden Karau] Merge branch 'master' into SPARK-7446-inverse-transform-for-string-indexer 64dd3a3 [Holden Karau] Merge branch 'master' into SPARK-7446-inverse-transform-for-string-indexer 4f06c59 [Holden Karau] Fix comment styles, use empty array as the default, etc. a60c0e3 [Holden Karau] CR feedback (remove old constructor, add a note about use of setLabels) 1987b95 [Holden Karau] Use default copy 71e8d66 [Holden Karau] Make labels a local param for StringIndexerInverse 8450d0b [Holden Karau] Use the labels param in StringIndexerInverse 7464019 [Holden Karau] Add a labels param 868b1a9 [Holden Karau] Update scaladoc since we don't have labelsCol anymore 5aa38bf [Holden Karau] Add an inverse test using only meta data, pass labels when calling inverse method f3e0c64 [Holden Karau] CR feedback ebed932 [Holden Karau] Add Experimental tag and some scaladocs. Also don't require that the inputCol has the metadata on it, instead have the labelsCol specified when creating the inverse. 03ebf95 [Holden Karau] Add explicit type for invert function ecc65e0 [Holden Karau] Read the metadata correctly, use the array, pass the test a42d773 [Holden Karau] Fix test to supply cols as per new invert method 16cc3c3 [Holden Karau] Add an invert method d4bcb20 [Holden Karau] Make the inverse string indexer into a transformer (still needs test updates but compiles) e8bf3ad [Holden Karau] Merge branch 'master' into SPARK-7446-inverse-transform-for-string-indexer c3fdee1 [Holden Karau] Some WIP refactoring based on jkbradley's CR feedback. Definite work-in-progress 557bef8 [Holden Karau] Instead of using a private inverse transform, add an invert function so we can use it in a pipeline 88779c1 [Holden Karau] fix long line 78b28c1 [Holden Karau] Finish reverse part and add a test :) bb16a6a [Holden Karau] Some progress
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala108
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala13
2 files changed, 118 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index bf7be363b8..ebfa972532 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -20,13 +20,14 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkException
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.attribute.NominalAttribute
+import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{NumericType, StringType, StructType}
+import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType}
import org.apache.spark.util.collection.OpenHashMap
/**
@@ -151,4 +152,105 @@ class StringIndexerModel private[ml] (
val copied = new StringIndexerModel(uid, labels)
copyValues(copied, extra)
}
+
+ /**
+ * Return a model to perform the inverse transformation.
+ * Note: By default we keep the original columns during this transformation, so the inverse
+ * should only be used on new columns such as predicted labels.
+ */
+ def invert(inputCol: String, outputCol: String): StringIndexerInverse = {
+ new StringIndexerInverse()
+ .setInputCol(inputCol)
+ .setOutputCol(outputCol)
+ .setLabels(labels)
+ }
+}
+
+/**
+ * :: Experimental ::
+ * Transform a provided column back to the original input types using either the metadata
+ * on the input column, or if provided using the labels supplied by the user.
+ * Note: By default we keep the original columns during this transformation,
+ * so the inverse should only be used on new columns such as predicted labels.
+ */
+@Experimental
+class StringIndexerInverse private[ml] (
+ override val uid: String) extends Transformer
+ with HasInputCol with HasOutputCol {
+
+ def this() =
+ this(Identifiable.randomUID("strIdxInv"))
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ /**
+ * Optional labels to be provided by the user, if not supplied column
+ * metadata is read for labels. The default value is an empty array,
+ * but the empty array is ignored and column metadata used instead.
+ * @group setParam
+ */
+ def setLabels(value: Array[String]): this.type = set(labels, value)
+
+ /**
+ * Param for array of labels.
+ * Optional labels to be provided by the user, if not supplied column
+ * metadata is read for labels.
+ * @group param
+ */
+ final val labels: StringArrayParam = new StringArrayParam(this, "labels",
+ "array of labels, if not provided metadata from inputCol is used instead.")
+ setDefault(labels, Array.empty[String])
+
+ /**
+ * Optional labels to be provided by the user, if not supplied column
+ * metadata is read for labels.
+ * @group getParam
+ */
+ final def getLabels: Array[String] = $(labels)
+
+ /** Transform the schema for the inverse transformation */
+ override def transformSchema(schema: StructType): StructType = {
+ val inputColName = $(inputCol)
+ val inputDataType = schema(inputColName).dataType
+ require(inputDataType.isInstanceOf[NumericType],
+ s"The input column $inputColName must be a numeric type, " +
+ s"but got $inputDataType.")
+ val inputFields = schema.fields
+ val outputColName = $(outputCol)
+ require(inputFields.forall(_.name != outputColName),
+ s"Output column $outputColName already exists.")
+ val attr = NominalAttribute.defaultAttr.withName($(outputCol))
+ val outputFields = inputFields :+ attr.toStructField()
+ StructType(outputFields)
+ }
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ val inputColSchema = dataset.schema($(inputCol))
+ // If the labels array is empty use column metadata
+ val values = if ($(labels).isEmpty) {
+ Attribute.fromStructField(inputColSchema)
+ .asInstanceOf[NominalAttribute].values.get
+ } else {
+ $(labels)
+ }
+ val indexer = udf { index: Double =>
+ val idx = index.toInt
+ if (0 <= idx && idx < values.size) {
+ values(idx)
+ } else {
+ throw new SparkException(s"Unseen index: $index ??")
+ }
+ }
+ val outputColName = $(outputCol)
+ dataset.select(col("*"),
+ indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName))
+ }
+
+ override def copy(extra: ParamMap): StringIndexerInverse = {
+ defaultCopy(extra)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 99f82bea42..d0295a0fe2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -47,6 +47,19 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
// a -> 0, b -> 2, c -> 1
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
assert(output === expected)
+ // convert reverse our transform
+ val reversed = indexer.invert("labelIndex", "label2")
+ .transform(transformed)
+ .select("id", "label2")
+ assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet ===
+ reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
+ // Check invert using only metadata
+ val inverse2 = new StringIndexerInverse()
+ .setInputCol("labelIndex")
+ .setOutputCol("label2")
+ val reversed2 = inverse2.transform(transformed).select("id", "label2")
+ assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet ===
+ reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
}
test("StringIndexer with a numeric input column") {