aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-08-05 22:07:59 +0100
committerSean Owen <sowen@cloudera.com>2016-08-05 22:07:59 +0100
commit6cbde337a539e5bb170d0eb81f715a95ee9c9af3 (patch)
treebb169670d130594c9130a2c1b600b2975301e664 /mllib/src/test
parent1f96c97f2374a95140a0c72b1f4eae50ac21d84a (diff)
downloadspark-6cbde337a539e5bb170d0eb81f715a95ee9c9af3.tar.gz
spark-6cbde337a539e5bb170d0eb81f715a95ee9c9af3.tar.bz2
spark-6cbde337a539e5bb170d0eb81f715a95ee9c9af3.zip
[SPARK-16750][FOLLOW-UP][ML] Add transformSchema for StringIndexer/VectorAssembler and fix failed tests.
## What changes were proposed in this pull request? This is follow-up for #14378. When we add ```transformSchema``` for all estimators and transformers, I found there are tests failed for ```StringIndexer``` and ```VectorAssembler```. So I moved these parts of work separately in this PR, to make it more clear to review. The corresponding tests should throw ```IllegalArgumentException``` at schema validation period after we add ```transformSchema```. It's efficient that to throw exception at the start of ```fit``` or ```transform``` rather than during the process. ## How was this patch tested? Modified unit tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #14455 from yanboliang/transformSchema.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala4
2 files changed, 12 insertions, 4 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index c221d4aa55..b478fea5e7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -120,12 +120,20 @@ class StringIndexerSuite
test("StringIndexerModel can't overwrite output column") {
val df = spark.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output")
+ intercept[IllegalArgumentException] {
+ new StringIndexer()
+ .setInputCol("input")
+ .setOutputCol("output")
+ .fit(df)
+ }
+
val indexer = new StringIndexer()
.setInputCol("input")
- .setOutputCol("output")
+ .setOutputCol("indexedInput")
.fit(df)
+
intercept[IllegalArgumentException] {
- indexer.transform(df)
+ indexer.setOutputCol("output").transform(df)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index 14973e79bf..561493fbaf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -74,10 +74,10 @@ class VectorAssemblerSuite
val assembler = new VectorAssembler()
.setInputCols(Array("a", "b", "c"))
.setOutputCol("features")
- val thrown = intercept[SparkException] {
+ val thrown = intercept[IllegalArgumentException] {
assembler.transform(df)
}
- assert(thrown.getMessage contains "VectorAssembler does not support the StringType type")
+ assert(thrown.getMessage contains "Data type StringType is not supported")
}
test("ML attributes") {