aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-08-13 16:52:17 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-13 16:52:17 -0700
commit6c5858bc65c8a8602422b46bfa9cf0a1fb296b88 (patch)
treed666cc0ed10832e350109a1eb9fd554a75ae21da /mllib/src
parentc2520f501a200cf794bbe5dc9c385100f518d020 (diff)
downloadspark-6c5858bc65c8a8602422b46bfa9cf0a1fb296b88.tar.gz
spark-6c5858bc65c8a8602422b46bfa9cf0a1fb296b88.tar.bz2
spark-6c5858bc65c8a8602422b46bfa9cf0a1fb296b88.zip
[SPARK-9922] [ML] rename StringIndexerReverse to IndexToString
What `StringIndexerInverse` does is not strictly associated with `StringIndexer`, and the name is not clearly describing the transformation. Renaming to `IndexToString` might be better. ~~I also changed `invert` to `inverse` without arguments. `inputCol` and `outputCol` could be set after.~~ I also removed `invert`. jkbradley holdenk Author: Xiangrui Meng <meng@databricks.com> Closes #8152 from mengxr/SPARK-9922.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala34
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala50
2 files changed, 48 insertions, 36 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 9e4b0f0add..9f6e7b6b6b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.Transformer
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType}
@@ -59,6 +59,8 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
* If the input column is numeric, we cast it to string and index the string values.
* The indices are in [0, numLabels), ordered by label frequencies.
* So the most frequent label gets index 0.
+ *
+ * @see [[IndexToString]] for the inverse transformation
*/
@Experimental
class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel]
@@ -170,34 +172,24 @@ class StringIndexerModel private[ml] (
val copied = new StringIndexerModel(uid, labels)
copyValues(copied, extra).setParent(parent)
}
-
- /**
- * Return a model to perform the inverse transformation.
- * Note: By default we keep the original columns during this transformation, so the inverse
- * should only be used on new columns such as predicted labels.
- */
- def invert(inputCol: String, outputCol: String): StringIndexerInverse = {
- new StringIndexerInverse()
- .setInputCol(inputCol)
- .setOutputCol(outputCol)
- .setLabels(labels)
- }
}
/**
* :: Experimental ::
- * Transform a provided column back to the original input types using either the metadata
- * on the input column, or if provided using the labels supplied by the user.
- * Note: By default we keep the original columns during this transformation,
- * so the inverse should only be used on new columns such as predicted labels.
+ * A [[Transformer]] that maps a column of string indices back to a new column of corresponding
+ * string values using either the ML attributes of the input column, or if provided using the labels
+ * supplied by the user.
+ * All original columns are kept during transformation.
+ *
+ * @see [[StringIndexer]] for converting strings into indices
*/
@Experimental
-class StringIndexerInverse private[ml] (
+class IndexToString private[ml] (
override val uid: String) extends Transformer
with HasInputCol with HasOutputCol {
def this() =
- this(Identifiable.randomUID("strIdxInv"))
+ this(Identifiable.randomUID("idxToStr"))
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -257,7 +249,7 @@ class StringIndexerInverse private[ml] (
}
val indexer = udf { index: Double =>
val idx = index.toInt
- if (0 <= idx && idx < values.size) {
+ if (0 <= idx && idx < values.length) {
values(idx)
} else {
throw new SparkException(s"Unseen index: $index ??")
@@ -268,7 +260,7 @@ class StringIndexerInverse private[ml] (
indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName))
}
- override def copy(extra: ParamMap): StringIndexerInverse = {
+ override def copy(extra: ParamMap): IndexToString = {
defaultCopy(extra)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 2d24914cb9..fa918ce648 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -17,12 +17,13 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkException
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.functions.col
class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -53,19 +54,6 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
// a -> 0, b -> 2, c -> 1
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
assert(output === expected)
- // convert reverse our transform
- val reversed = indexer.invert("labelIndex", "label2")
- .transform(transformed)
- .select("id", "label2")
- assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet ===
- reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
- // Check invert using only metadata
- val inverse2 = new StringIndexerInverse()
- .setInputCol("labelIndex")
- .setOutputCol("label2")
- val reversed2 = inverse2.transform(transformed).select("id", "label2")
- assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet ===
- reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
}
test("StringIndexerUnseen") {
@@ -125,4 +113,36 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
val df = sqlContext.range(0L, 10L)
assert(indexerModel.transform(df).eq(df))
}
+
+ test("IndexToString params") {
+ val idxToStr = new IndexToString()
+ ParamsSuite.checkParams(idxToStr)
+ }
+
+ test("IndexToString.transform") {
+ val labels = Array("a", "b", "c")
+ val df0 = sqlContext.createDataFrame(Seq(
+ (0, "a"), (1, "b"), (2, "c"), (0, "a")
+ )).toDF("index", "expected")
+
+ val idxToStr0 = new IndexToString()
+ .setInputCol("index")
+ .setOutputCol("actual")
+ .setLabels(labels)
+ idxToStr0.transform(df0).select("actual", "expected").collect().foreach {
+ case Row(actual, expected) =>
+ assert(actual === expected)
+ }
+
+ val attr = NominalAttribute.defaultAttr.withValues(labels)
+ val df1 = df0.select(col("index").as("indexWithAttr", attr.toMetadata()), col("expected"))
+
+ val idxToStr1 = new IndexToString()
+ .setInputCol("indexWithAttr")
+ .setOutputCol("actual")
+ idxToStr1.transform(df1).select("actual", "expected").collect().foreach {
+ case Row(actual, expected) =>
+ assert(actual === expected)
+ }
+ }
}