aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-06-08 21:33:47 -0700
committerDB Tsai <dbt@netflix.com>2015-06-08 21:33:47 -0700
commit82870d507dfaeeaf315d6766ca1496205c6216d3 (patch)
tree4456e0fd36626cc9badfa70f10af8b806a420d65 /mllib/src
parentf3eec92ce7e13cc461d2f0404f26730259210f12 (diff)
downloadspark-82870d507dfaeeaf315d6766ca1496205c6216d3.tar.gz
spark-82870d507dfaeeaf315d6766ca1496205c6216d3.tar.bz2
spark-82870d507dfaeeaf315d6766ca1496205c6216d3.zip
[SPARK-8168] [MLLIB] Add Python friendly constructor to PipelineModel
This makes the constructor callable in Python. dbtsai Author: Xiangrui Meng <meng@databricks.com> Closes #6709 from mengxr/SPARK-8168 and squashes the following commits: f871de4 [Xiangrui Meng] Add Python friendly constructor to PipelineModel
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala17
2 files changed, 25 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 11a4722722..a9bd28df71 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -17,6 +17,9 @@
package org.apache.spark.ml
+import java.{util => ju}
+
+import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
import org.apache.spark.Logging
@@ -175,6 +178,11 @@ class PipelineModel private[ml] (
val stages: Array[Transformer])
extends Model[PipelineModel] with Logging {
+ /** A Java/Python-friendly auxiliary constructor. */
+ private[ml] def this(uid: String, stages: ju.List[Transformer]) = {
+ this(uid, stages.asScala.toArray)
+ }
+
override def validateParams(): Unit = {
super.validateParams()
stages.foreach(_.validateParams())
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 05bf58e63a..29394fefcb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml
+import scala.collection.JavaConverters._
+
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.when
import org.scalatest.mock.MockitoSugar.mock
@@ -81,4 +83,19 @@ class PipelineSuite extends SparkFunSuite {
pipeline.fit(dataset)
}
}
+
+ test("pipeline model constructors") {
+ val transform0 = mock[Transformer]
+ val model1 = mock[MyModel]
+
+ val stages = Array(transform0, model1)
+ val pipelineModel0 = new PipelineModel("pipeline0", stages)
+ assert(pipelineModel0.uid === "pipeline0")
+ assert(pipelineModel0.stages === stages)
+
+ val stagesAsList = stages.toList.asJava
+ val pipelineModel1 = new PipelineModel("pipeline1", stagesAsList)
+ assert(pipelineModel1.uid === "pipeline1")
+ assert(pipelineModel1.stages === stages)
+ }
}