aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-11-22 21:48:48 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-22 21:48:48 -0800
commita6fda0bfc16a13b28b1cecc96f1ff91363089144 (patch)
treecad3ab93b533e4611aa86d79298ae34f376d7dd5 /mllib
parentfe89c1817d668e46adf70d0896c42c22a547c76a (diff)
downloadspark-a6fda0bfc16a13b28b1cecc96f1ff91363089144.tar.gz
spark-a6fda0bfc16a13b28b1cecc96f1ff91363089144.tar.bz2
spark-a6fda0bfc16a13b28b1cecc96f1ff91363089144.zip
[SPARK-6791][ML] Add read/write for CrossValidator and Evaluators
I believe this works for general estimators within CrossValidator, including compound estimators. (See the complex unit test.) Added read/write for all 3 Evaluators as well. CC: mengxr yanboliang Author: Joseph K. Bradley <joseph@databricks.com> Closes #9848 from jkbradley/cv-io.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala38
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala14
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala229
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala48
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala13
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala13
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala202
12 files changed, 522 insertions, 85 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 6f15b37abc..4b2b3f8489 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -34,7 +34,6 @@ import org.apache.spark.ml.util.MLWriter
import org.apache.spark.ml.util._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.Utils
/**
* :: DeveloperApi ::
@@ -232,20 +231,9 @@ object Pipeline extends MLReadable[Pipeline] {
stages: Array[PipelineStage],
sc: SparkContext,
path: String): Unit = {
- // Copied and edited from DefaultParamsWriter.saveMetadata
- // TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication
- val uid = instance.uid
- val cls = instance.getClass.getName
val stageUids = stages.map(_.uid)
val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq))))
- val metadata = ("class" -> cls) ~
- ("timestamp" -> System.currentTimeMillis()) ~
- ("sparkVersion" -> sc.version) ~
- ("uid" -> uid) ~
- ("paramMap" -> jsonParams)
- val metadataPath = new Path(path, "metadata").toString
- val metadataJson = compact(render(metadata))
- sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
+ DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = Some(jsonParams))
// Save stages
val stagesDir = new Path(path, "stages").toString
@@ -266,30 +254,10 @@ object Pipeline extends MLReadable[Pipeline] {
implicit val format = DefaultFormats
val stagesDir = new Path(path, "stages").toString
- val stageUids: Array[String] = metadata.params match {
- case JObject(pairs) =>
- if (pairs.length != 1) {
- // Should not happen unless file is corrupted or we have a bug.
- throw new RuntimeException(
- s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.")
- }
- pairs.head match {
- case ("stageUids", jsonValue) =>
- jsonValue.extract[Seq[String]].toArray
- case (paramName, jsonValue) =>
- // Should not happen unless file is corrupted or we have a bug.
- throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" +
- s" in metadata: ${metadata.metadataStr}")
- }
- case _ =>
- throw new IllegalArgumentException(
- s"Cannot recognize JSON metadata: ${metadata.metadataStr}.")
- }
+ val stageUids: Array[String] = (metadata.params \ "stageUids").extract[Seq[String]].toArray
val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) =>
val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir)
- val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc)
- val cls = Utils.classForName(stageMetadata.className)
- cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath)
+ DefaultParamsReader.loadParamsInstance[PipelineStage](stagePath, sc)
}
(metadata.uid, stages)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index 1fe3abaca8..bfb70963b1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.{DataFrame, Row}
@@ -33,7 +33,7 @@ import org.apache.spark.sql.types.DoubleType
@Since("1.2.0")
@Experimental
class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
- extends Evaluator with HasRawPredictionCol with HasLabelCol {
+ extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable {
@Since("1.2.0")
def this() = this(Identifiable.randomUID("binEval"))
@@ -105,3 +105,10 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
@Since("1.4.1")
override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
}
+
+@Since("1.6.0")
+object BinaryClassificationEvaluator extends DefaultParamsReadable[BinaryClassificationEvaluator] {
+
+ @Since("1.6.0")
+ override def load(path: String): BinaryClassificationEvaluator = super.load(path)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
index df5f04ca5a..c44db0ec59 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
-import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
+import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, SchemaUtils, Identifiable}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.sql.{Row, DataFrame}
import org.apache.spark.sql.types.DoubleType
@@ -32,7 +32,7 @@ import org.apache.spark.sql.types.DoubleType
@Since("1.5.0")
@Experimental
class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String)
- extends Evaluator with HasPredictionCol with HasLabelCol {
+ extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
@Since("1.5.0")
def this() = this(Identifiable.randomUID("mcEval"))
@@ -101,3 +101,11 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
@Since("1.5.0")
override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra)
}
+
+@Since("1.6.0")
+object MulticlassClassificationEvaluator
+ extends DefaultParamsReadable[MulticlassClassificationEvaluator] {
+
+ @Since("1.6.0")
+ override def load(path: String): MulticlassClassificationEvaluator = super.load(path)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index ba012f444d..daaa174a08 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
@@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, FloatType}
@Since("1.4.0")
@Experimental
final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
- extends Evaluator with HasPredictionCol with HasLabelCol {
+ extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("regEval"))
@@ -104,3 +104,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
@Since("1.5.0")
override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra)
}
+
+@Since("1.6.0")
+object RegressionEvaluator extends DefaultParamsReadable[RegressionEvaluator] {
+
+ @Since("1.6.0")
+ override def load(path: String): RegressionEvaluator = super.load(path)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 4d35177ad9..b798aa1fab 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -27,9 +27,8 @@ import scala.util.hashing.byteswap64
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.hadoop.fs.{FileSystem, Path}
-import org.json4s.{DefaultFormats, JValue}
+import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
-import org.json4s.jackson.JsonMethods._
import org.apache.spark.{Logging, Partitioner}
import org.apache.spark.annotation.{Since, DeveloperApi, Experimental}
@@ -240,7 +239,7 @@ object ALSModel extends MLReadable[ALSModel] {
private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
- val extraMetadata = render("rank" -> instance.rank)
+ val extraMetadata = "rank" -> instance.rank
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
val userPath = new Path(path, "userFactors").toString
instance.userFactors.write.format("parquet").save(userPath)
@@ -257,14 +256,7 @@ object ALSModel extends MLReadable[ALSModel] {
override def load(path: String): ALSModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
implicit val format = DefaultFormats
- val rank: Int = metadata.extraMetadata match {
- case Some(m: JValue) =>
- (m \ "rank").extract[Int]
- case None =>
- throw new RuntimeException(s"ALSModel loader could not read rank from JSON metadata:" +
- s" ${metadata.metadataStr}")
- }
-
+ val rank = (metadata.metadata \ "rank").extract[Int]
val userPath = new Path(path, "userFactors").toString
val userFactors = sqlContext.read.format("parquet").load(userPath)
val itemPath = new Path(path, "itemFactors").toString
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 77d9948ed8..83a9048374 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -18,17 +18,24 @@
package org.apache.spark.ml.tuning
import com.github.fommil.netlib.F2jBLAS
+import org.apache.hadoop.fs.Path
+import org.json4s.{JObject, DefaultFormats}
+import org.json4s.jackson.JsonMethods._
-import org.apache.spark.Logging
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.classification.OneVsRestParams
+import org.apache.spark.ml.feature.RFormulaModel
+import org.apache.spark.{SparkContext, Logging}
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
+
/**
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
*/
@@ -53,7 +60,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams {
*/
@Experimental
class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel]
- with CrossValidatorParams with Logging {
+ with CrossValidatorParams with MLWritable with Logging {
def this() = this(Identifiable.randomUID("cv"))
@@ -131,6 +138,166 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
}
copied
}
+
+ // Currently, this only works if all [[Param]]s in [[estimatorParamMaps]] are simple types.
+ // E.g., this may fail if a [[Param]] is an instance of an [[Estimator]].
+ // However, this case should be unusual.
+ @Since("1.6.0")
+ override def write: MLWriter = new CrossValidator.CrossValidatorWriter(this)
+}
+
+@Since("1.6.0")
+object CrossValidator extends MLReadable[CrossValidator] {
+
+ @Since("1.6.0")
+ override def read: MLReader[CrossValidator] = new CrossValidatorReader
+
+ @Since("1.6.0")
+ override def load(path: String): CrossValidator = super.load(path)
+
+ private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter {
+
+ SharedReadWrite.validateParams(instance)
+
+ override protected def saveImpl(path: String): Unit =
+ SharedReadWrite.saveImpl(path, instance, sc)
+ }
+
+ private class CrossValidatorReader extends MLReader[CrossValidator] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[CrossValidator].getName
+
+ override def load(path: String): CrossValidator = {
+ val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
+ SharedReadWrite.load(path, sc, className)
+ new CrossValidator(metadata.uid)
+ .setEstimator(estimator)
+ .setEvaluator(evaluator)
+ .setEstimatorParamMaps(estimatorParamMaps)
+ .setNumFolds(numFolds)
+ }
+ }
+
+ private object CrossValidatorReader {
+ /**
+ * Examine the given estimator (which may be a compound estimator) and extract a mapping
+ * from UIDs to corresponding [[Params]] instances.
+ */
+ def getUidMap(instance: Params): Map[String, Params] = {
+ val uidList = getUidMapImpl(instance)
+ val uidMap = uidList.toMap
+ if (uidList.size != uidMap.size) {
+ throw new RuntimeException("CrossValidator.load found a compound estimator with stages" +
+ s" with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}")
+ }
+ uidMap
+ }
+
+ def getUidMapImpl(instance: Params): List[(String, Params)] = {
+ val subStages: Array[Params] = instance match {
+ case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
+ case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
+ case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
+ case ovr: OneVsRestParams =>
+ // TODO: SPARK-11892: This case may require special handling.
+ throw new UnsupportedOperationException("CrossValidator write will fail because it" +
+ " cannot yet handle an estimator containing type: ${ovr.getClass.getName}")
+ case rform: RFormulaModel =>
+ // TODO: SPARK-11891: This case may require special handling.
+ throw new UnsupportedOperationException("CrossValidator write will fail because it" +
+ " cannot yet handle an estimator containing an RFormulaModel")
+ case _: Params => Array()
+ }
+ val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)
+ List((instance.uid, instance)) ++ subStageMaps
+ }
+ }
+
+ private[tuning] object SharedReadWrite {
+
+ /**
+ * Check that [[CrossValidator.evaluator]] and [[CrossValidator.estimator]] are Writable.
+ * This does not check [[CrossValidator.estimatorParamMaps]].
+ */
+ def validateParams(instance: ValidatorParams): Unit = {
+ def checkElement(elem: Params, name: String): Unit = elem match {
+ case stage: MLWritable => // good
+ case other =>
+ throw new UnsupportedOperationException("CrossValidator write will fail " +
+ s" because it contains $name which does not implement Writable." +
+ s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
+ }
+ checkElement(instance.getEvaluator, "evaluator")
+ checkElement(instance.getEstimator, "estimator")
+ // Check to make sure all Params apply to this estimator. Throw an error if any do not.
+ // Extraneous Params would cause problems when loading the estimatorParamMaps.
+ val uidToInstance: Map[String, Params] = CrossValidatorReader.getUidMap(instance)
+ instance.getEstimatorParamMaps.foreach { case pMap: ParamMap =>
+ pMap.toSeq.foreach { case ParamPair(p, v) =>
+ require(uidToInstance.contains(p.parent), s"CrossValidator save requires all Params in" +
+ s" estimatorParamMaps to apply to this CrossValidator, its Estimator, or its" +
+ s" Evaluator. An extraneous Param was found: $p")
+ }
+ }
+ }
+
+ private[tuning] def saveImpl(
+ path: String,
+ instance: CrossValidatorParams,
+ sc: SparkContext,
+ extraMetadata: Option[JObject] = None): Unit = {
+ import org.json4s.JsonDSL._
+
+ val estimatorParamMapsJson = compact(render(
+ instance.getEstimatorParamMaps.map { case paramMap =>
+ paramMap.toSeq.map { case ParamPair(p, v) =>
+ Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v))
+ }
+ }.toSeq
+ ))
+ val jsonParams = List(
+ "numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)),
+ "estimatorParamMaps" -> parse(estimatorParamMapsJson)
+ )
+ DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
+
+ val evaluatorPath = new Path(path, "evaluator").toString
+ instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath)
+ val estimatorPath = new Path(path, "estimator").toString
+ instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath)
+ }
+
+ private[tuning] def load[M <: Model[M]](
+ path: String,
+ sc: SparkContext,
+ expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap], Int) = {
+
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
+
+ implicit val format = DefaultFormats
+ val evaluatorPath = new Path(path, "evaluator").toString
+ val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc)
+ val estimatorPath = new Path(path, "estimator").toString
+ val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc)
+
+ val uidToParams = Map(evaluator.uid -> evaluator) ++ CrossValidatorReader.getUidMap(estimator)
+
+ val numFolds = (metadata.params \ "numFolds").extract[Int]
+ val estimatorParamMaps: Array[ParamMap] =
+ (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map {
+ pMap =>
+ val paramPairs = pMap.map { case pInfo: Map[String, String] =>
+ val est = uidToParams(pInfo("parent"))
+ val param = est.getParam(pInfo("name"))
+ val value = param.jsonDecode(pInfo("value"))
+ param -> value
+ }
+ ParamMap(paramPairs: _*)
+ }.toArray
+ (metadata, estimator, evaluator, estimatorParamMaps, numFolds)
+ }
+ }
}
/**
@@ -139,14 +306,14 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
*
* @param bestModel The best model selected from k-fold cross validation.
* @param avgMetrics Average cross-validation metrics for each paramMap in
- * [[estimatorParamMaps]], in the corresponding order.
+ * [[CrossValidator.estimatorParamMaps]], in the corresponding order.
*/
@Experimental
class CrossValidatorModel private[ml] (
override val uid: String,
val bestModel: Model[_],
val avgMetrics: Array[Double])
- extends Model[CrossValidatorModel] with CrossValidatorParams {
+ extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {
override def validateParams(): Unit = {
bestModel.validateParams()
@@ -168,4 +335,54 @@ class CrossValidatorModel private[ml] (
avgMetrics.clone())
copyValues(copied, extra).setParent(parent)
}
+
+ @Since("1.6.0")
+ override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this)
+}
+
+@Since("1.6.0")
+object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
+
+ import CrossValidator.SharedReadWrite
+
+ @Since("1.6.0")
+ override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): CrossValidatorModel = super.load(path)
+
+ private[CrossValidatorModel]
+ class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter {
+
+ SharedReadWrite.validateParams(instance)
+
+ override protected def saveImpl(path: String): Unit = {
+ import org.json4s.JsonDSL._
+ val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq
+ SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata))
+ val bestModelPath = new Path(path, "bestModel").toString
+ instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
+ }
+ }
+
+ private class CrossValidatorModelReader extends MLReader[CrossValidatorModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[CrossValidatorModel].getName
+
+ override def load(path: String): CrossValidatorModel = {
+ implicit val format = DefaultFormats
+
+ val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
+ SharedReadWrite.load(path, sc, className)
+ val bestModelPath = new Path(path, "bestModel").toString
+ val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
+ val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
+ val cv = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics)
+ cv.set(cv.estimator, estimator)
+ .set(cv.evaluator, evaluator)
+ .set(cv.estimatorParamMaps, estimatorParamMaps)
+ .set(cv.numFolds, numFolds)
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index ff9322dba1..8484b1f801 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -202,25 +202,36 @@ private[ml] object DefaultParamsWriter {
* - timestamp
* - sparkVersion
* - uid
- * - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]].
+ * - paramMap
+ * - (optionally, extra metadata)
+ * @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
+ * @param paramMap If given, this is saved in the "paramMap" field.
+ * Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
+ * [[org.apache.spark.ml.param.Param.jsonEncode()]].
*/
def saveMetadata(
instance: Params,
path: String,
sc: SparkContext,
- extraMetadata: Option[JValue] = None): Unit = {
+ extraMetadata: Option[JObject] = None,
+ paramMap: Option[JValue] = None): Unit = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
- val jsonParams = params.map { case ParamPair(p, v) =>
+ val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
- }.toList
- val metadata = ("class" -> cls) ~
+ }.toList))
+ val basicMetadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
- ("paramMap" -> jsonParams) ~
- ("extraMetadata" -> extraMetadata)
+ ("paramMap" -> jsonParams)
+ val metadata = extraMetadata match {
+ case Some(jObject) =>
+ basicMetadata ~ jObject
+ case None =>
+ basicMetadata
+ }
val metadataPath = new Path(path, "metadata").toString
val metadataJson = compact(render(metadata))
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
@@ -251,8 +262,8 @@ private[ml] object DefaultParamsReader {
/**
* All info from metadata file.
* @param params paramMap, as a [[JValue]]
- * @param extraMetadata Extra metadata saved by [[DefaultParamsWriter.saveMetadata()]]
- * @param metadataStr Full metadata file String (for debugging)
+ * @param metadata All metadata, including the other fields
+ * @param metadataJson Full metadata file String (for debugging)
*/
case class Metadata(
className: String,
@@ -260,8 +271,8 @@ private[ml] object DefaultParamsReader {
timestamp: Long,
sparkVersion: String,
params: JValue,
- extraMetadata: Option[JValue],
- metadataStr: String)
+ metadata: JValue,
+ metadataJson: String)
/**
* Load metadata from file.
@@ -279,13 +290,12 @@ private[ml] object DefaultParamsReader {
val timestamp = (metadata \ "timestamp").extract[Long]
val sparkVersion = (metadata \ "sparkVersion").extract[String]
val params = metadata \ "paramMap"
- val extraMetadata = (metadata \ "extraMetadata").extract[Option[JValue]]
if (expectedClassName.nonEmpty) {
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
s" $expectedClassName but found class name $className")
}
- Metadata(className, uid, timestamp, sparkVersion, params, extraMetadata, metadataStr)
+ Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
}
/**
@@ -303,7 +313,17 @@ private[ml] object DefaultParamsReader {
}
case _ =>
throw new IllegalArgumentException(
- s"Cannot recognize JSON metadata: ${metadata.metadataStr}.")
+ s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
}
}
+
+ /**
+ * Load a [[Params]] instance from the given path, and return it.
+ * This assumes the instance implements [[MLReadable]].
+ */
+ def loadParamsInstance[T](path: String, sc: SparkContext): T = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc)
+ val cls = Utils.classForName(metadata.className)
+ cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 12aba6bc6d..8c86767456 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -17,11 +17,9 @@
package org.apache.spark.ml
-import java.io.File
-
import scala.collection.JavaConverters._
-import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.Path
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.when
import org.scalatest.mock.MockitoSugar.mock
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
index def869fe66..a535c1218e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
@@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.util.MLlibTestSparkContext
-class BinaryClassificationEvaluatorSuite extends SparkFunSuite {
+class BinaryClassificationEvaluatorSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new BinaryClassificationEvaluator)
}
+
+ test("read/write") {
+ val evaluator = new BinaryClassificationEvaluator()
+ .setRawPredictionCol("myRawPrediction")
+ .setLabelCol("myLabel")
+ .setMetricName("areaUnderPR")
+ testDefaultReadWrite(evaluator)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
index 6d8412b0b3..7ee65975d2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
@@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.util.MLlibTestSparkContext
-class MulticlassClassificationEvaluatorSuite extends SparkFunSuite {
+class MulticlassClassificationEvaluatorSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new MulticlassClassificationEvaluator)
}
+
+ test("read/write") {
+ val evaluator = new MulticlassClassificationEvaluator()
+ .setPredictionCol("myPrediction")
+ .setLabelCol("myLabel")
+ .setMetricName("recall")
+ testDefaultReadWrite(evaluator)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
index aa722da323..60886bf77d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
@@ -20,10 +20,12 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
-class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class RegressionEvaluatorSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new RegressionEvaluator)
@@ -73,4 +75,12 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext
evaluator.setMetricName("mae")
assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001)
}
+
+ test("read/write") {
+ val evaluator = new RegressionEvaluator()
+ .setPredictionCol("myPrediction")
+ .setLabelCol("myLabel")
+ .setMetricName("r2")
+ testDefaultReadWrite(evaluator)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index cbe09292a0..dd6366050c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -18,19 +18,22 @@
package org.apache.spark.ml.tuning
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.MLTestingUtils
-import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.feature.HashingTF
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.{Pipeline, Estimator, Model}
+import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression}
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
-import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.{ParamPair, ParamMap}
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.types.StructType
-class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class CrossValidatorSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@transient var dataset: DataFrame = _
@@ -95,7 +98,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("validateParams should check estimatorParamMaps") {
- import CrossValidatorSuite._
+ import CrossValidatorSuite.{MyEstimator, MyEvaluator}
val est = new MyEstimator("est")
val eval = new MyEvaluator
@@ -116,9 +119,194 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
cv.validateParams()
}
}
+
+ test("read/write: CrossValidator with simple estimator") {
+ val lr = new LogisticRegression().setMaxIter(3)
+ val evaluator = new BinaryClassificationEvaluator()
+ .setMetricName("areaUnderPR") // not default metric
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.1, 0.2))
+ .build()
+ val cv = new CrossValidator()
+ .setEstimator(lr)
+ .setEvaluator(evaluator)
+ .setNumFolds(20)
+ .setEstimatorParamMaps(paramMaps)
+
+ val cv2 = testDefaultReadWrite(cv, testParams = false)
+
+ assert(cv.uid === cv2.uid)
+ assert(cv.getNumFolds === cv2.getNumFolds)
+
+ assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
+ val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
+ assert(evaluator.uid === evaluator2.uid)
+ assert(evaluator.getMetricName === evaluator2.getMetricName)
+
+ cv2.getEstimator match {
+ case lr2: LogisticRegression =>
+ assert(lr.uid === lr2.uid)
+ assert(lr.getMaxIter === lr2.getMaxIter)
+ case other =>
+ throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
+ s" LogisticRegression but found ${other.getClass.getName}")
+ }
+
+ CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+ }
+
+ test("read/write: CrossValidator with complex estimator") {
+ // workflow: CrossValidator[Pipeline[HashingTF, CrossValidator[LogisticRegression]]]
+ val lrEvaluator = new BinaryClassificationEvaluator()
+ .setMetricName("areaUnderPR") // not default metric
+
+ val lr = new LogisticRegression().setMaxIter(3)
+ val lrParamMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.1, 0.2))
+ .build()
+ val lrcv = new CrossValidator()
+ .setEstimator(lr)
+ .setEvaluator(lrEvaluator)
+ .setEstimatorParamMaps(lrParamMaps)
+
+ val hashingTF = new HashingTF()
+ val pipeline = new Pipeline().setStages(Array(hashingTF, lrcv))
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(hashingTF.numFeatures, Array(10, 20))
+ .addGrid(lr.elasticNetParam, Array(0.0, 1.0))
+ .build()
+ val evaluator = new BinaryClassificationEvaluator()
+
+ val cv = new CrossValidator()
+ .setEstimator(pipeline)
+ .setEvaluator(evaluator)
+ .setNumFolds(20)
+ .setEstimatorParamMaps(paramMaps)
+
+ val cv2 = testDefaultReadWrite(cv, testParams = false)
+
+ assert(cv.uid === cv2.uid)
+ assert(cv.getNumFolds === cv2.getNumFolds)
+
+ assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
+ assert(cv.getEvaluator.uid === cv2.getEvaluator.uid)
+
+ CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+
+ cv2.getEstimator match {
+ case pipeline2: Pipeline =>
+ assert(pipeline.uid === pipeline2.uid)
+ pipeline2.getStages match {
+ case Array(hashingTF2: HashingTF, lrcv2: CrossValidator) =>
+ assert(hashingTF.uid === hashingTF2.uid)
+ lrcv2.getEstimator match {
+ case lr2: LogisticRegression =>
+ assert(lr.uid === lr2.uid)
+ assert(lr.getMaxIter === lr2.getMaxIter)
+ case other =>
+ throw new AssertionError(s"Loaded internal CrossValidator expected to be" +
+ s" LogisticRegression but found type ${other.getClass.getName}")
+ }
+ assert(lrcv.uid === lrcv2.uid)
+ assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
+ assert(lrEvaluator.uid === lrcv2.getEvaluator.uid)
+ CrossValidatorSuite.compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps)
+ case other =>
+ throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)" +
+ " but found: " + other.map(_.getClass.getName).mkString(", "))
+ }
+ case other =>
+ throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
+ s" CrossValidator but found ${other.getClass.getName}")
+ }
+ }
+
+ test("read/write: CrossValidator fails for extraneous Param") {
+ val lr = new LogisticRegression()
+ val lr2 = new LogisticRegression()
+ val evaluator = new BinaryClassificationEvaluator()
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.1, 0.2))
+ .addGrid(lr2.regParam, Array(0.1, 0.2))
+ .build()
+ val cv = new CrossValidator()
+ .setEstimator(lr)
+ .setEvaluator(evaluator)
+ .setEstimatorParamMaps(paramMaps)
+ withClue("CrossValidator.write failed to catch extraneous Param error") {
+ intercept[IllegalArgumentException] {
+ cv.write
+ }
+ }
+ }
+
+ test("read/write: CrossValidatorModel") {
+ val lr = new LogisticRegression()
+ .setThreshold(0.6)
+ val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 1.2)
+ .setThreshold(0.6)
+ val evaluator = new BinaryClassificationEvaluator()
+ .setMetricName("areaUnderPR") // not default metric
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.1, 0.2))
+ .build()
+ val cv = new CrossValidatorModel("cvUid", lrModel, Array(0.3, 0.6))
+ cv.set(cv.estimator, lr)
+ .set(cv.evaluator, evaluator)
+ .set(cv.numFolds, 20)
+ .set(cv.estimatorParamMaps, paramMaps)
+
+ val cv2 = testDefaultReadWrite(cv, testParams = false)
+
+ assert(cv.uid === cv2.uid)
+ assert(cv.getNumFolds === cv2.getNumFolds)
+
+ assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
+ val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
+ assert(evaluator.uid === evaluator2.uid)
+ assert(evaluator.getMetricName === evaluator2.getMetricName)
+
+ cv2.getEstimator match {
+ case lr2: LogisticRegression =>
+ assert(lr.uid === lr2.uid)
+ assert(lr.getThreshold === lr2.getThreshold)
+ case other =>
+ throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
+ s" LogisticRegression but found ${other.getClass.getName}")
+ }
+
+ CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+
+ cv2.bestModel match {
+ case lrModel2: LogisticRegressionModel =>
+ assert(lrModel.uid === lrModel2.uid)
+ assert(lrModel.getThreshold === lrModel2.getThreshold)
+ assert(lrModel.coefficients === lrModel2.coefficients)
+ assert(lrModel.intercept === lrModel2.intercept)
+ case other =>
+ throw new AssertionError(s"Loaded CrossValidator expected bestModel of type" +
+ s" LogisticRegressionModel but found ${other.getClass.getName}")
+ }
+ assert(cv.avgMetrics === cv2.avgMetrics)
+ }
}
-object CrossValidatorSuite {
+object CrossValidatorSuite extends SparkFunSuite {
+
+ /**
+ * Assert sequences of estimatorParamMaps are identical.
+ * Params must be simple types comparable with `===`.
+ */
+ def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = {
+ assert(pMaps.length === pMaps2.length)
+ pMaps.zip(pMaps2).foreach { case (pMap, pMap2) =>
+ assert(pMap.size === pMap2.size)
+ pMap.toSeq.foreach { case ParamPair(p, v) =>
+ assert(pMap2.contains(p))
+ assert(pMap2(p) === v)
+ }
+ }
+ }
abstract class MyModel extends Model[MyModel]