aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-03-28 15:40:06 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-28 15:40:06 -0700
commit8c11d1aab8522c75d78bc6b30402c64e8d9ff065 (patch)
tree7e31c6256fcb49db9eb0f47b86b24daa24764650 /mllib/src/main
parent39f743a6231cbd8cc770a28f43ee601eff28d597 (diff)
downloadspark-8c11d1aab8522c75d78bc6b30402c64e8d9ff065.tar.gz
spark-8c11d1aab8522c75d78bc6b30402c64e8d9ff065.tar.bz2
spark-8c11d1aab8522c75d78bc6b30402c64e8d9ff065.zip
[SPARK-11893] Model export/import for spark.ml: TrainValidationSplit
https://issues.apache.org/jira/browse/SPARK-11893 jkbradley In order to share read/write with `TrainValidationSplit`, I move the `SharedReadWrite` out of `CrossValidator` into a new trait `SharedReadWrite` in the tunning package. To reduce the repeated tests, I move the complex tests from `CrossValidatorSuite` to `SharedReadWriteSuite`, and create a fake validator called `MyValidator` to test the shared code. With `SharedReadWrite`, potential newly added `Validator` can share the read/write common part, and only need to implement their extra params save/load. Author: Xusen Yin <yinxusen@gmail.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #9971 from yinxusen/SPARK-11893.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala148
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala100
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala117
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala42
4 files changed, 267 insertions, 140 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 963f81cb3e..040b0093b9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -19,25 +19,19 @@ package org.apache.spark.ml.tuning
import com.github.fommil.netlib.F2jBLAS
import org.apache.hadoop.fs.Path
-import org.json4s.{DefaultFormats, JObject}
-import org.json4s.jackson.JsonMethods._
+import org.json4s.DefaultFormats
-import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
-import org.apache.spark.ml.classification.OneVsRestParams
import org.apache.spark.ml.evaluation.Evaluator
-import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.HasSeed
import org.apache.spark.ml.util._
-import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
-
/**
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
*/
@@ -45,6 +39,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed {
/**
* Param for number of folds for cross validation. Must be >= 2.
* Default: 3
+ *
* @group param
*/
val numFolds: IntParam = new IntParam(this, "numFolds",
@@ -163,10 +158,10 @@ object CrossValidator extends MLReadable[CrossValidator] {
private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter {
- SharedReadWrite.validateParams(instance)
+ ValidatorParams.validateParams(instance)
override protected def saveImpl(path: String): Unit =
- SharedReadWrite.saveImpl(path, instance, sc)
+ ValidatorParams.saveImpl(path, instance, sc)
}
private class CrossValidatorReader extends MLReader[CrossValidator] {
@@ -175,8 +170,11 @@ object CrossValidator extends MLReadable[CrossValidator] {
private val className = classOf[CrossValidator].getName
override def load(path: String): CrossValidator = {
- val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
- SharedReadWrite.load(path, sc, className)
+ implicit val format = DefaultFormats
+
+ val (metadata, estimator, evaluator, estimatorParamMaps) =
+ ValidatorParams.loadImpl(path, sc, className)
+ val numFolds = (metadata.params \ "numFolds").extract[Int]
new CrossValidator(metadata.uid)
.setEstimator(estimator)
.setEvaluator(evaluator)
@@ -184,123 +182,6 @@ object CrossValidator extends MLReadable[CrossValidator] {
.setNumFolds(numFolds)
}
}
-
- private object CrossValidatorReader {
- /**
- * Examine the given estimator (which may be a compound estimator) and extract a mapping
- * from UIDs to corresponding [[Params]] instances.
- */
- def getUidMap(instance: Params): Map[String, Params] = {
- val uidList = getUidMapImpl(instance)
- val uidMap = uidList.toMap
- if (uidList.size != uidMap.size) {
- throw new RuntimeException("CrossValidator.load found a compound estimator with stages" +
- s" with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}")
- }
- uidMap
- }
-
- def getUidMapImpl(instance: Params): List[(String, Params)] = {
- val subStages: Array[Params] = instance match {
- case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
- case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
- case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
- case ovr: OneVsRestParams =>
- // TODO: SPARK-11892: This case may require special handling.
- throw new UnsupportedOperationException("CrossValidator write will fail because it" +
- " cannot yet handle an estimator containing type: ${ovr.getClass.getName}")
- case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
- case _: Params => Array()
- }
- val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)
- List((instance.uid, instance)) ++ subStageMaps
- }
- }
-
- private[tuning] object SharedReadWrite {
-
- /**
- * Check that [[CrossValidator.evaluator]] and [[CrossValidator.estimator]] are Writable.
- * This does not check [[CrossValidator.estimatorParamMaps]].
- */
- def validateParams(instance: ValidatorParams): Unit = {
- def checkElement(elem: Params, name: String): Unit = elem match {
- case stage: MLWritable => // good
- case other =>
- throw new UnsupportedOperationException("CrossValidator write will fail " +
- s" because it contains $name which does not implement Writable." +
- s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
- }
- checkElement(instance.getEvaluator, "evaluator")
- checkElement(instance.getEstimator, "estimator")
- // Check to make sure all Params apply to this estimator. Throw an error if any do not.
- // Extraneous Params would cause problems when loading the estimatorParamMaps.
- val uidToInstance: Map[String, Params] = CrossValidatorReader.getUidMap(instance)
- instance.getEstimatorParamMaps.foreach { case pMap: ParamMap =>
- pMap.toSeq.foreach { case ParamPair(p, v) =>
- require(uidToInstance.contains(p.parent), s"CrossValidator save requires all Params in" +
- s" estimatorParamMaps to apply to this CrossValidator, its Estimator, or its" +
- s" Evaluator. An extraneous Param was found: $p")
- }
- }
- }
-
- private[tuning] def saveImpl(
- path: String,
- instance: CrossValidatorParams,
- sc: SparkContext,
- extraMetadata: Option[JObject] = None): Unit = {
- import org.json4s.JsonDSL._
-
- val estimatorParamMapsJson = compact(render(
- instance.getEstimatorParamMaps.map { case paramMap =>
- paramMap.toSeq.map { case ParamPair(p, v) =>
- Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v))
- }
- }.toSeq
- ))
- val jsonParams = List(
- "numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)),
- "estimatorParamMaps" -> parse(estimatorParamMapsJson)
- )
- DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
-
- val evaluatorPath = new Path(path, "evaluator").toString
- instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath)
- val estimatorPath = new Path(path, "estimator").toString
- instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath)
- }
-
- private[tuning] def load[M <: Model[M]](
- path: String,
- sc: SparkContext,
- expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap], Int) = {
-
- val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
-
- implicit val format = DefaultFormats
- val evaluatorPath = new Path(path, "evaluator").toString
- val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc)
- val estimatorPath = new Path(path, "estimator").toString
- val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc)
-
- val uidToParams = Map(evaluator.uid -> evaluator) ++ CrossValidatorReader.getUidMap(estimator)
-
- val numFolds = (metadata.params \ "numFolds").extract[Int]
- val estimatorParamMaps: Array[ParamMap] =
- (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map {
- pMap =>
- val paramPairs = pMap.map { case pInfo: Map[String, String] =>
- val est = uidToParams(pInfo("parent"))
- val param = est.getParam(pInfo("name"))
- val value = param.jsonDecode(pInfo("value"))
- param -> value
- }
- ParamMap(paramPairs: _*)
- }.toArray
- (metadata, estimator, evaluator, estimatorParamMaps, numFolds)
- }
- }
}
/**
@@ -346,8 +227,6 @@ class CrossValidatorModel private[ml] (
@Since("1.6.0")
object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
- import CrossValidator.SharedReadWrite
-
@Since("1.6.0")
override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader
@@ -357,12 +236,12 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
private[CrossValidatorModel]
class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter {
- SharedReadWrite.validateParams(instance)
+ ValidatorParams.validateParams(instance)
override protected def saveImpl(path: String): Unit = {
import org.json4s.JsonDSL._
val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq
- SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata))
+ ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
val bestModelPath = new Path(path, "bestModel").toString
instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
}
@@ -376,8 +255,9 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
override def load(path: String): CrossValidatorModel = {
implicit val format = DefaultFormats
- val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
- SharedReadWrite.load(path, sc, className)
+ val (metadata, estimator, evaluator, estimatorParamMaps) =
+ ValidatorParams.loadImpl(path, sc, className)
+ val numFolds = (metadata.params \ "numFolds").extract[Int]
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 70fa5f0234..4d1d6364d7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -17,12 +17,15 @@
package org.apache.spark.ml.tuning
+import org.apache.hadoop.fs.Path
+import org.json4s.DefaultFormats
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
@@ -33,6 +36,7 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams {
/**
* Param for ratio between train and validation data. Must be between 0 and 1.
* Default: 0.75
+ *
* @group param
*/
val trainRatio: DoubleParam = new DoubleParam(this, "trainRatio",
@@ -55,7 +59,7 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams {
@Experimental
class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String)
extends Estimator[TrainValidationSplitModel]
- with TrainValidationSplitParams with Logging {
+ with TrainValidationSplitParams with MLWritable with Logging {
@Since("1.5.0")
def this() = this(Identifiable.randomUID("tvs"))
@@ -130,6 +134,47 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
}
copied
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new TrainValidationSplit.TrainValidationSplitWriter(this)
+}
+
+@Since("2.0.0")
+object TrainValidationSplit extends MLReadable[TrainValidationSplit] {
+
+ @Since("2.0.0")
+ override def read: MLReader[TrainValidationSplit] = new TrainValidationSplitReader
+
+ @Since("2.0.0")
+ override def load(path: String): TrainValidationSplit = super.load(path)
+
+ private[TrainValidationSplit] class TrainValidationSplitWriter(instance: TrainValidationSplit)
+ extends MLWriter {
+
+ ValidatorParams.validateParams(instance)
+
+ override protected def saveImpl(path: String): Unit =
+ ValidatorParams.saveImpl(path, instance, sc)
+ }
+
+ private class TrainValidationSplitReader extends MLReader[TrainValidationSplit] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[TrainValidationSplit].getName
+
+ override def load(path: String): TrainValidationSplit = {
+ implicit val format = DefaultFormats
+
+ val (metadata, estimator, evaluator, estimatorParamMaps) =
+ ValidatorParams.loadImpl(path, sc, className)
+ val trainRatio = (metadata.params \ "trainRatio").extract[Double]
+ new TrainValidationSplit(metadata.uid)
+ .setEstimator(estimator)
+ .setEvaluator(evaluator)
+ .setEstimatorParamMaps(estimatorParamMaps)
+ .setTrainRatio(trainRatio)
+ }
+ }
}
/**
@@ -146,7 +191,7 @@ class TrainValidationSplitModel private[ml] (
@Since("1.5.0") override val uid: String,
@Since("1.5.0") val bestModel: Model[_],
@Since("1.5.0") val validationMetrics: Array[Double])
- extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {
+ extends Model[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable {
@Since("1.5.0")
override def transform(dataset: DataFrame): DataFrame = {
@@ -167,4 +212,53 @@ class TrainValidationSplitModel private[ml] (
validationMetrics.clone())
copyValues(copied, extra)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new TrainValidationSplitModel.TrainValidationSplitModelWriter(this)
+}
+
+@Since("2.0.0")
+object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[TrainValidationSplitModel] = new TrainValidationSplitModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): TrainValidationSplitModel = super.load(path)
+
+ private[TrainValidationSplitModel]
+ class TrainValidationSplitModelWriter(instance: TrainValidationSplitModel) extends MLWriter {
+
+ ValidatorParams.validateParams(instance)
+
+ override protected def saveImpl(path: String): Unit = {
+ import org.json4s.JsonDSL._
+ val extraMetadata = "validationMetrics" -> instance.validationMetrics.toSeq
+ ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
+ val bestModelPath = new Path(path, "bestModel").toString
+ instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
+ }
+ }
+
+ private class TrainValidationSplitModelReader extends MLReader[TrainValidationSplitModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[TrainValidationSplitModel].getName
+
+ override def load(path: String): TrainValidationSplitModel = {
+ implicit val format = DefaultFormats
+
+ val (metadata, estimator, evaluator, estimatorParamMaps) =
+ ValidatorParams.loadImpl(path, sc, className)
+ val trainRatio = (metadata.params \ "trainRatio").extract[Double]
+ val bestModelPath = new Path(path, "bestModel").toString
+ val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
+ val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray
+ val tvs = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics)
+ tvs.set(tvs.estimator, estimator)
+ .set(tvs.evaluator, evaluator)
+ .set(tvs.estimatorParamMaps, estimatorParamMaps)
+ .set(tvs.trainRatio, trainRatio)
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
index 953456e8f0..7a4e106aeb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
@@ -17,9 +17,17 @@
package org.apache.spark.ml.tuning
-import org.apache.spark.ml.Estimator
+import org.apache.hadoop.fs.Path
+import org.json4s.{DefaultFormats, _}
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
-import org.apache.spark.ml.param.{Param, ParamMap, Params}
+import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
+import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MetaAlgorithmReadWrite,
+ MLWritable}
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.sql.types.StructType
/**
@@ -69,3 +77,108 @@ private[ml] trait ValidatorParams extends Params {
est.copy(firstEstimatorParamMap).transformSchema(schema)
}
}
+
+private[ml] object ValidatorParams {
+ /**
+ * Check that [[ValidatorParams.evaluator]] and [[ValidatorParams.estimator]] are Writable.
+ * This does not check [[ValidatorParams.estimatorParamMaps]].
+ */
+ def validateParams(instance: ValidatorParams): Unit = {
+ def checkElement(elem: Params, name: String): Unit = elem match {
+ case stage: MLWritable => // good
+ case other =>
+ throw new UnsupportedOperationException(instance.getClass.getName + " write will fail " +
+ s" because it contains $name which does not implement Writable." +
+ s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
+ }
+ checkElement(instance.getEvaluator, "evaluator")
+ checkElement(instance.getEstimator, "estimator")
+ // Check to make sure all Params apply to this estimator. Throw an error if any do not.
+ // Extraneous Params would cause problems when loading the estimatorParamMaps.
+ val uidToInstance: Map[String, Params] = MetaAlgorithmReadWrite.getUidMap(instance)
+ instance.getEstimatorParamMaps.foreach { case pMap: ParamMap =>
+ pMap.toSeq.foreach { case ParamPair(p, v) =>
+ require(uidToInstance.contains(p.parent), s"ValidatorParams save requires all Params in" +
+ s" estimatorParamMaps to apply to this ValidatorParams, its Estimator, or its" +
+ s" Evaluator. An extraneous Param was found: $p")
+ }
+ }
+ }
+
+ /**
+ * Generic implementation of save for [[ValidatorParams]] types.
+ * This handles all [[ValidatorParams]] fields and saves [[Param]] values, but the implementing
+ * class needs to handle model data.
+ */
+ def saveImpl(
+ path: String,
+ instance: ValidatorParams,
+ sc: SparkContext,
+ extraMetadata: Option[JObject] = None): Unit = {
+ import org.json4s.JsonDSL._
+
+ val estimatorParamMapsJson = compact(render(
+ instance.getEstimatorParamMaps.map { case paramMap =>
+ paramMap.toSeq.map { case ParamPair(p, v) =>
+ Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v))
+ }
+ }.toSeq
+ ))
+
+ val validatorSpecificParams = instance match {
+ case cv: CrossValidatorParams =>
+ List("numFolds" -> parse(cv.numFolds.jsonEncode(cv.getNumFolds)))
+ case tvs: TrainValidationSplitParams =>
+ List("trainRatio" -> parse(tvs.trainRatio.jsonEncode(tvs.getTrainRatio)))
+ case _ =>
+ // This should not happen.
+ throw new NotImplementedError("ValidatorParams.saveImpl does not handle type: " +
+ instance.getClass.getCanonicalName)
+ }
+
+ val jsonParams = validatorSpecificParams ++ List(
+ "estimatorParamMaps" -> parse(estimatorParamMapsJson))
+
+ DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
+
+ val evaluatorPath = new Path(path, "evaluator").toString
+ instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath)
+ val estimatorPath = new Path(path, "estimator").toString
+ instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath)
+ }
+
+ /**
+ * Generic implementation of load for [[ValidatorParams]] types.
+ * This handles all [[ValidatorParams]] fields, but the implementing
+ * class needs to handle model data and special [[Param]] values.
+ */
+ def loadImpl[M <: Model[M]](
+ path: String,
+ sc: SparkContext,
+ expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap]) = {
+
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
+
+ implicit val format = DefaultFormats
+ val evaluatorPath = new Path(path, "evaluator").toString
+ val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc)
+ val estimatorPath = new Path(path, "estimator").toString
+ val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc)
+
+ val uidToParams = Map(evaluator.uid -> evaluator) ++ MetaAlgorithmReadWrite.getUidMap(estimator)
+
+ val estimatorParamMaps: Array[ParamMap] =
+ (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map {
+ pMap =>
+ val paramPairs = pMap.map { case pInfo: Map[String, String] =>
+ val est = uidToParams(pInfo("parent"))
+ val param = est.getParam(pInfo("name"))
+ val value = param.jsonDecode(pInfo("value"))
+ param -> value
+ }
+ ParamMap(paramPairs: _*)
+ }.toArray
+
+ (metadata, estimator, evaluator, estimatorParamMaps)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index c95e536abd..5a596cad06 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -21,13 +21,18 @@ import java.io.IOException
import org.apache.hadoop.fs.Path
import org.json4s._
-import org.json4s.jackson.JsonMethods._
+import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
+import org.apache.spark.ml._
+import org.apache.spark.ml.classification.OneVsRestParams
+import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.param.{ParamPair, Params}
+import org.apache.spark.ml.tuning.ValidatorParams
import org.apache.spark.sql.SQLContext
import org.apache.spark.util.Utils
@@ -352,3 +357,38 @@ private[ml] object DefaultParamsReader {
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
}
}
+
+/**
+ * Default Meta-Algorithm read and write implementation.
+ */
+private[ml] object MetaAlgorithmReadWrite {
+ /**
+ * Examine the given estimator (which may be a compound estimator) and extract a mapping
+ * from UIDs to corresponding [[Params]] instances.
+ */
+ def getUidMap(instance: Params): Map[String, Params] = {
+ val uidList = getUidMapImpl(instance)
+ val uidMap = uidList.toMap
+ if (uidList.size != uidMap.size) {
+ throw new RuntimeException(s"${instance.getClass.getName}.load found a compound estimator" +
+ s" with stages with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}.")
+ }
+ uidMap
+ }
+
+ private def getUidMapImpl(instance: Params): List[(String, Params)] = {
+ val subStages: Array[Params] = instance match {
+ case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
+ case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
+ case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
+ case ovr: OneVsRestParams =>
+ // TODO: SPARK-11892: This case may require special handling.
+ throw new UnsupportedOperationException(s"${instance.getClass.getName} write will fail" +
+ s" because it cannot yet handle an estimator containing type: ${ovr.getClass.getName}.")
+ case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
+ case _: Params => Array()
+ }
+ val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)
+ List((instance.uid, instance)) ++ subStageMaps
+ }
+}