From f5ebb18c45ffdee2756a80f64239cb9158df1a11 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 27 Apr 2016 16:11:12 -0700 Subject: [SPARK-14671][ML] Pipeline setStages should handle subclasses of PipelineStage ## What changes were proposed in this pull request? Pipeline.setStages failed for some code examples which worked in 1.5 but fail in 1.6. This tends to occur when using a mix of transformers from ml.feature. It is because Java Arrays are non-covariant and the addition of MLWritable to some transformers means the stages0/1 arrays above are not of type Array[PipelineStage]. This PR modifies the following to accept subclasses of PipelineStage: * Pipeline.setStages() * Params.w() ## How was this patch tested? Unit test which fails to compile before this fix. Author: Joseph K. Bradley Closes #12430 from jkbradley/pipeline-setstages. --- mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala | 5 ++++- mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala | 9 ++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 82066726a0..b02aea92b7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -103,7 +103,10 @@ class Pipeline @Since("1.4.0") ( /** @group setParam */ @Since("1.2.0") - def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } + def setStages(value: Array[_ <: PipelineStage]): this.type = { + set(stages, value.asInstanceOf[Array[PipelineStage]]) + this + } // Below, we clone stages so that modifications to the list of stages will not change // the Param value in the Pipeline. diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index a8c4ac6d05..1de638f245 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite import org.apache.spark.ml.Pipeline.SharedReadWrite import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler} -import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamPair} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -201,6 +201,13 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul pipeline.fit(df) } } + + test("Pipeline.setStages should handle Java Arrays being non-covariant") { + val stages0 = Array(new UnWritableStage("b")) + val stages1 = Array(new WritableStage("a")) + val steps = stages0 ++ stages1 + val p = new Pipeline().setStages(steps) + } } -- cgit v1.2.3