aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-08-14 14:05:03 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-14 14:05:03 -0700
commit2a6590e510aba3bfc6603d280023128b3f5ac702 (patch)
tree9c6c53fb2ac079c2498f1d9d2d73fc1c4bd60088 /mllib/src/test
parent11ed2b180ec86523a94679a8b8132fadb911ccd5 (diff)
downloadspark-2a6590e510aba3bfc6603d280023128b3f5ac702.tar.gz
spark-2a6590e510aba3bfc6603d280023128b3f5ac702.tar.bz2
spark-2a6590e510aba3bfc6603d280023128b3f5ac702.zip
[SPARK-9981] [ML] Made labels public for StringIndexerModel
Also added unit test for integration between StringIndexerModel and IndexToString CC: holdenk We realized we should have left in your unit test (to catch the issue with removing the inverse() method), so this adds it back. mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #8211 from jkbradley/stridx-labels.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala18
1 files changed, 18 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 0b4c8ba71e..05e05bdc64 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -147,4 +147,22 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(actual === expected)
}
}
+
+ test("StringIndexer, IndexToString are inverses") {
+ val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
+ val df = sqlContext.createDataFrame(data).toDF("id", "label")
+ val indexer = new StringIndexer()
+ .setInputCol("label")
+ .setOutputCol("labelIndex")
+ .fit(df)
+ val transformed = indexer.transform(df)
+ val idx2str = new IndexToString()
+ .setInputCol("labelIndex")
+ .setOutputCol("sameLabel")
+ .setLabels(indexer.labels)
+ idx2str.transform(transformed).select("label", "sameLabel").collect().foreach {
+ case Row(a: String, b: String) =>
+ assert(a === b)
+ }
+ }
}