aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala13
1 files changed, 13 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 99f82bea42..d0295a0fe2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -47,6 +47,19 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
// a -> 0, b -> 2, c -> 1
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
assert(output === expected)
+ // convert reverse our transform
+ val reversed = indexer.invert("labelIndex", "label2")
+ .transform(transformed)
+ .select("id", "label2")
+ assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet ===
+ reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
+ // Check invert using only metadata
+ val inverse2 = new StringIndexerInverse()
+ .setInputCol("labelIndex")
+ .setOutputCol("label2")
+ val reversed2 = inverse2.transform(transformed).select("id", "label2")
+ assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet ===
+ reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
}
test("StringIndexer with a numeric input column") {