aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala82
1 files changed, 77 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index c95e536abd..7dec07ea14 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -21,13 +21,18 @@ import java.io.IOException
import org.apache.hadoop.fs.Path
import org.json4s._
-import org.json4s.jackson.JsonMethods._
+import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
+import org.apache.spark.ml._
+import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel}
+import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.param.{ParamPair, Params}
+import org.apache.spark.ml.tuning.ValidatorParams
import org.apache.spark.sql.SQLContext
import org.apache.spark.util.Utils
@@ -139,6 +144,7 @@ private[ml] trait DefaultParamsWritable extends MLWritable { self: Params =>
/**
* Abstract class for utility classes that can load ML instances.
+ *
* @tparam T ML instance type
*/
@Experimental
@@ -157,6 +163,7 @@ abstract class MLReader[T] extends BaseReadWrite {
/**
* Trait for objects that provide [[MLReader]].
+ *
* @tparam T ML instance type
*/
@Experimental
@@ -187,6 +194,7 @@ private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] {
* Default [[MLWriter]] implementation for transformers and estimators that contain basic
* (json4s-serializable) params and no data. This will not handle more complex params or types with
* data (e.g., models with coefficients).
+ *
* @param instance object to save
*/
private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter {
@@ -206,6 +214,7 @@ private[ml] object DefaultParamsWriter {
* - uid
* - paramMap
* - (optionally, extra metadata)
+ *
* @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
* @param paramMap If given, this is saved in the "paramMap" field.
* Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
@@ -217,6 +226,22 @@ private[ml] object DefaultParamsWriter {
sc: SparkContext,
extraMetadata: Option[JObject] = None,
paramMap: Option[JValue] = None): Unit = {
+ val metadataPath = new Path(path, "metadata").toString
+ val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
+ sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
+ }
+
+ /**
+ * Helper for [[saveMetadata()]] which extracts the JSON to save.
+ * This is useful for ensemble models which need to save metadata for many sub-models.
+ *
+ * @see [[saveMetadata()]] for details on what this includes.
+ */
+ def getMetadataToSave(
+ instance: Params,
+ sc: SparkContext,
+ extraMetadata: Option[JObject] = None,
+ paramMap: Option[JValue] = None): String = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
@@ -234,9 +259,8 @@ private[ml] object DefaultParamsWriter {
case None =>
basicMetadata
}
- val metadataPath = new Path(path, "metadata").toString
- val metadataJson = compact(render(metadata))
- sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
+ val metadataJson: String = compact(render(metadata))
+ metadataJson
}
}
@@ -244,6 +268,7 @@ private[ml] object DefaultParamsWriter {
* Default [[MLReader]] implementation for transformers and estimators that contain basic
* (json4s-serializable) params and no data. This will not handle more complex params or types with
* data (e.g., models with coefficients).
+ *
* @tparam T ML instance type
* TODO: Consider adding check for correct class name.
*/
@@ -263,6 +288,7 @@ private[ml] object DefaultParamsReader {
/**
* All info from metadata file.
+ *
* @param params paramMap, as a [[JValue]]
* @param metadata All metadata, including the other fields
* @param metadataJson Full metadata file String (for debugging)
@@ -299,13 +325,26 @@ private[ml] object DefaultParamsReader {
}
/**
- * Load metadata from file.
+ * Load metadata saved using [[DefaultParamsWriter.saveMetadata()]]
+ *
* @param expectedClassName If non empty, this is checked against the loaded metadata.
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
*/
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
val metadataPath = new Path(path, "metadata").toString
val metadataStr = sc.textFile(metadataPath, 1).first()
+ parseMetadata(metadataStr, expectedClassName)
+ }
+
+ /**
+ * Parse metadata JSON string produced by [[DefaultParamsWriter.getMetadataToSave()]].
+ * This is a helper function for [[loadMetadata()]].
+ *
+ * @param metadataStr JSON string of metadata
+ * @param expectedClassName If non empty, this is checked against the loaded metadata.
+ * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
+ */
+ def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = {
val metadata = parse(metadataStr)
implicit val format = DefaultFormats
@@ -352,3 +391,36 @@ private[ml] object DefaultParamsReader {
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
}
}
+
+/**
+ * Default Meta-Algorithm read and write implementation.
+ */
+private[ml] object MetaAlgorithmReadWrite {
+ /**
+ * Examine the given estimator (which may be a compound estimator) and extract a mapping
+ * from UIDs to corresponding [[Params]] instances.
+ */
+ def getUidMap(instance: Params): Map[String, Params] = {
+ val uidList = getUidMapImpl(instance)
+ val uidMap = uidList.toMap
+ if (uidList.size != uidMap.size) {
+ throw new RuntimeException(s"${instance.getClass.getName}.load found a compound estimator" +
+ s" with stages with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}.")
+ }
+ uidMap
+ }
+
+ private def getUidMapImpl(instance: Params): List[(String, Params)] = {
+ val subStages: Array[Params] = instance match {
+ case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
+ case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
+ case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
+ case ovr: OneVsRest => Array(ovr.getClassifier)
+ case ovrModel: OneVsRestModel => Array(ovrModel.getClassifier) ++ ovrModel.models
+ case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
+ case _: Params => Array()
+ }
+ val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)
+ List((instance.uid, instance)) ++ subStageMaps
+ }
+}