diff options
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 99f82bea42..d0295a0fe2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -47,6 +47,19 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { // a -> 0, b -> 2, c -> 1 val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) + // convert reverse our transform + val reversed = indexer.invert("labelIndex", "label2") + .transform(transformed) + .select("id", "label2") + assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === + reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet) + // Check invert using only metadata + val inverse2 = new StringIndexerInverse() + .setInputCol("labelIndex") + .setOutputCol("label2") + val reversed2 = inverse2.transform(transformed).select("id", "label2") + assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === + reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet) } test("StringIndexer with a numeric input column") { |