diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2015-11-16 17:12:39 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-11-16 17:12:39 -0800 |
commit | 1c5475f1401d2233f4c61f213d1e2c2ee9673067 (patch) | |
tree | 320f6ac8a5e02aace474461962afe6a3b486ac1a /mllib/src/test/scala/org | |
parent | bd10eb81c98e5e9df453f721943a3e82d9f74ae4 (diff) | |
download | spark-1c5475f1401d2233f4c61f213d1e2c2ee9673067.tar.gz spark-1c5475f1401d2233f4c61f213d1e2c2ee9673067.tar.bz2 spark-1c5475f1401d2233f4c61f213d1e2c2ee9673067.zip |
[SPARK-11612][ML] Pipeline and PipelineModel persistence
Pipeline and PipelineModel extend Readable and Writable. Persistence succeeds only when all stages are Writable.
Note: This PR reinstates tests for other read/write functionality. It should probably not get merged until [https://issues.apache.org/jira/browse/SPARK-11672] gets fixed.
CC: mengxr
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #9674 from jkbradley/pipeline-io.
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala | 120 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala | 25 |
2 files changed, 132 insertions, 13 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 1f2c9b75b6..484026b1ba 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -17,19 +17,25 @@ package org.apache.spark.ml +import java.io.File + import scala.collection.JavaConverters._ +import org.apache.hadoop.fs.{FileSystem, Path} import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.Pipeline.SharedReadWrite import org.apache.spark.ml.feature.HashingTF -import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.StructType -class PipelineSuite extends SparkFunSuite { +class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { abstract class MyModel extends Model[MyModel] @@ -111,4 +117,112 @@ class PipelineSuite extends SparkFunSuite { assert(pipelineModel1.uid === "pipeline1") assert(pipelineModel1.stages === stages) } + + test("Pipeline read/write") { + val writableStage = new WritableStage("writableStage").setIntParam(56) + val pipeline = new Pipeline().setStages(Array(writableStage)) + + val pipeline2 = testDefaultReadWrite(pipeline, testParams = false) + assert(pipeline2.getStages.length === 1) + assert(pipeline2.getStages(0).isInstanceOf[WritableStage]) + val writableStage2 = pipeline2.getStages(0).asInstanceOf[WritableStage] + assert(writableStage.getIntParam === writableStage2.getIntParam) + } + + test("Pipeline read/write with non-Writable stage") { + val unWritableStage = new UnWritableStage("unwritableStage") + val unWritablePipeline = new Pipeline().setStages(Array(unWritableStage)) + withClue("Pipeline.write should fail when Pipeline contains non-Writable stage") { + intercept[UnsupportedOperationException] { + unWritablePipeline.write + } + } + } + + test("PipelineModel read/write") { + val writableStage = new WritableStage("writableStage").setIntParam(56) + val pipeline = + new PipelineModel("pipeline_89329327", Array(writableStage.asInstanceOf[Transformer])) + + val pipeline2 = testDefaultReadWrite(pipeline, testParams = false) + assert(pipeline2.stages.length === 1) + assert(pipeline2.stages(0).isInstanceOf[WritableStage]) + val writableStage2 = pipeline2.stages(0).asInstanceOf[WritableStage] + assert(writableStage.getIntParam === writableStage2.getIntParam) + + val path = new File(tempDir, pipeline.uid).getPath + val stagesDir = new Path(path, "stages").toString + val expectedStagePath = SharedReadWrite.getStagePath(writableStage.uid, 0, 1, stagesDir) + assert(FileSystem.get(sc.hadoopConfiguration).exists(new Path(expectedStagePath)), + s"Expected stage 0 of 1 with uid ${writableStage.uid} in Pipeline with uid ${pipeline.uid}" + + s" to be saved to path: $expectedStagePath") + } + + test("PipelineModel read/write: getStagePath") { + val stageUid = "myStage" + val stagesDir = new Path("pipeline", "stages").toString + def testStage(stageIdx: Int, numStages: Int, expectedPrefix: String): Unit = { + val path = SharedReadWrite.getStagePath(stageUid, stageIdx, numStages, stagesDir) + val expected = new Path(stagesDir, expectedPrefix + "_" + stageUid).toString + assert(path === expected) + } + testStage(0, 1, "0") + testStage(0, 9, "0") + testStage(0, 10, "00") + testStage(1, 10, "01") + testStage(12, 999, "012") + } + + test("PipelineModel read/write with non-Writable stage") { + val unWritableStage = new UnWritableStage("unwritableStage") + val unWritablePipeline = + new PipelineModel("pipeline_328957", Array(unWritableStage.asInstanceOf[Transformer])) + withClue("PipelineModel.write should fail when PipelineModel contains non-Writable stage") { + intercept[UnsupportedOperationException] { + unWritablePipeline.write + } + } + } +} + + +/** Used to test [[Pipeline]] with [[Writable]] stages */ +class WritableStage(override val uid: String) extends Transformer with Writable { + + final val intParam: IntParam = new IntParam(this, "intParam", "doc") + + def getIntParam: Int = $(intParam) + + def setIntParam(value: Int): this.type = set(intParam, value) + + setDefault(intParam -> 0) + + override def copy(extra: ParamMap): WritableStage = defaultCopy(extra) + + override def write: Writer = new DefaultParamsWriter(this) + + override def transform(dataset: DataFrame): DataFrame = dataset + + override def transformSchema(schema: StructType): StructType = schema +} + +object WritableStage extends Readable[WritableStage] { + + override def read: Reader[WritableStage] = new DefaultParamsReader[WritableStage] + + override def load(path: String): WritableStage = read.load(path) +} + +/** Used to test [[Pipeline]] with non-[[Writable]] stages */ +class UnWritableStage(override val uid: String) extends Transformer { + + final val intParam: IntParam = new IntParam(this, "intParam", "doc") + + setDefault(intParam -> 0) + + override def copy(extra: ParamMap): UnWritableStage = defaultCopy(extra) + + override def transform(dataset: DataFrame): DataFrame = dataset + + override def transformSchema(schema: StructType): StructType = schema } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index cac4bd9aa3..c37f0503f1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -30,10 +30,13 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => /** * Checks "overwrite" option and params. * @param instance ML instance to test saving/loading + * @param testParams If true, then test values of Params. Otherwise, just test overwrite option. * @tparam T ML instance type * @return Instance loaded from file */ - def testDefaultReadWrite[T <: Params with Writable](instance: T): T = { + def testDefaultReadWrite[T <: Params with Writable]( + instance: T, + testParams: Boolean = true): T = { val uid = instance.uid val path = new File(tempDir, uid).getPath @@ -46,16 +49,18 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => val newInstance = loader.load(path) assert(newInstance.uid === instance.uid) - instance.params.foreach { p => - if (instance.isDefined(p)) { - (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { - case (Array(values), Array(newValues)) => - assert(values === newValues, s"Values do not match on param ${p.name}.") - case (value, newValue) => - assert(value === newValue, s"Values do not match on param ${p.name}.") + if (testParams) { + instance.params.foreach { p => + if (instance.isDefined(p)) { + (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { + case (Array(values), Array(newValues)) => + assert(values === newValues, s"Values do not match on param ${p.name}.") + case (value, newValue) => + assert(value === newValue, s"Values do not match on param ${p.name}.") + } + } else { + assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.") } - } else { - assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.") } } |