aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-08-13 16:52:17 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-13 16:52:17 -0700
commit6c5858bc65c8a8602422b46bfa9cf0a1fb296b88 (patch)
treed666cc0ed10832e350109a1eb9fd554a75ae21da /mllib/src/main
parentc2520f501a200cf794bbe5dc9c385100f518d020 (diff)
downloadspark-6c5858bc65c8a8602422b46bfa9cf0a1fb296b88.tar.gz
spark-6c5858bc65c8a8602422b46bfa9cf0a1fb296b88.tar.bz2
spark-6c5858bc65c8a8602422b46bfa9cf0a1fb296b88.zip
[SPARK-9922] [ML] rename StringIndexerReverse to IndexToString
What `StringIndexerInverse` does is not strictly associated with `StringIndexer`, and the name is not clearly describing the transformation. Renaming to `IndexToString` might be better. ~~I also changed `invert` to `inverse` without arguments. `inputCol` and `outputCol` could be set after.~~ I also removed `invert`. jkbradley holdenk Author: Xiangrui Meng <meng@databricks.com> Closes #8152 from mengxr/SPARK-9922.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala34
1 files changed, 13 insertions, 21 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 9e4b0f0add..9f6e7b6b6b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.Transformer
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType}
@@ -59,6 +59,8 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
* If the input column is numeric, we cast it to string and index the string values.
* The indices are in [0, numLabels), ordered by label frequencies.
* So the most frequent label gets index 0.
+ *
+ * @see [[IndexToString]] for the inverse transformation
*/
@Experimental
class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel]
@@ -170,34 +172,24 @@ class StringIndexerModel private[ml] (
val copied = new StringIndexerModel(uid, labels)
copyValues(copied, extra).setParent(parent)
}
-
- /**
- * Return a model to perform the inverse transformation.
- * Note: By default we keep the original columns during this transformation, so the inverse
- * should only be used on new columns such as predicted labels.
- */
- def invert(inputCol: String, outputCol: String): StringIndexerInverse = {
- new StringIndexerInverse()
- .setInputCol(inputCol)
- .setOutputCol(outputCol)
- .setLabels(labels)
- }
}
/**
* :: Experimental ::
- * Transform a provided column back to the original input types using either the metadata
- * on the input column, or if provided using the labels supplied by the user.
- * Note: By default we keep the original columns during this transformation,
- * so the inverse should only be used on new columns such as predicted labels.
+ * A [[Transformer]] that maps a column of string indices back to a new column of corresponding
+ * string values using either the ML attributes of the input column, or if provided using the labels
+ * supplied by the user.
+ * All original columns are kept during transformation.
+ *
+ * @see [[StringIndexer]] for converting strings into indices
*/
@Experimental
-class StringIndexerInverse private[ml] (
+class IndexToString private[ml] (
override val uid: String) extends Transformer
with HasInputCol with HasOutputCol {
def this() =
- this(Identifiable.randomUID("strIdxInv"))
+ this(Identifiable.randomUID("idxToStr"))
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -257,7 +249,7 @@ class StringIndexerInverse private[ml] (
}
val indexer = udf { index: Double =>
val idx = index.toInt
- if (0 <= idx && idx < values.size) {
+ if (0 <= idx && idx < values.length) {
values(idx)
} else {
throw new SparkException(s"Unseen index: $index ??")
@@ -268,7 +260,7 @@ class StringIndexerInverse private[ml] (
indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName))
}
- override def copy(extra: ParamMap): StringIndexerInverse = {
+ override def copy(extra: ParamMap): IndexToString = {
defaultCopy(extra)
}
}