diff options
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 188dffb3dd..8d9042b31e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -122,6 +122,51 @@ class StringIndexerSuite assert(output === expected) } + test("StringIndexer with NULLs") { + val data: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (2, "b"), (3, null)) + val data2: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (3, null)) + val df = data.toDF("id", "label") + val df2 = data2.toDF("id", "label") + + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + + withClue("StringIndexer should throw error when setHandleInvalid=error " + + "when given NULL values") { + intercept[SparkException] { + indexer.setHandleInvalid("error") + indexer.fit(df).transform(df2).collect() + } + } + + indexer.setHandleInvalid("skip") + val transformedSkip = indexer.fit(df).transform(df2) + val attrSkip = Attribute + .fromStructField(transformedSkip.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrSkip.values.get === Array("b", "a")) + val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0 + val expectedSkip = Set((0, 1.0), (1, 0.0)) + assert(outputSkip === expectedSkip) + + indexer.setHandleInvalid("keep") + val transformedKeep = indexer.fit(df).transform(df2) + val attrKeep = Attribute + .fromStructField(transformedKeep.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0, null -> 2 + val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0)) + assert(outputKeep === expectedKeep) + } + test("StringIndexerModel should keep silent if the input column does not exist.") { val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) .setInputCol("label") |