From 85941ecf28362f35718ebcd3a22dbb17adb49154 Mon Sep 17 00:00:00 2001 From: Menglong TAN Date: Tue, 14 Mar 2017 07:45:42 -0700 Subject: [SPARK-11569][ML] Fix StringIndexer to handle null value properly ## What changes were proposed in this pull request? This PR is to enhance StringIndexer with NULL values handling. Before the PR, StringIndexer will throw an exception when encounters NULL values. With this PR: - handleInvalid=error: Throw an exception as before - handleInvalid=skip: Skip null values as well as unseen labels - handleInvalid=keep: Give null values an additional index as well as unseen labels BTW, I noticed someone was trying to solve the same problem ( #9920 ) but seems getting no progress or response for a long time. Would you mind to give me a chance to solve it ? I'm eager to help. :-) ## How was this patch tested? new unit tests Author: Menglong TAN Author: Menglong TAN Closes #17233 from crackcell/11569_StringIndexer_NULL. --- .../spark/ml/feature/StringIndexerSuite.scala | 45 ++++++++++++++++++++++ 1 file changed, 45 insertions(+) (limited to 'mllib/src/test') diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 188dffb3dd..8d9042b31e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -122,6 +122,51 @@ class StringIndexerSuite assert(output === expected) } + test("StringIndexer with NULLs") { + val data: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (2, "b"), (3, null)) + val data2: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (3, null)) + val df = data.toDF("id", "label") + val df2 = data2.toDF("id", "label") + + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + + withClue("StringIndexer should throw error when setHandleInvalid=error " + + "when given NULL values") { + intercept[SparkException] { + indexer.setHandleInvalid("error") + indexer.fit(df).transform(df2).collect() + } + } + + indexer.setHandleInvalid("skip") + val transformedSkip = indexer.fit(df).transform(df2) + val attrSkip = Attribute + .fromStructField(transformedSkip.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrSkip.values.get === Array("b", "a")) + val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0 + val expectedSkip = Set((0, 1.0), (1, 0.0)) + assert(outputSkip === expectedSkip) + + indexer.setHandleInvalid("keep") + val transformedKeep = indexer.fit(df).transform(df2) + val attrKeep = Attribute + .fromStructField(transformedKeep.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0, null -> 2 + val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0)) + assert(outputKeep === expectedKeep) + } + test("StringIndexerModel should keep silent if the input column does not exist.") { val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) .setInputCol("label") -- cgit v1.2.3