aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala169
1 files changed, 29 insertions, 140 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 963f81cb3e..de563d4fad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -17,27 +17,25 @@
package org.apache.spark.ml.tuning
+import java.util.{List => JList}
+
+import scala.collection.JavaConverters._
+
import com.github.fommil.netlib.F2jBLAS
import org.apache.hadoop.fs.Path
-import org.json4s.{DefaultFormats, JObject}
-import org.json4s.jackson.JsonMethods._
+import org.json4s.DefaultFormats
-import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
-import org.apache.spark.ml.classification.OneVsRestParams
import org.apache.spark.ml.evaluation.Evaluator
-import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.HasSeed
import org.apache.spark.ml.util._
-import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
-
/**
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
*/
@@ -45,6 +43,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed {
/**
* Param for number of folds for cross validation. Must be >= 2.
* Default: 3
+ *
* @group param
*/
val numFolds: IntParam = new IntParam(this, "numFolds",
@@ -91,8 +90,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
@Since("2.0.0")
def setSeed(value: Long): this.type = set(seed, value)
- @Since("1.4.0")
- override def fit(dataset: DataFrame): CrossValidatorModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): CrossValidatorModel = {
val schema = dataset.schema
transformSchema(schema, logging = true)
val sqlCtx = dataset.sqlContext
@@ -101,7 +100,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val epm = $(estimatorParamMaps)
val numModels = epm.length
val metrics = new Array[Double](epm.length)
- val splits = MLUtils.kFold(dataset.rdd, $(numFolds), $(seed))
+ val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
@@ -163,10 +162,10 @@ object CrossValidator extends MLReadable[CrossValidator] {
private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter {
- SharedReadWrite.validateParams(instance)
+ ValidatorParams.validateParams(instance)
override protected def saveImpl(path: String): Unit =
- SharedReadWrite.saveImpl(path, instance, sc)
+ ValidatorParams.saveImpl(path, instance, sc)
}
private class CrossValidatorReader extends MLReader[CrossValidator] {
@@ -175,8 +174,11 @@ object CrossValidator extends MLReadable[CrossValidator] {
private val className = classOf[CrossValidator].getName
override def load(path: String): CrossValidator = {
- val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
- SharedReadWrite.load(path, sc, className)
+ implicit val format = DefaultFormats
+
+ val (metadata, estimator, evaluator, estimatorParamMaps) =
+ ValidatorParams.loadImpl(path, sc, className)
+ val numFolds = (metadata.params \ "numFolds").extract[Int]
new CrossValidator(metadata.uid)
.setEstimator(estimator)
.setEvaluator(evaluator)
@@ -184,123 +186,6 @@ object CrossValidator extends MLReadable[CrossValidator] {
.setNumFolds(numFolds)
}
}
-
- private object CrossValidatorReader {
- /**
- * Examine the given estimator (which may be a compound estimator) and extract a mapping
- * from UIDs to corresponding [[Params]] instances.
- */
- def getUidMap(instance: Params): Map[String, Params] = {
- val uidList = getUidMapImpl(instance)
- val uidMap = uidList.toMap
- if (uidList.size != uidMap.size) {
- throw new RuntimeException("CrossValidator.load found a compound estimator with stages" +
- s" with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}")
- }
- uidMap
- }
-
- def getUidMapImpl(instance: Params): List[(String, Params)] = {
- val subStages: Array[Params] = instance match {
- case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
- case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
- case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
- case ovr: OneVsRestParams =>
- // TODO: SPARK-11892: This case may require special handling.
- throw new UnsupportedOperationException("CrossValidator write will fail because it" +
- " cannot yet handle an estimator containing type: ${ovr.getClass.getName}")
- case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
- case _: Params => Array()
- }
- val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)
- List((instance.uid, instance)) ++ subStageMaps
- }
- }
-
- private[tuning] object SharedReadWrite {
-
- /**
- * Check that [[CrossValidator.evaluator]] and [[CrossValidator.estimator]] are Writable.
- * This does not check [[CrossValidator.estimatorParamMaps]].
- */
- def validateParams(instance: ValidatorParams): Unit = {
- def checkElement(elem: Params, name: String): Unit = elem match {
- case stage: MLWritable => // good
- case other =>
- throw new UnsupportedOperationException("CrossValidator write will fail " +
- s" because it contains $name which does not implement Writable." +
- s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
- }
- checkElement(instance.getEvaluator, "evaluator")
- checkElement(instance.getEstimator, "estimator")
- // Check to make sure all Params apply to this estimator. Throw an error if any do not.
- // Extraneous Params would cause problems when loading the estimatorParamMaps.
- val uidToInstance: Map[String, Params] = CrossValidatorReader.getUidMap(instance)
- instance.getEstimatorParamMaps.foreach { case pMap: ParamMap =>
- pMap.toSeq.foreach { case ParamPair(p, v) =>
- require(uidToInstance.contains(p.parent), s"CrossValidator save requires all Params in" +
- s" estimatorParamMaps to apply to this CrossValidator, its Estimator, or its" +
- s" Evaluator. An extraneous Param was found: $p")
- }
- }
- }
-
- private[tuning] def saveImpl(
- path: String,
- instance: CrossValidatorParams,
- sc: SparkContext,
- extraMetadata: Option[JObject] = None): Unit = {
- import org.json4s.JsonDSL._
-
- val estimatorParamMapsJson = compact(render(
- instance.getEstimatorParamMaps.map { case paramMap =>
- paramMap.toSeq.map { case ParamPair(p, v) =>
- Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v))
- }
- }.toSeq
- ))
- val jsonParams = List(
- "numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)),
- "estimatorParamMaps" -> parse(estimatorParamMapsJson)
- )
- DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
-
- val evaluatorPath = new Path(path, "evaluator").toString
- instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath)
- val estimatorPath = new Path(path, "estimator").toString
- instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath)
- }
-
- private[tuning] def load[M <: Model[M]](
- path: String,
- sc: SparkContext,
- expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap], Int) = {
-
- val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
-
- implicit val format = DefaultFormats
- val evaluatorPath = new Path(path, "evaluator").toString
- val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc)
- val estimatorPath = new Path(path, "estimator").toString
- val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc)
-
- val uidToParams = Map(evaluator.uid -> evaluator) ++ CrossValidatorReader.getUidMap(estimator)
-
- val numFolds = (metadata.params \ "numFolds").extract[Int]
- val estimatorParamMaps: Array[ParamMap] =
- (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map {
- pMap =>
- val paramPairs = pMap.map { case pInfo: Map[String, String] =>
- val est = uidToParams(pInfo("parent"))
- val param = est.getParam(pInfo("name"))
- val value = param.jsonDecode(pInfo("value"))
- param -> value
- }
- ParamMap(paramPairs: _*)
- }.toArray
- (metadata, estimator, evaluator, estimatorParamMaps, numFolds)
- }
- }
}
/**
@@ -319,8 +204,13 @@ class CrossValidatorModel private[ml] (
@Since("1.5.0") val avgMetrics: Array[Double])
extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {
- @Since("1.4.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ /** A Python-friendly auxiliary constructor. */
+ private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: JList[Double]) = {
+ this(uid, bestModel, avgMetrics.asScala.toArray)
+ }
+
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
bestModel.transform(dataset)
}
@@ -346,8 +236,6 @@ class CrossValidatorModel private[ml] (
@Since("1.6.0")
object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
- import CrossValidator.SharedReadWrite
-
@Since("1.6.0")
override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader
@@ -357,12 +245,12 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
private[CrossValidatorModel]
class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter {
- SharedReadWrite.validateParams(instance)
+ ValidatorParams.validateParams(instance)
override protected def saveImpl(path: String): Unit = {
import org.json4s.JsonDSL._
val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq
- SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata))
+ ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
val bestModelPath = new Path(path, "bestModel").toString
instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
}
@@ -376,8 +264,9 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
override def load(path: String): CrossValidatorModel = {
implicit val format = DefaultFormats
- val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
- SharedReadWrite.load(path, sc, className)
+ val (metadata, estimator, evaluator, estimatorParamMaps) =
+ ValidatorParams.loadImpl(path, sc, className)
+ val numFolds = (metadata.params \ "numFolds").extract[Int]
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray