aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala175
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala120
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala25
4 files changed, 306 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index a3e59401c5..25f0c696f4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -22,12 +22,19 @@ import java.{util => ju}
import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
-import org.apache.spark.Logging
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.ml.param.{Param, ParamMap, Params}
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util.Reader
+import org.apache.spark.ml.util.Writer
+import org.apache.spark.ml.util._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
/**
* :: DeveloperApi ::
@@ -82,7 +89,7 @@ abstract class PipelineStage extends Params with Logging {
* an identity transformer.
*/
@Experimental
-class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
+class Pipeline(override val uid: String) extends Estimator[PipelineModel] with Writable {
def this() = this(Identifiable.randomUID("pipeline"))
@@ -166,6 +173,131 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
"Cannot have duplicate components in a pipeline.")
theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur))
}
+
+ override def write: Writer = new Pipeline.PipelineWriter(this)
+}
+
+object Pipeline extends Readable[Pipeline] {
+
+ override def read: Reader[Pipeline] = new PipelineReader
+
+ override def load(path: String): Pipeline = read.load(path)
+
+ private[ml] class PipelineWriter(instance: Pipeline) extends Writer {
+
+ SharedReadWrite.validateStages(instance.getStages)
+
+ override protected def saveImpl(path: String): Unit =
+ SharedReadWrite.saveImpl(instance, instance.getStages, sc, path)
+ }
+
+ private[ml] class PipelineReader extends Reader[Pipeline] {
+
+ /** Checked against metadata when loading model */
+ private val className = "org.apache.spark.ml.Pipeline"
+
+ override def load(path: String): Pipeline = {
+ val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
+ new Pipeline(uid).setStages(stages)
+ }
+ }
+
+ /** Methods for [[Reader]] and [[Writer]] shared between [[Pipeline]] and [[PipelineModel]] */
+ private[ml] object SharedReadWrite {
+
+ import org.json4s.JsonDSL._
+
+ /** Check that all stages are Writable */
+ def validateStages(stages: Array[PipelineStage]): Unit = {
+ stages.foreach {
+ case stage: Writable => // good
+ case other =>
+ throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" +
+ s" because it contains a stage which does not implement Writable. Non-Writable stage:" +
+ s" ${other.uid} of type ${other.getClass}")
+ }
+ }
+
+ /**
+ * Save metadata and stages for a [[Pipeline]] or [[PipelineModel]]
+ * - save metadata to path/metadata
+ * - save stages to stages/IDX_UID
+ */
+ def saveImpl(
+ instance: Params,
+ stages: Array[PipelineStage],
+ sc: SparkContext,
+ path: String): Unit = {
+ // Copied and edited from DefaultParamsWriter.saveMetadata
+ // TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication
+ val uid = instance.uid
+ val cls = instance.getClass.getName
+ val stageUids = stages.map(_.uid)
+ val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq))))
+ val metadata = ("class" -> cls) ~
+ ("timestamp" -> System.currentTimeMillis()) ~
+ ("sparkVersion" -> sc.version) ~
+ ("uid" -> uid) ~
+ ("paramMap" -> jsonParams)
+ val metadataPath = new Path(path, "metadata").toString
+ val metadataJson = compact(render(metadata))
+ sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
+
+ // Save stages
+ val stagesDir = new Path(path, "stages").toString
+ stages.zipWithIndex.foreach { case (stage: Writable, idx: Int) =>
+ stage.write.save(getStagePath(stage.uid, idx, stages.length, stagesDir))
+ }
+ }
+
+ /**
+ * Load metadata and stages for a [[Pipeline]] or [[PipelineModel]]
+ * @return (UID, list of stages)
+ */
+ def load(
+ expectedClassName: String,
+ sc: SparkContext,
+ path: String): (String, Array[PipelineStage]) = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
+
+ implicit val format = DefaultFormats
+ val stagesDir = new Path(path, "stages").toString
+ val stageUids: Array[String] = metadata.params match {
+ case JObject(pairs) =>
+ if (pairs.length != 1) {
+ // Should not happen unless file is corrupted or we have a bug.
+ throw new RuntimeException(
+ s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.")
+ }
+ pairs.head match {
+ case ("stageUids", jsonValue) =>
+ jsonValue.extract[Seq[String]].toArray
+ case (paramName, jsonValue) =>
+ // Should not happen unless file is corrupted or we have a bug.
+ throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" +
+ s" in metadata: ${metadata.metadataStr}")
+ }
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Cannot recognize JSON metadata: ${metadata.metadataStr}.")
+ }
+ val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) =>
+ val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir)
+ val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc)
+ val cls = Utils.classForName(stageMetadata.className)
+ cls.getMethod("read").invoke(null).asInstanceOf[Reader[PipelineStage]].load(stagePath)
+ }
+ (metadata.uid, stages)
+ }
+
+ /** Get path for saving the given stage. */
+ def getStagePath(stageUid: String, stageIdx: Int, numStages: Int, stagesDir: String): String = {
+ val stageIdxDigits = numStages.toString.length
+ val idxFormat = s"%0${stageIdxDigits}d"
+ val stageDir = idxFormat.format(stageIdx) + "_" + stageUid
+ new Path(stagesDir, stageDir).toString
+ }
+ }
}
/**
@@ -176,7 +308,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
class PipelineModel private[ml] (
override val uid: String,
val stages: Array[Transformer])
- extends Model[PipelineModel] with Logging {
+ extends Model[PipelineModel] with Writable with Logging {
/** A Java/Python-friendly auxiliary constructor. */
private[ml] def this(uid: String, stages: ju.List[Transformer]) = {
@@ -200,4 +332,39 @@ class PipelineModel private[ml] (
override def copy(extra: ParamMap): PipelineModel = {
new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent)
}
+
+ override def write: Writer = new PipelineModel.PipelineModelWriter(this)
+}
+
+object PipelineModel extends Readable[PipelineModel] {
+
+ import Pipeline.SharedReadWrite
+
+ override def read: Reader[PipelineModel] = new PipelineModelReader
+
+ override def load(path: String): PipelineModel = read.load(path)
+
+ private[ml] class PipelineModelWriter(instance: PipelineModel) extends Writer {
+
+ SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]])
+
+ override protected def saveImpl(path: String): Unit = SharedReadWrite.saveImpl(instance,
+ instance.stages.asInstanceOf[Array[PipelineStage]], sc, path)
+ }
+
+ private[ml] class PipelineModelReader extends Reader[PipelineModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = "org.apache.spark.ml.PipelineModel"
+
+ override def load(path: String): PipelineModel = {
+ val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
+ val transformers = stages map {
+ case stage: Transformer => stage
+ case other => throw new RuntimeException(s"PipelineModel.read loaded a stage but found it" +
+ s" was not a Transformer. Bad stage ${other.uid} of type ${other.getClass}")
+ }
+ new PipelineModel(uid, transformers)
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index ca896ed610..3169c9e9af 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -164,6 +164,8 @@ trait Readable[T] {
/**
* Reads an ML instance from the input path, a shortcut of `read.load(path)`.
+ *
+ * Note: Implementing classes should override this to be Java-friendly.
*/
@Since("1.6.0")
def load(path: String): T = read.load(path)
@@ -190,7 +192,7 @@ private[ml] object DefaultParamsWriter {
* - timestamp
* - sparkVersion
* - uid
- * - paramMap
+ * - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]].
*/
def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = {
val uid = instance.uid
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 1f2c9b75b6..484026b1ba 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -17,19 +17,25 @@
package org.apache.spark.ml
+import java.io.File
+
import scala.collection.JavaConverters._
+import org.apache.hadoop.fs.{FileSystem, Path}
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.when
import org.scalatest.mock.MockitoSugar.mock
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.Pipeline.SharedReadWrite
import org.apache.spark.ml.feature.HashingTF
-import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.param.{IntParam, ParamMap}
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.types.StructType
-class PipelineSuite extends SparkFunSuite {
+class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
abstract class MyModel extends Model[MyModel]
@@ -111,4 +117,112 @@ class PipelineSuite extends SparkFunSuite {
assert(pipelineModel1.uid === "pipeline1")
assert(pipelineModel1.stages === stages)
}
+
+ test("Pipeline read/write") {
+ val writableStage = new WritableStage("writableStage").setIntParam(56)
+ val pipeline = new Pipeline().setStages(Array(writableStage))
+
+ val pipeline2 = testDefaultReadWrite(pipeline, testParams = false)
+ assert(pipeline2.getStages.length === 1)
+ assert(pipeline2.getStages(0).isInstanceOf[WritableStage])
+ val writableStage2 = pipeline2.getStages(0).asInstanceOf[WritableStage]
+ assert(writableStage.getIntParam === writableStage2.getIntParam)
+ }
+
+ test("Pipeline read/write with non-Writable stage") {
+ val unWritableStage = new UnWritableStage("unwritableStage")
+ val unWritablePipeline = new Pipeline().setStages(Array(unWritableStage))
+ withClue("Pipeline.write should fail when Pipeline contains non-Writable stage") {
+ intercept[UnsupportedOperationException] {
+ unWritablePipeline.write
+ }
+ }
+ }
+
+ test("PipelineModel read/write") {
+ val writableStage = new WritableStage("writableStage").setIntParam(56)
+ val pipeline =
+ new PipelineModel("pipeline_89329327", Array(writableStage.asInstanceOf[Transformer]))
+
+ val pipeline2 = testDefaultReadWrite(pipeline, testParams = false)
+ assert(pipeline2.stages.length === 1)
+ assert(pipeline2.stages(0).isInstanceOf[WritableStage])
+ val writableStage2 = pipeline2.stages(0).asInstanceOf[WritableStage]
+ assert(writableStage.getIntParam === writableStage2.getIntParam)
+
+ val path = new File(tempDir, pipeline.uid).getPath
+ val stagesDir = new Path(path, "stages").toString
+ val expectedStagePath = SharedReadWrite.getStagePath(writableStage.uid, 0, 1, stagesDir)
+ assert(FileSystem.get(sc.hadoopConfiguration).exists(new Path(expectedStagePath)),
+ s"Expected stage 0 of 1 with uid ${writableStage.uid} in Pipeline with uid ${pipeline.uid}" +
+ s" to be saved to path: $expectedStagePath")
+ }
+
+ test("PipelineModel read/write: getStagePath") {
+ val stageUid = "myStage"
+ val stagesDir = new Path("pipeline", "stages").toString
+ def testStage(stageIdx: Int, numStages: Int, expectedPrefix: String): Unit = {
+ val path = SharedReadWrite.getStagePath(stageUid, stageIdx, numStages, stagesDir)
+ val expected = new Path(stagesDir, expectedPrefix + "_" + stageUid).toString
+ assert(path === expected)
+ }
+ testStage(0, 1, "0")
+ testStage(0, 9, "0")
+ testStage(0, 10, "00")
+ testStage(1, 10, "01")
+ testStage(12, 999, "012")
+ }
+
+ test("PipelineModel read/write with non-Writable stage") {
+ val unWritableStage = new UnWritableStage("unwritableStage")
+ val unWritablePipeline =
+ new PipelineModel("pipeline_328957", Array(unWritableStage.asInstanceOf[Transformer]))
+ withClue("PipelineModel.write should fail when PipelineModel contains non-Writable stage") {
+ intercept[UnsupportedOperationException] {
+ unWritablePipeline.write
+ }
+ }
+ }
+}
+
+
+/** Used to test [[Pipeline]] with [[Writable]] stages */
+class WritableStage(override val uid: String) extends Transformer with Writable {
+
+ final val intParam: IntParam = new IntParam(this, "intParam", "doc")
+
+ def getIntParam: Int = $(intParam)
+
+ def setIntParam(value: Int): this.type = set(intParam, value)
+
+ setDefault(intParam -> 0)
+
+ override def copy(extra: ParamMap): WritableStage = defaultCopy(extra)
+
+ override def write: Writer = new DefaultParamsWriter(this)
+
+ override def transform(dataset: DataFrame): DataFrame = dataset
+
+ override def transformSchema(schema: StructType): StructType = schema
+}
+
+object WritableStage extends Readable[WritableStage] {
+
+ override def read: Reader[WritableStage] = new DefaultParamsReader[WritableStage]
+
+ override def load(path: String): WritableStage = read.load(path)
+}
+
+/** Used to test [[Pipeline]] with non-[[Writable]] stages */
+class UnWritableStage(override val uid: String) extends Transformer {
+
+ final val intParam: IntParam = new IntParam(this, "intParam", "doc")
+
+ setDefault(intParam -> 0)
+
+ override def copy(extra: ParamMap): UnWritableStage = defaultCopy(extra)
+
+ override def transform(dataset: DataFrame): DataFrame = dataset
+
+ override def transformSchema(schema: StructType): StructType = schema
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index cac4bd9aa3..c37f0503f1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -30,10 +30,13 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
/**
* Checks "overwrite" option and params.
* @param instance ML instance to test saving/loading
+ * @param testParams If true, then test values of Params. Otherwise, just test overwrite option.
* @tparam T ML instance type
* @return Instance loaded from file
*/
- def testDefaultReadWrite[T <: Params with Writable](instance: T): T = {
+ def testDefaultReadWrite[T <: Params with Writable](
+ instance: T,
+ testParams: Boolean = true): T = {
val uid = instance.uid
val path = new File(tempDir, uid).getPath
@@ -46,16 +49,18 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
val newInstance = loader.load(path)
assert(newInstance.uid === instance.uid)
- instance.params.foreach { p =>
- if (instance.isDefined(p)) {
- (instance.getOrDefault(p), newInstance.getOrDefault(p)) match {
- case (Array(values), Array(newValues)) =>
- assert(values === newValues, s"Values do not match on param ${p.name}.")
- case (value, newValue) =>
- assert(value === newValue, s"Values do not match on param ${p.name}.")
+ if (testParams) {
+ instance.params.foreach { p =>
+ if (instance.isDefined(p)) {
+ (instance.getOrDefault(p), newInstance.getOrDefault(p)) match {
+ case (Array(values), Array(newValues)) =>
+ assert(values === newValues, s"Values do not match on param ${p.name}.")
+ case (value, newValue) =>
+ assert(value === newValue, s"Values do not match on param ${p.name}.")
+ }
+ } else {
+ assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.")
}
- } else {
- assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.")
}
}