aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-08-11 11:33:36 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-11 11:33:36 -0700
commitdbd778d84d094ca142bc08c351478595b280bc2a (patch)
treeb7de552c15cca4f7317566aba5680e7538210a88 /mllib/src/test
parent8cad854ef6a2066de5adffcca6b79a205ccfd5f3 (diff)
downloadspark-dbd778d84d094ca142bc08c351478595b280bc2a.tar.gz
spark-dbd778d84d094ca142bc08c351478595b280bc2a.tar.bz2
spark-dbd778d84d094ca142bc08c351478595b280bc2a.zip
[SPARK-8764] [ML] string indexer should take option to handle unseen values
As a precursor to adding a public constructor add an option to handle unseen values by skipping rather than throwing an exception (default remains throwing an exception), Author: Holden Karau <holden@pigscanfly.ca> Closes #7266 from holdenk/SPARK-8764-string-indexer-should-take-option-to-handle-unseen-values and squashes the following commits: 38a4de9 [Holden Karau] fix long line 045bf22 [Holden Karau] Add a second b entry so b gets 0 for sure 81dd312 [Holden Karau] Update the docs for handleInvalid param to be more descriptive 7f37f6e [Holden Karau] remove extra space (scala style) 414e249 [Holden Karau] And switch to using handleInvalid instead of skipInvalid 1e53f9b [Holden Karau] update the param (codegen side) 7a22215 [Holden Karau] fix typo 100a39b [Holden Karau] Merge in master aa5b093 [Holden Karau] Since we filter we should never go down this code path if getSkipInvalid is true 75ffa69 [Holden Karau] Remove extra newline d69ef5e [Holden Karau] Add a test b5734be [Holden Karau] Add support for unseen labels afecd4e [Holden Karau] Add a param to skip invalid entries.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala32
1 files changed, 32 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index d0295a0fe2..b111036087 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.feature
+import org.apache.spark.SparkException
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
@@ -62,6 +63,37 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
}
+ test("StringIndexerUnseen") {
+ val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2)
+ val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2)
+ val df = sqlContext.createDataFrame(data).toDF("id", "label")
+ val df2 = sqlContext.createDataFrame(data2).toDF("id", "label")
+ val indexer = new StringIndexer()
+ .setInputCol("label")
+ .setOutputCol("labelIndex")
+ .fit(df)
+ // Verify we throw by default with unseen values
+ intercept[SparkException] {
+ indexer.transform(df2).collect()
+ }
+ val indexerSkipInvalid = new StringIndexer()
+ .setInputCol("label")
+ .setOutputCol("labelIndex")
+ .setHandleInvalid("skip")
+ .fit(df)
+ // Verify that we skip the c record
+ val transformed = indexerSkipInvalid.transform(df2)
+ val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attr.values.get === Array("b", "a"))
+ val output = transformed.select("id", "labelIndex").map { r =>
+ (r.getInt(0), r.getDouble(1))
+ }.collect().toSet
+ // a -> 1, b -> 0
+ val expected = Set((0, 1.0), (1, 0.0))
+ assert(output === expected)
+ }
+
test("StringIndexer with a numeric input column") {
val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")