From 6cbde337a539e5bb170d0eb81f715a95ee9c9af3 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 5 Aug 2016 22:07:59 +0100 Subject: [SPARK-16750][FOLLOW-UP][ML] Add transformSchema for StringIndexer/VectorAssembler and fix failed tests. ## What changes were proposed in this pull request? This is follow-up for #14378. When we add ```transformSchema``` for all estimators and transformers, I found there are tests failed for ```StringIndexer``` and ```VectorAssembler```. So I moved these parts of work separately in this PR, to make it more clear to review. The corresponding tests should throw ```IllegalArgumentException``` at schema validation period after we add ```transformSchema```. It's efficient that to throw exception at the start of ```fit``` or ```transform``` rather than during the process. ## How was this patch tested? Modified unit tests. Author: Yanbo Liang Closes #14455 from yanboliang/transformSchema. --- .../org/apache/spark/ml/feature/StringIndexerSuite.scala | 12 ++++++++++-- .../org/apache/spark/ml/feature/VectorAssemblerSuite.scala | 4 ++-- 2 files changed, 12 insertions(+), 4 deletions(-) (limited to 'mllib/src/test') diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index c221d4aa55..b478fea5e7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -120,12 +120,20 @@ class StringIndexerSuite test("StringIndexerModel can't overwrite output column") { val df = spark.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output") + intercept[IllegalArgumentException] { + new StringIndexer() + .setInputCol("input") + .setOutputCol("output") + .fit(df) + } + val indexer = new StringIndexer() .setInputCol("input") - .setOutputCol("output") + .setOutputCol("indexedInput") .fit(df) + intercept[IllegalArgumentException] { - indexer.transform(df) + indexer.setOutputCol("output").transform(df) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 14973e79bf..561493fbaf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -74,10 +74,10 @@ class VectorAssemblerSuite val assembler = new VectorAssembler() .setInputCols(Array("a", "b", "c")) .setOutputCol("features") - val thrown = intercept[SparkException] { + val thrown = intercept[IllegalArgumentException] { assembler.transform(df) } - assert(thrown.getMessage contains "VectorAssembler does not support the StringType type") + assert(thrown.getMessage contains "Data type StringType is not supported") } test("ML attributes") { -- cgit v1.2.3