diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2016-03-25 16:00:09 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-03-25 16:00:09 -0700 |
commit | 54d13bed87fcf2968f77e1f1153e85184ec91d78 (patch) | |
tree | fe1aba5bcf9c5af86db892a6fd2facfa7114bab0 /mllib/src/test/scala/org | |
parent | ff7cc45f521c63ce40f955b8995d52a79dca17b4 (diff) | |
download | spark-54d13bed87fcf2968f77e1f1153e85184ec91d78.tar.gz spark-54d13bed87fcf2968f77e1f1153e85184ec91d78.tar.bz2 spark-54d13bed87fcf2968f77e1f1153e85184ec91d78.zip |
[SPARK-14159][ML] Fixed bug in StringIndexer + related issue in RFormula
## What changes were proposed in this pull request?
StringIndexerModel.transform sets the output column metadata to use name inputCol. It should not. Fixing this causes a problem with the metadata produced by RFormula.
Fix in RFormula: I added the StringIndexer columns to prefixesToRewrite, and I modified VectorAttributeRewriter to find and replace all "prefixes" since attributes collect multiple prefixes from StringIndexer + Interaction.
Note that "prefixes" is no longer accurate since internal strings may be replaced.
## How was this patch tested?
Unit test which failed before this fix.
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #11965 from jkbradley/StringIndexer-fix.
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index d40e69dced..2c3255ef33 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -210,4 +210,17 @@ class StringIndexerSuite .setLabels(Array("a", "b", "c")) testDefaultReadWrite(t) } + + test("StringIndexer metadata") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val transformed = indexer.transform(df) + val attrs = + NominalAttribute.decodeStructField(transformed.schema("labelIndex"), preserveName = true) + assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex") + } } |