aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-11-16 17:12:39 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-11-16 17:12:39 -0800
commit1c5475f1401d2233f4c61f213d1e2c2ee9673067 (patch)
tree320f6ac8a5e02aace474461962afe6a3b486ac1a /mllib/src/test/scala/org
parentbd10eb81c98e5e9df453f721943a3e82d9f74ae4 (diff)
downloadspark-1c5475f1401d2233f4c61f213d1e2c2ee9673067.tar.gz
spark-1c5475f1401d2233f4c61f213d1e2c2ee9673067.tar.bz2
spark-1c5475f1401d2233f4c61f213d1e2c2ee9673067.zip
[SPARK-11612][ML] Pipeline and PipelineModel persistence
Pipeline and PipelineModel extend Readable and Writable. Persistence succeeds only when all stages are Writable. Note: This PR reinstates tests for other read/write functionality. It should probably not get merged until [https://issues.apache.org/jira/browse/SPARK-11672] gets fixed. CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #9674 from jkbradley/pipeline-io.
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala120
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala25
2 files changed, 132 insertions, 13 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 1f2c9b75b6..484026b1ba 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -17,19 +17,25 @@
package org.apache.spark.ml
+import java.io.File
+
import scala.collection.JavaConverters._
+import org.apache.hadoop.fs.{FileSystem, Path}
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.when
import org.scalatest.mock.MockitoSugar.mock
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.Pipeline.SharedReadWrite
import org.apache.spark.ml.feature.HashingTF
-import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.param.{IntParam, ParamMap}
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.types.StructType
-class PipelineSuite extends SparkFunSuite {
+class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
abstract class MyModel extends Model[MyModel]
@@ -111,4 +117,112 @@ class PipelineSuite extends SparkFunSuite {
assert(pipelineModel1.uid === "pipeline1")
assert(pipelineModel1.stages === stages)
}
+
+ test("Pipeline read/write") {
+ val writableStage = new WritableStage("writableStage").setIntParam(56)
+ val pipeline = new Pipeline().setStages(Array(writableStage))
+
+ val pipeline2 = testDefaultReadWrite(pipeline, testParams = false)
+ assert(pipeline2.getStages.length === 1)
+ assert(pipeline2.getStages(0).isInstanceOf[WritableStage])
+ val writableStage2 = pipeline2.getStages(0).asInstanceOf[WritableStage]
+ assert(writableStage.getIntParam === writableStage2.getIntParam)
+ }
+
+ test("Pipeline read/write with non-Writable stage") {
+ val unWritableStage = new UnWritableStage("unwritableStage")
+ val unWritablePipeline = new Pipeline().setStages(Array(unWritableStage))
+ withClue("Pipeline.write should fail when Pipeline contains non-Writable stage") {
+ intercept[UnsupportedOperationException] {
+ unWritablePipeline.write
+ }
+ }
+ }
+
+ test("PipelineModel read/write") {
+ val writableStage = new WritableStage("writableStage").setIntParam(56)
+ val pipeline =
+ new PipelineModel("pipeline_89329327", Array(writableStage.asInstanceOf[Transformer]))
+
+ val pipeline2 = testDefaultReadWrite(pipeline, testParams = false)
+ assert(pipeline2.stages.length === 1)
+ assert(pipeline2.stages(0).isInstanceOf[WritableStage])
+ val writableStage2 = pipeline2.stages(0).asInstanceOf[WritableStage]
+ assert(writableStage.getIntParam === writableStage2.getIntParam)
+
+ val path = new File(tempDir, pipeline.uid).getPath
+ val stagesDir = new Path(path, "stages").toString
+ val expectedStagePath = SharedReadWrite.getStagePath(writableStage.uid, 0, 1, stagesDir)
+ assert(FileSystem.get(sc.hadoopConfiguration).exists(new Path(expectedStagePath)),
+ s"Expected stage 0 of 1 with uid ${writableStage.uid} in Pipeline with uid ${pipeline.uid}" +
+ s" to be saved to path: $expectedStagePath")
+ }
+
+ test("PipelineModel read/write: getStagePath") {
+ val stageUid = "myStage"
+ val stagesDir = new Path("pipeline", "stages").toString
+ def testStage(stageIdx: Int, numStages: Int, expectedPrefix: String): Unit = {
+ val path = SharedReadWrite.getStagePath(stageUid, stageIdx, numStages, stagesDir)
+ val expected = new Path(stagesDir, expectedPrefix + "_" + stageUid).toString
+ assert(path === expected)
+ }
+ testStage(0, 1, "0")
+ testStage(0, 9, "0")
+ testStage(0, 10, "00")
+ testStage(1, 10, "01")
+ testStage(12, 999, "012")
+ }
+
+ test("PipelineModel read/write with non-Writable stage") {
+ val unWritableStage = new UnWritableStage("unwritableStage")
+ val unWritablePipeline =
+ new PipelineModel("pipeline_328957", Array(unWritableStage.asInstanceOf[Transformer]))
+ withClue("PipelineModel.write should fail when PipelineModel contains non-Writable stage") {
+ intercept[UnsupportedOperationException] {
+ unWritablePipeline.write
+ }
+ }
+ }
+}
+
+
+/** Used to test [[Pipeline]] with [[Writable]] stages */
+class WritableStage(override val uid: String) extends Transformer with Writable {
+
+ final val intParam: IntParam = new IntParam(this, "intParam", "doc")
+
+ def getIntParam: Int = $(intParam)
+
+ def setIntParam(value: Int): this.type = set(intParam, value)
+
+ setDefault(intParam -> 0)
+
+ override def copy(extra: ParamMap): WritableStage = defaultCopy(extra)
+
+ override def write: Writer = new DefaultParamsWriter(this)
+
+ override def transform(dataset: DataFrame): DataFrame = dataset
+
+ override def transformSchema(schema: StructType): StructType = schema
+}
+
+object WritableStage extends Readable[WritableStage] {
+
+ override def read: Reader[WritableStage] = new DefaultParamsReader[WritableStage]
+
+ override def load(path: String): WritableStage = read.load(path)
+}
+
+/** Used to test [[Pipeline]] with non-[[Writable]] stages */
+class UnWritableStage(override val uid: String) extends Transformer {
+
+ final val intParam: IntParam = new IntParam(this, "intParam", "doc")
+
+ setDefault(intParam -> 0)
+
+ override def copy(extra: ParamMap): UnWritableStage = defaultCopy(extra)
+
+ override def transform(dataset: DataFrame): DataFrame = dataset
+
+ override def transformSchema(schema: StructType): StructType = schema
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index cac4bd9aa3..c37f0503f1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -30,10 +30,13 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
/**
* Checks "overwrite" option and params.
* @param instance ML instance to test saving/loading
+ * @param testParams If true, then test values of Params. Otherwise, just test overwrite option.
* @tparam T ML instance type
* @return Instance loaded from file
*/
- def testDefaultReadWrite[T <: Params with Writable](instance: T): T = {
+ def testDefaultReadWrite[T <: Params with Writable](
+ instance: T,
+ testParams: Boolean = true): T = {
val uid = instance.uid
val path = new File(tempDir, uid).getPath
@@ -46,16 +49,18 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
val newInstance = loader.load(path)
assert(newInstance.uid === instance.uid)
- instance.params.foreach { p =>
- if (instance.isDefined(p)) {
- (instance.getOrDefault(p), newInstance.getOrDefault(p)) match {
- case (Array(values), Array(newValues)) =>
- assert(values === newValues, s"Values do not match on param ${p.name}.")
- case (value, newValue) =>
- assert(value === newValue, s"Values do not match on param ${p.name}.")
+ if (testParams) {
+ instance.params.foreach { p =>
+ if (instance.isDefined(p)) {
+ (instance.getOrDefault(p), newInstance.getOrDefault(p)) match {
+ case (Array(values), Array(newValues)) =>
+ assert(values === newValues, s"Values do not match on param ${p.name}.")
+ case (value, newValue) =>
+ assert(value === newValue, s"Values do not match on param ${p.name}.")
+ }
+ } else {
+ assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.")
}
- } else {
- assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.")
}
}