diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-06-08 21:33:47 -0700 |
---|---|---|
committer | DB Tsai <dbt@netflix.com> | 2015-06-08 21:33:47 -0700 |
commit | 82870d507dfaeeaf315d6766ca1496205c6216d3 (patch) | |
tree | 4456e0fd36626cc9badfa70f10af8b806a420d65 /mllib/src | |
parent | f3eec92ce7e13cc461d2f0404f26730259210f12 (diff) | |
download | spark-82870d507dfaeeaf315d6766ca1496205c6216d3.tar.gz spark-82870d507dfaeeaf315d6766ca1496205c6216d3.tar.bz2 spark-82870d507dfaeeaf315d6766ca1496205c6216d3.zip |
[SPARK-8168] [MLLIB] Add Python friendly constructor to PipelineModel
This makes the constructor callable in Python. dbtsai
Author: Xiangrui Meng <meng@databricks.com>
Closes #6709 from mengxr/SPARK-8168 and squashes the following commits:
f871de4 [Xiangrui Meng] Add Python friendly constructor to PipelineModel
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala | 8 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala | 17 |
2 files changed, 25 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 11a4722722..a9bd28df71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -17,6 +17,9 @@ package org.apache.spark.ml +import java.{util => ju} + +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import org.apache.spark.Logging @@ -175,6 +178,11 @@ class PipelineModel private[ml] ( val stages: Array[Transformer]) extends Model[PipelineModel] with Logging { + /** A Java/Python-friendly auxiliary constructor. */ + private[ml] def this(uid: String, stages: ju.List[Transformer]) = { + this(uid, stages.asScala.toArray) + } + override def validateParams(): Unit = { super.validateParams() stages.foreach(_.validateParams()) diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 05bf58e63a..29394fefcb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml +import scala.collection.JavaConverters._ + import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar.mock @@ -81,4 +83,19 @@ class PipelineSuite extends SparkFunSuite { pipeline.fit(dataset) } } + + test("pipeline model constructors") { + val transform0 = mock[Transformer] + val model1 = mock[MyModel] + + val stages = Array(transform0, model1) + val pipelineModel0 = new PipelineModel("pipeline0", stages) + assert(pipelineModel0.uid === "pipeline0") + assert(pipelineModel0.stages === stages) + + val stagesAsList = stages.toList.asJava + val pipelineModel1 = new PipelineModel("pipeline1", stagesAsList) + assert(pipelineModel1.uid === "pipeline1") + assert(pipelineModel1.stages === stages) + } } |