diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2016-04-27 16:11:12 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-27 16:11:12 -0700 |
commit | f5ebb18c45ffdee2756a80f64239cb9158df1a11 (patch) | |
tree | d7994c8299067f2118b3d9fedd0b41ea0a1cb438 /mllib/src | |
parent | 6466d6c8a47273f08451ab5950d31d130c685c7a (diff) | |
download | spark-f5ebb18c45ffdee2756a80f64239cb9158df1a11.tar.gz spark-f5ebb18c45ffdee2756a80f64239cb9158df1a11.tar.bz2 spark-f5ebb18c45ffdee2756a80f64239cb9158df1a11.zip |
[SPARK-14671][ML] Pipeline setStages should handle subclasses of PipelineStage
## What changes were proposed in this pull request?
Pipeline.setStages failed for some code examples which worked in 1.5 but fail in 1.6. This tends to occur when using a mix of transformers from ml.feature. It is because Java Arrays are non-covariant and the addition of MLWritable to some transformers means the stages0/1 arrays above are not of type Array[PipelineStage]. This PR modifies the following to accept subclasses of PipelineStage:
* Pipeline.setStages()
* Params.w()
## How was this patch tested?
Unit test which fails to compile before this fix.
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #12430 from jkbradley/pipeline-setstages.
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala | 5 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala | 9 |
2 files changed, 12 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 82066726a0..b02aea92b7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -103,7 +103,10 @@ class Pipeline @Since("1.4.0") ( /** @group setParam */ @Since("1.2.0") - def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } + def setStages(value: Array[_ <: PipelineStage]): this.type = { + set(stages, value.asInstanceOf[Array[PipelineStage]]) + this + } // Below, we clone stages so that modifications to the list of stages will not change // the Param value in the Pipeline. diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index a8c4ac6d05..1de638f245 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite import org.apache.spark.ml.Pipeline.SharedReadWrite import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler} -import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamPair} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -201,6 +201,13 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul pipeline.fit(df) } } + + test("Pipeline.setStages should handle Java Arrays being non-covariant") { + val stages0 = Array(new UnWritableStage("b")) + val stages1 = Array(new WritableStage("a")) + val steps = stages0 ++ stages1 + val p = new Pipeline().setStages(steps) + } } |