aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYu ISHIKAWA <yuu.ishikawa@gmail.com>2016-02-25 13:21:33 -0800
committerXiangrui Meng <meng@databricks.com>2016-02-25 13:21:33 -0800
commit14e2700de29d06460179a94cc9816bcd37344cf7 (patch)
tree727441369903be48e3033201d60285f2bd8f23b7 /mllib
parentfb8bb04766005e8935607069c0155d639f407e8a (diff)
downloadspark-14e2700de29d06460179a94cc9816bcd37344cf7.tar.gz
spark-14e2700de29d06460179a94cc9816bcd37344cf7.tar.bz2
spark-14e2700de29d06460179a94cc9816bcd37344cf7.zip
[SPARK-12874][ML] ML StringIndexer does not protect itself from column name duplication
## What changes were proposed in this pull request? ML StringIndexer does not protect itself from column name duplication. We should still improve a way to validate a schema of `StringIndexer` and `StringIndexerModel`. However, it would be great to fix at another issue. ## How was this patch tested? unit test Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com> Closes #11370 from yu-iskw/SPARK-12874.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala11
2 files changed, 12 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 912bd95a2e..555f1130e4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -150,6 +150,7 @@ class StringIndexerModel (
"Skip StringIndexerModel.")
return dataset
}
+ validateAndTransformSchema(dataset.schema)
val indexer = udf { label: String =>
if (labelToIndex.contains(label)) {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 5d199ca9b5..0dbaed2522 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -118,6 +118,17 @@ class StringIndexerSuite
assert(indexerModel.transform(df).eq(df))
}
+ test("StringIndexerModel can't overwrite output column") {
+ val df = sqlContext.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output")
+ val indexer = new StringIndexer()
+ .setInputCol("input")
+ .setOutputCol("output")
+ .fit(df)
+ intercept[IllegalArgumentException] {
+ indexer.transform(df)
+ }
+ }
+
test("StringIndexer read/write") {
val t = new StringIndexer()
.setInputCol("myInputCol")