aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-04-27 16:11:12 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-27 16:11:12 -0700
commitf5ebb18c45ffdee2756a80f64239cb9158df1a11 (patch)
treed7994c8299067f2118b3d9fedd0b41ea0a1cb438 /mllib/src/test
parent6466d6c8a47273f08451ab5950d31d130c685c7a (diff)
downloadspark-f5ebb18c45ffdee2756a80f64239cb9158df1a11.tar.gz
spark-f5ebb18c45ffdee2756a80f64239cb9158df1a11.tar.bz2
spark-f5ebb18c45ffdee2756a80f64239cb9158df1a11.zip
[SPARK-14671][ML] Pipeline setStages should handle subclasses of PipelineStage
## What changes were proposed in this pull request? Pipeline.setStages failed for some code examples which worked in 1.5 but fail in 1.6. This tends to occur when using a mix of transformers from ml.feature. It is because Java Arrays are non-covariant and the addition of MLWritable to some transformers means the stages0/1 arrays above are not of type Array[PipelineStage]. This PR modifies the following to accept subclasses of PipelineStage: * Pipeline.setStages() * Params.w() ## How was this patch tested? Unit test which fails to compile before this fix. Author: Joseph K. Bradley <joseph@databricks.com> Closes #12430 from jkbradley/pipeline-setstages.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala9
1 files changed, 8 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index a8c4ac6d05..1de638f245 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -27,7 +27,7 @@ import org.scalatest.mock.MockitoSugar.mock
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.Pipeline.SharedReadWrite
import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler}
-import org.apache.spark.ml.param.{IntParam, ParamMap}
+import org.apache.spark.ml.param.{IntParam, ParamMap, ParamPair}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -201,6 +201,13 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
pipeline.fit(df)
}
}
+
+ test("Pipeline.setStages should handle Java Arrays being non-covariant") {
+ val stages0 = Array(new UnWritableStage("b"))
+ val stages1 = Array(new WritableStage("a"))
+ val steps = stages0 ++ stages1
+ val p = new Pipeline().setStages(steps)
+ }
}