aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-03-25 16:00:09 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-25 16:00:09 -0700
commit54d13bed87fcf2968f77e1f1153e85184ec91d78 (patch)
treefe1aba5bcf9c5af86db892a6fd2facfa7114bab0 /mllib/src/test/scala/org/apache
parentff7cc45f521c63ce40f955b8995d52a79dca17b4 (diff)
downloadspark-54d13bed87fcf2968f77e1f1153e85184ec91d78.tar.gz
spark-54d13bed87fcf2968f77e1f1153e85184ec91d78.tar.bz2
spark-54d13bed87fcf2968f77e1f1153e85184ec91d78.zip
[SPARK-14159][ML] Fixed bug in StringIndexer + related issue in RFormula
## What changes were proposed in this pull request? StringIndexerModel.transform sets the output column metadata to use name inputCol. It should not. Fixing this causes a problem with the metadata produced by RFormula. Fix in RFormula: I added the StringIndexer columns to prefixesToRewrite, and I modified VectorAttributeRewriter to find and replace all "prefixes" since attributes collect multiple prefixes from StringIndexer + Interaction. Note that "prefixes" is no longer accurate since internal strings may be replaced. ## How was this patch tested? Unit test which failed before this fix. Author: Joseph K. Bradley <joseph@databricks.com> Closes #11965 from jkbradley/StringIndexer-fix.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala13
1 files changed, 13 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index d40e69dced..2c3255ef33 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -210,4 +210,17 @@ class StringIndexerSuite
.setLabels(Array("a", "b", "c"))
testDefaultReadWrite(t)
}
+
+ test("StringIndexer metadata") {
+ val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
+ val df = sqlContext.createDataFrame(data).toDF("id", "label")
+ val indexer = new StringIndexer()
+ .setInputCol("label")
+ .setOutputCol("labelIndex")
+ .fit(df)
+ val transformed = indexer.transform(df)
+ val attrs =
+ NominalAttribute.decodeStructField(transformed.schema("labelIndex"), preserveName = true)
+ assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex")
+ }
}