diff options
author | Yu ISHIKAWA <yuu.ishikawa@gmail.com> | 2016-02-25 13:21:33 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-02-25 13:21:33 -0800 |
commit | 14e2700de29d06460179a94cc9816bcd37344cf7 (patch) | |
tree | 727441369903be48e3033201d60285f2bd8f23b7 /mllib | |
parent | fb8bb04766005e8935607069c0155d639f407e8a (diff) | |
download | spark-14e2700de29d06460179a94cc9816bcd37344cf7.tar.gz spark-14e2700de29d06460179a94cc9816bcd37344cf7.tar.bz2 spark-14e2700de29d06460179a94cc9816bcd37344cf7.zip |
[SPARK-12874][ML] ML StringIndexer does not protect itself from column name duplication
## What changes were proposed in this pull request?
ML StringIndexer does not protect itself from column name duplication.
We should still improve a way to validate a schema of `StringIndexer` and `StringIndexerModel`. However, it would be great to fix at another issue.
## How was this patch tested?
unit test
Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com>
Closes #11370 from yu-iskw/SPARK-12874.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala | 1 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala | 11 |
2 files changed, 12 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 912bd95a2e..555f1130e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -150,6 +150,7 @@ class StringIndexerModel ( "Skip StringIndexerModel.") return dataset } + validateAndTransformSchema(dataset.schema) val indexer = udf { label: String => if (labelToIndex.contains(label)) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 5d199ca9b5..0dbaed2522 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -118,6 +118,17 @@ class StringIndexerSuite assert(indexerModel.transform(df).eq(df)) } + test("StringIndexerModel can't overwrite output column") { + val df = sqlContext.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output") + val indexer = new StringIndexer() + .setInputCol("input") + .setOutputCol("output") + .fit(df) + intercept[IllegalArgumentException] { + indexer.transform(df) + } + } + test("StringIndexer read/write") { val t = new StringIndexer() .setInputCol("myInputCol") |