diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala | 82 |
1 files changed, 77 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index c95e536abd..7dec07ea14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -21,13 +21,18 @@ import java.io.IOException import org.apache.hadoop.fs.Path import org.json4s._ -import org.json4s.jackson.JsonMethods._ +import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging +import org.apache.spark.ml._ +import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} +import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param.{ParamPair, Params} +import org.apache.spark.ml.tuning.ValidatorParams import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils @@ -139,6 +144,7 @@ private[ml] trait DefaultParamsWritable extends MLWritable { self: Params => /** * Abstract class for utility classes that can load ML instances. + * * @tparam T ML instance type */ @Experimental @@ -157,6 +163,7 @@ abstract class MLReader[T] extends BaseReadWrite { /** * Trait for objects that provide [[MLReader]]. + * * @tparam T ML instance type */ @Experimental @@ -187,6 +194,7 @@ private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] { * Default [[MLWriter]] implementation for transformers and estimators that contain basic * (json4s-serializable) params and no data. This will not handle more complex params or types with * data (e.g., models with coefficients). + * * @param instance object to save */ private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter { @@ -206,6 +214,7 @@ private[ml] object DefaultParamsWriter { * - uid * - paramMap * - (optionally, extra metadata) + * * @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc. * @param paramMap If given, this is saved in the "paramMap" field. * Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using @@ -217,6 +226,22 @@ private[ml] object DefaultParamsWriter { sc: SparkContext, extraMetadata: Option[JObject] = None, paramMap: Option[JValue] = None): Unit = { + val metadataPath = new Path(path, "metadata").toString + val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap) + sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + } + + /** + * Helper for [[saveMetadata()]] which extracts the JSON to save. + * This is useful for ensemble models which need to save metadata for many sub-models. + * + * @see [[saveMetadata()]] for details on what this includes. + */ + def getMetadataToSave( + instance: Params, + sc: SparkContext, + extraMetadata: Option[JObject] = None, + paramMap: Option[JValue] = None): String = { val uid = instance.uid val cls = instance.getClass.getName val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] @@ -234,9 +259,8 @@ private[ml] object DefaultParamsWriter { case None => basicMetadata } - val metadataPath = new Path(path, "metadata").toString - val metadataJson = compact(render(metadata)) - sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + val metadataJson: String = compact(render(metadata)) + metadataJson } } @@ -244,6 +268,7 @@ private[ml] object DefaultParamsWriter { * Default [[MLReader]] implementation for transformers and estimators that contain basic * (json4s-serializable) params and no data. This will not handle more complex params or types with * data (e.g., models with coefficients). + * * @tparam T ML instance type * TODO: Consider adding check for correct class name. */ @@ -263,6 +288,7 @@ private[ml] object DefaultParamsReader { /** * All info from metadata file. + * * @param params paramMap, as a [[JValue]] * @param metadata All metadata, including the other fields * @param metadataJson Full metadata file String (for debugging) @@ -299,13 +325,26 @@ private[ml] object DefaultParamsReader { } /** - * Load metadata from file. + * Load metadata saved using [[DefaultParamsWriter.saveMetadata()]] + * * @param expectedClassName If non empty, this is checked against the loaded metadata. * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata */ def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = { val metadataPath = new Path(path, "metadata").toString val metadataStr = sc.textFile(metadataPath, 1).first() + parseMetadata(metadataStr, expectedClassName) + } + + /** + * Parse metadata JSON string produced by [[DefaultParamsWriter.getMetadataToSave()]]. + * This is a helper function for [[loadMetadata()]]. + * + * @param metadataStr JSON string of metadata + * @param expectedClassName If non empty, this is checked against the loaded metadata. + * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata + */ + def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = { val metadata = parse(metadataStr) implicit val format = DefaultFormats @@ -352,3 +391,36 @@ private[ml] object DefaultParamsReader { cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) } } + +/** + * Default Meta-Algorithm read and write implementation. + */ +private[ml] object MetaAlgorithmReadWrite { + /** + * Examine the given estimator (which may be a compound estimator) and extract a mapping + * from UIDs to corresponding [[Params]] instances. + */ + def getUidMap(instance: Params): Map[String, Params] = { + val uidList = getUidMapImpl(instance) + val uidMap = uidList.toMap + if (uidList.size != uidMap.size) { + throw new RuntimeException(s"${instance.getClass.getName}.load found a compound estimator" + + s" with stages with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}.") + } + uidMap + } + + private def getUidMapImpl(instance: Params): List[(String, Params)] = { + val subStages: Array[Params] = instance match { + case p: Pipeline => p.getStages.asInstanceOf[Array[Params]] + case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]] + case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator) + case ovr: OneVsRest => Array(ovr.getClassifier) + case ovrModel: OneVsRestModel => Array(ovrModel.getClassifier) ++ ovrModel.models + case rformModel: RFormulaModel => Array(rformModel.pipelineModel) + case _: Params => Array() + } + val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _) + List((instance.uid, instance)) ++ subStageMaps + } +} |