aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorMenglong TAN <tanmenglong@renrenche.com>2017-03-14 07:45:42 -0700
committerJoseph K. Bradley <joseph@databricks.com>2017-03-14 07:45:42 -0700
commit85941ecf28362f35718ebcd3a22dbb17adb49154 (patch)
tree37bdc1a9558c4c07a26991f115132c13cdbecf17 /mllib/src/test
parentd4a637cd46b6dd5cc71ea17a55c4a26186e592c7 (diff)
downloadspark-85941ecf28362f35718ebcd3a22dbb17adb49154.tar.gz
spark-85941ecf28362f35718ebcd3a22dbb17adb49154.tar.bz2
spark-85941ecf28362f35718ebcd3a22dbb17adb49154.zip
[SPARK-11569][ML] Fix StringIndexer to handle null value properly
## What changes were proposed in this pull request? This PR is to enhance StringIndexer with NULL values handling. Before the PR, StringIndexer will throw an exception when encounters NULL values. With this PR: - handleInvalid=error: Throw an exception as before - handleInvalid=skip: Skip null values as well as unseen labels - handleInvalid=keep: Give null values an additional index as well as unseen labels BTW, I noticed someone was trying to solve the same problem ( #9920 ) but seems getting no progress or response for a long time. Would you mind to give me a chance to solve it ? I'm eager to help. :-) ## How was this patch tested? new unit tests Author: Menglong TAN <tanmenglong@renrenche.com> Author: Menglong TAN <tanmenglong@gmail.com> Closes #17233 from crackcell/11569_StringIndexer_NULL.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala45
1 files changed, 45 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 188dffb3dd..8d9042b31e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -122,6 +122,51 @@ class StringIndexerSuite
assert(output === expected)
}
+ test("StringIndexer with NULLs") {
+ val data: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (2, "b"), (3, null))
+ val data2: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (3, null))
+ val df = data.toDF("id", "label")
+ val df2 = data2.toDF("id", "label")
+
+ val indexer = new StringIndexer()
+ .setInputCol("label")
+ .setOutputCol("labelIndex")
+
+ withClue("StringIndexer should throw error when setHandleInvalid=error " +
+ "when given NULL values") {
+ intercept[SparkException] {
+ indexer.setHandleInvalid("error")
+ indexer.fit(df).transform(df2).collect()
+ }
+ }
+
+ indexer.setHandleInvalid("skip")
+ val transformedSkip = indexer.fit(df).transform(df2)
+ val attrSkip = Attribute
+ .fromStructField(transformedSkip.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attrSkip.values.get === Array("b", "a"))
+ val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r =>
+ (r.getInt(0), r.getDouble(1))
+ }.collect().toSet
+ // a -> 1, b -> 0
+ val expectedSkip = Set((0, 1.0), (1, 0.0))
+ assert(outputSkip === expectedSkip)
+
+ indexer.setHandleInvalid("keep")
+ val transformedKeep = indexer.fit(df).transform(df2)
+ val attrKeep = Attribute
+ .fromStructField(transformedKeep.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attrKeep.values.get === Array("b", "a", "__unknown"))
+ val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r =>
+ (r.getInt(0), r.getDouble(1))
+ }.collect().toSet
+ // a -> 1, b -> 0, null -> 2
+ val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0))
+ assert(outputKeep === expectedKeep)
+ }
+
test("StringIndexerModel should keep silent if the input column does not exist.") {
val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
.setInputCol("label")