diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala | 169 |
1 files changed, 29 insertions, 140 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 963f81cb3e..de563d4fad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -17,27 +17,25 @@ package org.apache.spark.ml.tuning +import java.util.{List => JList} + +import scala.collection.JavaConverters._ + import com.github.fommil.netlib.F2jBLAS import org.apache.hadoop.fs.Path -import org.json4s.{DefaultFormats, JObject} -import org.json4s.jackson.JsonMethods._ +import org.json4s.DefaultFormats -import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml._ -import org.apache.spark.ml.classification.OneVsRestParams import org.apache.spark.ml.evaluation.Evaluator -import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ -import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType - /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ @@ -45,6 +43,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed { /** * Param for number of folds for cross validation. Must be >= 2. * Default: 3 + * * @group param */ val numFolds: IntParam = new IntParam(this, "numFolds", @@ -91,8 +90,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("2.0.0") def setSeed(value: Long): this.type = set(seed, value) - @Since("1.4.0") - override def fit(dataset: DataFrame): CrossValidatorModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): CrossValidatorModel = { val schema = dataset.schema transformSchema(schema, logging = true) val sqlCtx = dataset.sqlContext @@ -101,7 +100,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val epm = $(estimatorParamMaps) val numModels = epm.length val metrics = new Array[Double](epm.length) - val splits = MLUtils.kFold(dataset.rdd, $(numFolds), $(seed)) + val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() @@ -163,10 +162,10 @@ object CrossValidator extends MLReadable[CrossValidator] { private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter { - SharedReadWrite.validateParams(instance) + ValidatorParams.validateParams(instance) override protected def saveImpl(path: String): Unit = - SharedReadWrite.saveImpl(path, instance, sc) + ValidatorParams.saveImpl(path, instance, sc) } private class CrossValidatorReader extends MLReader[CrossValidator] { @@ -175,8 +174,11 @@ object CrossValidator extends MLReadable[CrossValidator] { private val className = classOf[CrossValidator].getName override def load(path: String): CrossValidator = { - val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) = - SharedReadWrite.load(path, sc, className) + implicit val format = DefaultFormats + + val (metadata, estimator, evaluator, estimatorParamMaps) = + ValidatorParams.loadImpl(path, sc, className) + val numFolds = (metadata.params \ "numFolds").extract[Int] new CrossValidator(metadata.uid) .setEstimator(estimator) .setEvaluator(evaluator) @@ -184,123 +186,6 @@ object CrossValidator extends MLReadable[CrossValidator] { .setNumFolds(numFolds) } } - - private object CrossValidatorReader { - /** - * Examine the given estimator (which may be a compound estimator) and extract a mapping - * from UIDs to corresponding [[Params]] instances. - */ - def getUidMap(instance: Params): Map[String, Params] = { - val uidList = getUidMapImpl(instance) - val uidMap = uidList.toMap - if (uidList.size != uidMap.size) { - throw new RuntimeException("CrossValidator.load found a compound estimator with stages" + - s" with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}") - } - uidMap - } - - def getUidMapImpl(instance: Params): List[(String, Params)] = { - val subStages: Array[Params] = instance match { - case p: Pipeline => p.getStages.asInstanceOf[Array[Params]] - case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]] - case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator) - case ovr: OneVsRestParams => - // TODO: SPARK-11892: This case may require special handling. - throw new UnsupportedOperationException("CrossValidator write will fail because it" + - " cannot yet handle an estimator containing type: ${ovr.getClass.getName}") - case rformModel: RFormulaModel => Array(rformModel.pipelineModel) - case _: Params => Array() - } - val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _) - List((instance.uid, instance)) ++ subStageMaps - } - } - - private[tuning] object SharedReadWrite { - - /** - * Check that [[CrossValidator.evaluator]] and [[CrossValidator.estimator]] are Writable. - * This does not check [[CrossValidator.estimatorParamMaps]]. - */ - def validateParams(instance: ValidatorParams): Unit = { - def checkElement(elem: Params, name: String): Unit = elem match { - case stage: MLWritable => // good - case other => - throw new UnsupportedOperationException("CrossValidator write will fail " + - s" because it contains $name which does not implement Writable." + - s" Non-Writable $name: ${other.uid} of type ${other.getClass}") - } - checkElement(instance.getEvaluator, "evaluator") - checkElement(instance.getEstimator, "estimator") - // Check to make sure all Params apply to this estimator. Throw an error if any do not. - // Extraneous Params would cause problems when loading the estimatorParamMaps. - val uidToInstance: Map[String, Params] = CrossValidatorReader.getUidMap(instance) - instance.getEstimatorParamMaps.foreach { case pMap: ParamMap => - pMap.toSeq.foreach { case ParamPair(p, v) => - require(uidToInstance.contains(p.parent), s"CrossValidator save requires all Params in" + - s" estimatorParamMaps to apply to this CrossValidator, its Estimator, or its" + - s" Evaluator. An extraneous Param was found: $p") - } - } - } - - private[tuning] def saveImpl( - path: String, - instance: CrossValidatorParams, - sc: SparkContext, - extraMetadata: Option[JObject] = None): Unit = { - import org.json4s.JsonDSL._ - - val estimatorParamMapsJson = compact(render( - instance.getEstimatorParamMaps.map { case paramMap => - paramMap.toSeq.map { case ParamPair(p, v) => - Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v)) - } - }.toSeq - )) - val jsonParams = List( - "numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)), - "estimatorParamMaps" -> parse(estimatorParamMapsJson) - ) - DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) - - val evaluatorPath = new Path(path, "evaluator").toString - instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath) - val estimatorPath = new Path(path, "estimator").toString - instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath) - } - - private[tuning] def load[M <: Model[M]]( - path: String, - sc: SparkContext, - expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap], Int) = { - - val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) - - implicit val format = DefaultFormats - val evaluatorPath = new Path(path, "evaluator").toString - val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc) - val estimatorPath = new Path(path, "estimator").toString - val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc) - - val uidToParams = Map(evaluator.uid -> evaluator) ++ CrossValidatorReader.getUidMap(estimator) - - val numFolds = (metadata.params \ "numFolds").extract[Int] - val estimatorParamMaps: Array[ParamMap] = - (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map { - pMap => - val paramPairs = pMap.map { case pInfo: Map[String, String] => - val est = uidToParams(pInfo("parent")) - val param = est.getParam(pInfo("name")) - val value = param.jsonDecode(pInfo("value")) - param -> value - } - ParamMap(paramPairs: _*) - }.toArray - (metadata, estimator, evaluator, estimatorParamMaps, numFolds) - } - } } /** @@ -319,8 +204,13 @@ class CrossValidatorModel private[ml] ( @Since("1.5.0") val avgMetrics: Array[Double]) extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { - @Since("1.4.0") - override def transform(dataset: DataFrame): DataFrame = { + /** A Python-friendly auxiliary constructor. */ + private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: JList[Double]) = { + this(uid, bestModel, avgMetrics.asScala.toArray) + } + + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } @@ -346,8 +236,6 @@ class CrossValidatorModel private[ml] ( @Since("1.6.0") object CrossValidatorModel extends MLReadable[CrossValidatorModel] { - import CrossValidator.SharedReadWrite - @Since("1.6.0") override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader @@ -357,12 +245,12 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { private[CrossValidatorModel] class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter { - SharedReadWrite.validateParams(instance) + ValidatorParams.validateParams(instance) override protected def saveImpl(path: String): Unit = { import org.json4s.JsonDSL._ val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq - SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata)) + ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata)) val bestModelPath = new Path(path, "bestModel").toString instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) } @@ -376,8 +264,9 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { override def load(path: String): CrossValidatorModel = { implicit val format = DefaultFormats - val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) = - SharedReadWrite.load(path, sc, className) + val (metadata, estimator, evaluator, estimatorParamMaps) = + ValidatorParams.loadImpl(path, sc, className) + val numFolds = (metadata.params \ "numFolds").extract[Int] val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray |