diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2015-11-16 17:12:39 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-11-16 17:12:39 -0800 |
commit | 1c5475f1401d2233f4c61f213d1e2c2ee9673067 (patch) | |
tree | 320f6ac8a5e02aace474461962afe6a3b486ac1a /mllib/src/main/scala/org | |
parent | bd10eb81c98e5e9df453f721943a3e82d9f74ae4 (diff) | |
download | spark-1c5475f1401d2233f4c61f213d1e2c2ee9673067.tar.gz spark-1c5475f1401d2233f4c61f213d1e2c2ee9673067.tar.bz2 spark-1c5475f1401d2233f4c61f213d1e2c2ee9673067.zip |
[SPARK-11612][ML] Pipeline and PipelineModel persistence
Pipeline and PipelineModel extend Readable and Writable. Persistence succeeds only when all stages are Writable.
Note: This PR reinstates tests for other read/write functionality. It should probably not get merged until [https://issues.apache.org/jira/browse/SPARK-11672] gets fixed.
CC: mengxr
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #9674 from jkbradley/pipeline-io.
Diffstat (limited to 'mllib/src/main/scala/org')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala | 175 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala | 4 |
2 files changed, 174 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index a3e59401c5..25f0c696f4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -22,12 +22,19 @@ import java.{util => ju} import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer -import org.apache.spark.Logging +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{SparkContext, Logging} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util.Reader +import org.apache.spark.ml.util.Writer +import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -82,7 +89,7 @@ abstract class PipelineStage extends Params with Logging { * an identity transformer. */ @Experimental -class Pipeline(override val uid: String) extends Estimator[PipelineModel] { +class Pipeline(override val uid: String) extends Estimator[PipelineModel] with Writable { def this() = this(Identifiable.randomUID("pipeline")) @@ -166,6 +173,131 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { "Cannot have duplicate components in a pipeline.") theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } + + override def write: Writer = new Pipeline.PipelineWriter(this) +} + +object Pipeline extends Readable[Pipeline] { + + override def read: Reader[Pipeline] = new PipelineReader + + override def load(path: String): Pipeline = read.load(path) + + private[ml] class PipelineWriter(instance: Pipeline) extends Writer { + + SharedReadWrite.validateStages(instance.getStages) + + override protected def saveImpl(path: String): Unit = + SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) + } + + private[ml] class PipelineReader extends Reader[Pipeline] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.Pipeline" + + override def load(path: String): Pipeline = { + val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) + new Pipeline(uid).setStages(stages) + } + } + + /** Methods for [[Reader]] and [[Writer]] shared between [[Pipeline]] and [[PipelineModel]] */ + private[ml] object SharedReadWrite { + + import org.json4s.JsonDSL._ + + /** Check that all stages are Writable */ + def validateStages(stages: Array[PipelineStage]): Unit = { + stages.foreach { + case stage: Writable => // good + case other => + throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" + + s" because it contains a stage which does not implement Writable. Non-Writable stage:" + + s" ${other.uid} of type ${other.getClass}") + } + } + + /** + * Save metadata and stages for a [[Pipeline]] or [[PipelineModel]] + * - save metadata to path/metadata + * - save stages to stages/IDX_UID + */ + def saveImpl( + instance: Params, + stages: Array[PipelineStage], + sc: SparkContext, + path: String): Unit = { + // Copied and edited from DefaultParamsWriter.saveMetadata + // TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication + val uid = instance.uid + val cls = instance.getClass.getName + val stageUids = stages.map(_.uid) + val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq)))) + val metadata = ("class" -> cls) ~ + ("timestamp" -> System.currentTimeMillis()) ~ + ("sparkVersion" -> sc.version) ~ + ("uid" -> uid) ~ + ("paramMap" -> jsonParams) + val metadataPath = new Path(path, "metadata").toString + val metadataJson = compact(render(metadata)) + sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + + // Save stages + val stagesDir = new Path(path, "stages").toString + stages.zipWithIndex.foreach { case (stage: Writable, idx: Int) => + stage.write.save(getStagePath(stage.uid, idx, stages.length, stagesDir)) + } + } + + /** + * Load metadata and stages for a [[Pipeline]] or [[PipelineModel]] + * @return (UID, list of stages) + */ + def load( + expectedClassName: String, + sc: SparkContext, + path: String): (String, Array[PipelineStage]) = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + + implicit val format = DefaultFormats + val stagesDir = new Path(path, "stages").toString + val stageUids: Array[String] = metadata.params match { + case JObject(pairs) => + if (pairs.length != 1) { + // Should not happen unless file is corrupted or we have a bug. + throw new RuntimeException( + s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.") + } + pairs.head match { + case ("stageUids", jsonValue) => + jsonValue.extract[Seq[String]].toArray + case (paramName, jsonValue) => + // Should not happen unless file is corrupted or we have a bug. + throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" + + s" in metadata: ${metadata.metadataStr}") + } + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") + } + val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) => + val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir) + val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc) + val cls = Utils.classForName(stageMetadata.className) + cls.getMethod("read").invoke(null).asInstanceOf[Reader[PipelineStage]].load(stagePath) + } + (metadata.uid, stages) + } + + /** Get path for saving the given stage. */ + def getStagePath(stageUid: String, stageIdx: Int, numStages: Int, stagesDir: String): String = { + val stageIdxDigits = numStages.toString.length + val idxFormat = s"%0${stageIdxDigits}d" + val stageDir = idxFormat.format(stageIdx) + "_" + stageUid + new Path(stagesDir, stageDir).toString + } + } } /** @@ -176,7 +308,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { class PipelineModel private[ml] ( override val uid: String, val stages: Array[Transformer]) - extends Model[PipelineModel] with Logging { + extends Model[PipelineModel] with Writable with Logging { /** A Java/Python-friendly auxiliary constructor. */ private[ml] def this(uid: String, stages: ju.List[Transformer]) = { @@ -200,4 +332,39 @@ class PipelineModel private[ml] ( override def copy(extra: ParamMap): PipelineModel = { new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } + + override def write: Writer = new PipelineModel.PipelineModelWriter(this) +} + +object PipelineModel extends Readable[PipelineModel] { + + import Pipeline.SharedReadWrite + + override def read: Reader[PipelineModel] = new PipelineModelReader + + override def load(path: String): PipelineModel = read.load(path) + + private[ml] class PipelineModelWriter(instance: PipelineModel) extends Writer { + + SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]]) + + override protected def saveImpl(path: String): Unit = SharedReadWrite.saveImpl(instance, + instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) + } + + private[ml] class PipelineModelReader extends Reader[PipelineModel] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.PipelineModel" + + override def load(path: String): PipelineModel = { + val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) + val transformers = stages map { + case stage: Transformer => stage + case other => throw new RuntimeException(s"PipelineModel.read loaded a stage but found it" + + s" was not a Transformer. Bad stage ${other.uid} of type ${other.getClass}") + } + new PipelineModel(uid, transformers) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index ca896ed610..3169c9e9af 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -164,6 +164,8 @@ trait Readable[T] { /** * Reads an ML instance from the input path, a shortcut of `read.load(path)`. + * + * Note: Implementing classes should override this to be Java-friendly. */ @Since("1.6.0") def load(path: String): T = read.load(path) @@ -190,7 +192,7 @@ private[ml] object DefaultParamsWriter { * - timestamp * - sparkVersion * - uid - * - paramMap + * - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]]. */ def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = { val uid = instance.uid |