From 6c5858bc65c8a8602422b46bfa9cf0a1fb296b88 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 13 Aug 2015 16:52:17 -0700 Subject: [SPARK-9922] [ML] rename StringIndexerReverse to IndexToString What `StringIndexerInverse` does is not strictly associated with `StringIndexer`, and the name is not clearly describing the transformation. Renaming to `IndexToString` might be better. ~~I also changed `invert` to `inverse` without arguments. `inputCol` and `outputCol` could be set after.~~ I also removed `invert`. jkbradley holdenk Author: Xiangrui Meng Closes #8152 from mengxr/SPARK-9922. --- .../apache/spark/ml/feature/StringIndexer.scala | 34 ++++++--------- .../spark/ml/feature/StringIndexerSuite.scala | 50 +++++++++++++++------- 2 files changed, 48 insertions(+), 36 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 9e4b0f0add..9f6e7b6b6b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.Transformer -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType} @@ -59,6 +59,8 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha * If the input column is numeric, we cast it to string and index the string values. * The indices are in [0, numLabels), ordered by label frequencies. * So the most frequent label gets index 0. + * + * @see [[IndexToString]] for the inverse transformation */ @Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] @@ -170,34 +172,24 @@ class StringIndexerModel private[ml] ( val copied = new StringIndexerModel(uid, labels) copyValues(copied, extra).setParent(parent) } - - /** - * Return a model to perform the inverse transformation. - * Note: By default we keep the original columns during this transformation, so the inverse - * should only be used on new columns such as predicted labels. - */ - def invert(inputCol: String, outputCol: String): StringIndexerInverse = { - new StringIndexerInverse() - .setInputCol(inputCol) - .setOutputCol(outputCol) - .setLabels(labels) - } } /** * :: Experimental :: - * Transform a provided column back to the original input types using either the metadata - * on the input column, or if provided using the labels supplied by the user. - * Note: By default we keep the original columns during this transformation, - * so the inverse should only be used on new columns such as predicted labels. + * A [[Transformer]] that maps a column of string indices back to a new column of corresponding + * string values using either the ML attributes of the input column, or if provided using the labels + * supplied by the user. + * All original columns are kept during transformation. + * + * @see [[StringIndexer]] for converting strings into indices */ @Experimental -class StringIndexerInverse private[ml] ( +class IndexToString private[ml] ( override val uid: String) extends Transformer with HasInputCol with HasOutputCol { def this() = - this(Identifiable.randomUID("strIdxInv")) + this(Identifiable.randomUID("idxToStr")) /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -257,7 +249,7 @@ class StringIndexerInverse private[ml] ( } val indexer = udf { index: Double => val idx = index.toInt - if (0 <= idx && idx < values.size) { + if (0 <= idx && idx < values.length) { values(idx) } else { throw new SparkException(s"Unseen index: $index ??") @@ -268,7 +260,7 @@ class StringIndexerInverse private[ml] ( indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName)) } - override def copy(extra: ParamMap): StringIndexerInverse = { + override def copy(extra: ParamMap): IndexToString = { defaultCopy(extra) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 2d24914cb9..fa918ce648 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkException -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.col class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -53,19 +54,6 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { // a -> 0, b -> 2, c -> 1 val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) - // convert reverse our transform - val reversed = indexer.invert("labelIndex", "label2") - .transform(transformed) - .select("id", "label2") - assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === - reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet) - // Check invert using only metadata - val inverse2 = new StringIndexerInverse() - .setInputCol("labelIndex") - .setOutputCol("label2") - val reversed2 = inverse2.transform(transformed).select("id", "label2") - assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === - reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet) } test("StringIndexerUnseen") { @@ -125,4 +113,36 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { val df = sqlContext.range(0L, 10L) assert(indexerModel.transform(df).eq(df)) } + + test("IndexToString params") { + val idxToStr = new IndexToString() + ParamsSuite.checkParams(idxToStr) + } + + test("IndexToString.transform") { + val labels = Array("a", "b", "c") + val df0 = sqlContext.createDataFrame(Seq( + (0, "a"), (1, "b"), (2, "c"), (0, "a") + )).toDF("index", "expected") + + val idxToStr0 = new IndexToString() + .setInputCol("index") + .setOutputCol("actual") + .setLabels(labels) + idxToStr0.transform(df0).select("actual", "expected").collect().foreach { + case Row(actual, expected) => + assert(actual === expected) + } + + val attr = NominalAttribute.defaultAttr.withValues(labels) + val df1 = df0.select(col("index").as("indexWithAttr", attr.toMetadata()), col("expected")) + + val idxToStr1 = new IndexToString() + .setInputCol("indexWithAttr") + .setOutputCol("actual") + idxToStr1.transform(df1).select("actual", "expected").collect().foreach { + case Row(actual, expected) => + assert(actual === expected) + } + } } -- cgit v1.2.3