aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala45
1 files changed, 45 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 188dffb3dd..8d9042b31e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -122,6 +122,51 @@ class StringIndexerSuite
assert(output === expected)
}
+ test("StringIndexer with NULLs") {
+ val data: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (2, "b"), (3, null))
+ val data2: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (3, null))
+ val df = data.toDF("id", "label")
+ val df2 = data2.toDF("id", "label")
+
+ val indexer = new StringIndexer()
+ .setInputCol("label")
+ .setOutputCol("labelIndex")
+
+ withClue("StringIndexer should throw error when setHandleInvalid=error " +
+ "when given NULL values") {
+ intercept[SparkException] {
+ indexer.setHandleInvalid("error")
+ indexer.fit(df).transform(df2).collect()
+ }
+ }
+
+ indexer.setHandleInvalid("skip")
+ val transformedSkip = indexer.fit(df).transform(df2)
+ val attrSkip = Attribute
+ .fromStructField(transformedSkip.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attrSkip.values.get === Array("b", "a"))
+ val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r =>
+ (r.getInt(0), r.getDouble(1))
+ }.collect().toSet
+ // a -> 1, b -> 0
+ val expectedSkip = Set((0, 1.0), (1, 0.0))
+ assert(outputSkip === expectedSkip)
+
+ indexer.setHandleInvalid("keep")
+ val transformedKeep = indexer.fit(df).transform(df2)
+ val attrKeep = Attribute
+ .fromStructField(transformedKeep.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attrKeep.values.get === Array("b", "a", "__unknown"))
+ val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r =>
+ (r.getInt(0), r.getDouble(1))
+ }.collect().toSet
+ // a -> 1, b -> 0, null -> 2
+ val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0))
+ assert(outputKeep === expectedKeep)
+ }
+
test("StringIndexerModel should keep silent if the input column does not exist.") {
val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
.setInputCol("label")