diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2016-04-27 16:11:12 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-27 16:11:12 -0700 |
commit | f5ebb18c45ffdee2756a80f64239cb9158df1a11 (patch) | |
tree | d7994c8299067f2118b3d9fedd0b41ea0a1cb438 /mllib/src/test | |
parent | 6466d6c8a47273f08451ab5950d31d130c685c7a (diff) | |
download | spark-f5ebb18c45ffdee2756a80f64239cb9158df1a11.tar.gz spark-f5ebb18c45ffdee2756a80f64239cb9158df1a11.tar.bz2 spark-f5ebb18c45ffdee2756a80f64239cb9158df1a11.zip |
[SPARK-14671][ML] Pipeline setStages should handle subclasses of PipelineStage
## What changes were proposed in this pull request?
Pipeline.setStages failed for some code examples which worked in 1.5 but fail in 1.6. This tends to occur when using a mix of transformers from ml.feature. It is because Java Arrays are non-covariant and the addition of MLWritable to some transformers means the stages0/1 arrays above are not of type Array[PipelineStage]. This PR modifies the following to accept subclasses of PipelineStage:
* Pipeline.setStages()
* Params.w()
## How was this patch tested?
Unit test which fails to compile before this fix.
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #12430 from jkbradley/pipeline-setstages.
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index a8c4ac6d05..1de638f245 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite import org.apache.spark.ml.Pipeline.SharedReadWrite import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler} -import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamPair} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -201,6 +201,13 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul pipeline.fit(df) } } + + test("Pipeline.setStages should handle Java Arrays being non-covariant") { + val stages0 = Array(new UnWritableStage("b")) + val stages1 = Array(new WritableStage("a")) + val steps = stages0 ++ stages1 + val p = new Pipeline().setStages(steps) + } } |