aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-11-16 17:12:39 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-11-16 17:12:39 -0800
commit1c5475f1401d2233f4c61f213d1e2c2ee9673067 (patch)
tree320f6ac8a5e02aace474461962afe6a3b486ac1a /mllib/src/main/scala/org
parentbd10eb81c98e5e9df453f721943a3e82d9f74ae4 (diff)
downloadspark-1c5475f1401d2233f4c61f213d1e2c2ee9673067.tar.gz
spark-1c5475f1401d2233f4c61f213d1e2c2ee9673067.tar.bz2
spark-1c5475f1401d2233f4c61f213d1e2c2ee9673067.zip
[SPARK-11612][ML] Pipeline and PipelineModel persistence
Pipeline and PipelineModel extend Readable and Writable. Persistence succeeds only when all stages are Writable. Note: This PR reinstates tests for other read/write functionality. It should probably not get merged until [https://issues.apache.org/jira/browse/SPARK-11672] gets fixed. CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #9674 from jkbradley/pipeline-io.
Diffstat (limited to 'mllib/src/main/scala/org')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala175
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala4
2 files changed, 174 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index a3e59401c5..25f0c696f4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -22,12 +22,19 @@ import java.{util => ju}
import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
-import org.apache.spark.Logging
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.ml.param.{Param, ParamMap, Params}
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util.Reader
+import org.apache.spark.ml.util.Writer
+import org.apache.spark.ml.util._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
/**
* :: DeveloperApi ::
@@ -82,7 +89,7 @@ abstract class PipelineStage extends Params with Logging {
* an identity transformer.
*/
@Experimental
-class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
+class Pipeline(override val uid: String) extends Estimator[PipelineModel] with Writable {
def this() = this(Identifiable.randomUID("pipeline"))
@@ -166,6 +173,131 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
"Cannot have duplicate components in a pipeline.")
theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur))
}
+
+ override def write: Writer = new Pipeline.PipelineWriter(this)
+}
+
+object Pipeline extends Readable[Pipeline] {
+
+ override def read: Reader[Pipeline] = new PipelineReader
+
+ override def load(path: String): Pipeline = read.load(path)
+
+ private[ml] class PipelineWriter(instance: Pipeline) extends Writer {
+
+ SharedReadWrite.validateStages(instance.getStages)
+
+ override protected def saveImpl(path: String): Unit =
+ SharedReadWrite.saveImpl(instance, instance.getStages, sc, path)
+ }
+
+ private[ml] class PipelineReader extends Reader[Pipeline] {
+
+ /** Checked against metadata when loading model */
+ private val className = "org.apache.spark.ml.Pipeline"
+
+ override def load(path: String): Pipeline = {
+ val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
+ new Pipeline(uid).setStages(stages)
+ }
+ }
+
+ /** Methods for [[Reader]] and [[Writer]] shared between [[Pipeline]] and [[PipelineModel]] */
+ private[ml] object SharedReadWrite {
+
+ import org.json4s.JsonDSL._
+
+ /** Check that all stages are Writable */
+ def validateStages(stages: Array[PipelineStage]): Unit = {
+ stages.foreach {
+ case stage: Writable => // good
+ case other =>
+ throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" +
+ s" because it contains a stage which does not implement Writable. Non-Writable stage:" +
+ s" ${other.uid} of type ${other.getClass}")
+ }
+ }
+
+ /**
+ * Save metadata and stages for a [[Pipeline]] or [[PipelineModel]]
+ * - save metadata to path/metadata
+ * - save stages to stages/IDX_UID
+ */
+ def saveImpl(
+ instance: Params,
+ stages: Array[PipelineStage],
+ sc: SparkContext,
+ path: String): Unit = {
+ // Copied and edited from DefaultParamsWriter.saveMetadata
+ // TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication
+ val uid = instance.uid
+ val cls = instance.getClass.getName
+ val stageUids = stages.map(_.uid)
+ val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq))))
+ val metadata = ("class" -> cls) ~
+ ("timestamp" -> System.currentTimeMillis()) ~
+ ("sparkVersion" -> sc.version) ~
+ ("uid" -> uid) ~
+ ("paramMap" -> jsonParams)
+ val metadataPath = new Path(path, "metadata").toString
+ val metadataJson = compact(render(metadata))
+ sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
+
+ // Save stages
+ val stagesDir = new Path(path, "stages").toString
+ stages.zipWithIndex.foreach { case (stage: Writable, idx: Int) =>
+ stage.write.save(getStagePath(stage.uid, idx, stages.length, stagesDir))
+ }
+ }
+
+ /**
+ * Load metadata and stages for a [[Pipeline]] or [[PipelineModel]]
+ * @return (UID, list of stages)
+ */
+ def load(
+ expectedClassName: String,
+ sc: SparkContext,
+ path: String): (String, Array[PipelineStage]) = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
+
+ implicit val format = DefaultFormats
+ val stagesDir = new Path(path, "stages").toString
+ val stageUids: Array[String] = metadata.params match {
+ case JObject(pairs) =>
+ if (pairs.length != 1) {
+ // Should not happen unless file is corrupted or we have a bug.
+ throw new RuntimeException(
+ s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.")
+ }
+ pairs.head match {
+ case ("stageUids", jsonValue) =>
+ jsonValue.extract[Seq[String]].toArray
+ case (paramName, jsonValue) =>
+ // Should not happen unless file is corrupted or we have a bug.
+ throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" +
+ s" in metadata: ${metadata.metadataStr}")
+ }
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Cannot recognize JSON metadata: ${metadata.metadataStr}.")
+ }
+ val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) =>
+ val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir)
+ val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc)
+ val cls = Utils.classForName(stageMetadata.className)
+ cls.getMethod("read").invoke(null).asInstanceOf[Reader[PipelineStage]].load(stagePath)
+ }
+ (metadata.uid, stages)
+ }
+
+ /** Get path for saving the given stage. */
+ def getStagePath(stageUid: String, stageIdx: Int, numStages: Int, stagesDir: String): String = {
+ val stageIdxDigits = numStages.toString.length
+ val idxFormat = s"%0${stageIdxDigits}d"
+ val stageDir = idxFormat.format(stageIdx) + "_" + stageUid
+ new Path(stagesDir, stageDir).toString
+ }
+ }
}
/**
@@ -176,7 +308,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
class PipelineModel private[ml] (
override val uid: String,
val stages: Array[Transformer])
- extends Model[PipelineModel] with Logging {
+ extends Model[PipelineModel] with Writable with Logging {
/** A Java/Python-friendly auxiliary constructor. */
private[ml] def this(uid: String, stages: ju.List[Transformer]) = {
@@ -200,4 +332,39 @@ class PipelineModel private[ml] (
override def copy(extra: ParamMap): PipelineModel = {
new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent)
}
+
+ override def write: Writer = new PipelineModel.PipelineModelWriter(this)
+}
+
+object PipelineModel extends Readable[PipelineModel] {
+
+ import Pipeline.SharedReadWrite
+
+ override def read: Reader[PipelineModel] = new PipelineModelReader
+
+ override def load(path: String): PipelineModel = read.load(path)
+
+ private[ml] class PipelineModelWriter(instance: PipelineModel) extends Writer {
+
+ SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]])
+
+ override protected def saveImpl(path: String): Unit = SharedReadWrite.saveImpl(instance,
+ instance.stages.asInstanceOf[Array[PipelineStage]], sc, path)
+ }
+
+ private[ml] class PipelineModelReader extends Reader[PipelineModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = "org.apache.spark.ml.PipelineModel"
+
+ override def load(path: String): PipelineModel = {
+ val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
+ val transformers = stages map {
+ case stage: Transformer => stage
+ case other => throw new RuntimeException(s"PipelineModel.read loaded a stage but found it" +
+ s" was not a Transformer. Bad stage ${other.uid} of type ${other.getClass}")
+ }
+ new PipelineModel(uid, transformers)
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index ca896ed610..3169c9e9af 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -164,6 +164,8 @@ trait Readable[T] {
/**
* Reads an ML instance from the input path, a shortcut of `read.load(path)`.
+ *
+ * Note: Implementing classes should override this to be Java-friendly.
*/
@Since("1.6.0")
def load(path: String): T = read.load(path)
@@ -190,7 +192,7 @@ private[ml] object DefaultParamsWriter {
* - timestamp
* - sparkVersion
* - uid
- * - paramMap
+ * - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]].
*/
def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = {
val uid = instance.uid