aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala11
2 files changed, 12 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 912bd95a2e..555f1130e4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -150,6 +150,7 @@ class StringIndexerModel (
"Skip StringIndexerModel.")
return dataset
}
+ validateAndTransformSchema(dataset.schema)
val indexer = udf { label: String =>
if (labelToIndex.contains(label)) {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 5d199ca9b5..0dbaed2522 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -118,6 +118,17 @@ class StringIndexerSuite
assert(indexerModel.transform(df).eq(df))
}
+ test("StringIndexerModel can't overwrite output column") {
+ val df = sqlContext.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output")
+ val indexer = new StringIndexer()
+ .setInputCol("input")
+ .setOutputCol("output")
+ .fit(df)
+ intercept[IllegalArgumentException] {
+ indexer.transform(df)
+ }
+ }
+
test("StringIndexer read/write") {
val t = new StringIndexer()
.setInputCol("myInputCol")