diff options
Diffstat (limited to 'mllib')
5 files changed, 310 insertions, 142 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 963f81cb3e..040b0093b9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -19,25 +19,19 @@ package org.apache.spark.ml.tuning import com.github.fommil.netlib.F2jBLAS import org.apache.hadoop.fs.Path -import org.json4s.{DefaultFormats, JObject} -import org.json4s.jackson.JsonMethods._ +import org.json4s.DefaultFormats -import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml._ -import org.apache.spark.ml.classification.OneVsRestParams import org.apache.spark.ml.evaluation.Evaluator -import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ -import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType - /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ @@ -45,6 +39,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed { /** * Param for number of folds for cross validation. Must be >= 2. * Default: 3 + * * @group param */ val numFolds: IntParam = new IntParam(this, "numFolds", @@ -163,10 +158,10 @@ object CrossValidator extends MLReadable[CrossValidator] { private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter { - SharedReadWrite.validateParams(instance) + ValidatorParams.validateParams(instance) override protected def saveImpl(path: String): Unit = - SharedReadWrite.saveImpl(path, instance, sc) + ValidatorParams.saveImpl(path, instance, sc) } private class CrossValidatorReader extends MLReader[CrossValidator] { @@ -175,8 +170,11 @@ object CrossValidator extends MLReadable[CrossValidator] { private val className = classOf[CrossValidator].getName override def load(path: String): CrossValidator = { - val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) = - SharedReadWrite.load(path, sc, className) + implicit val format = DefaultFormats + + val (metadata, estimator, evaluator, estimatorParamMaps) = + ValidatorParams.loadImpl(path, sc, className) + val numFolds = (metadata.params \ "numFolds").extract[Int] new CrossValidator(metadata.uid) .setEstimator(estimator) .setEvaluator(evaluator) @@ -184,123 +182,6 @@ object CrossValidator extends MLReadable[CrossValidator] { .setNumFolds(numFolds) } } - - private object CrossValidatorReader { - /** - * Examine the given estimator (which may be a compound estimator) and extract a mapping - * from UIDs to corresponding [[Params]] instances. - */ - def getUidMap(instance: Params): Map[String, Params] = { - val uidList = getUidMapImpl(instance) - val uidMap = uidList.toMap - if (uidList.size != uidMap.size) { - throw new RuntimeException("CrossValidator.load found a compound estimator with stages" + - s" with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}") - } - uidMap - } - - def getUidMapImpl(instance: Params): List[(String, Params)] = { - val subStages: Array[Params] = instance match { - case p: Pipeline => p.getStages.asInstanceOf[Array[Params]] - case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]] - case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator) - case ovr: OneVsRestParams => - // TODO: SPARK-11892: This case may require special handling. - throw new UnsupportedOperationException("CrossValidator write will fail because it" + - " cannot yet handle an estimator containing type: ${ovr.getClass.getName}") - case rformModel: RFormulaModel => Array(rformModel.pipelineModel) - case _: Params => Array() - } - val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _) - List((instance.uid, instance)) ++ subStageMaps - } - } - - private[tuning] object SharedReadWrite { - - /** - * Check that [[CrossValidator.evaluator]] and [[CrossValidator.estimator]] are Writable. - * This does not check [[CrossValidator.estimatorParamMaps]]. - */ - def validateParams(instance: ValidatorParams): Unit = { - def checkElement(elem: Params, name: String): Unit = elem match { - case stage: MLWritable => // good - case other => - throw new UnsupportedOperationException("CrossValidator write will fail " + - s" because it contains $name which does not implement Writable." + - s" Non-Writable $name: ${other.uid} of type ${other.getClass}") - } - checkElement(instance.getEvaluator, "evaluator") - checkElement(instance.getEstimator, "estimator") - // Check to make sure all Params apply to this estimator. Throw an error if any do not. - // Extraneous Params would cause problems when loading the estimatorParamMaps. - val uidToInstance: Map[String, Params] = CrossValidatorReader.getUidMap(instance) - instance.getEstimatorParamMaps.foreach { case pMap: ParamMap => - pMap.toSeq.foreach { case ParamPair(p, v) => - require(uidToInstance.contains(p.parent), s"CrossValidator save requires all Params in" + - s" estimatorParamMaps to apply to this CrossValidator, its Estimator, or its" + - s" Evaluator. An extraneous Param was found: $p") - } - } - } - - private[tuning] def saveImpl( - path: String, - instance: CrossValidatorParams, - sc: SparkContext, - extraMetadata: Option[JObject] = None): Unit = { - import org.json4s.JsonDSL._ - - val estimatorParamMapsJson = compact(render( - instance.getEstimatorParamMaps.map { case paramMap => - paramMap.toSeq.map { case ParamPair(p, v) => - Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v)) - } - }.toSeq - )) - val jsonParams = List( - "numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)), - "estimatorParamMaps" -> parse(estimatorParamMapsJson) - ) - DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) - - val evaluatorPath = new Path(path, "evaluator").toString - instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath) - val estimatorPath = new Path(path, "estimator").toString - instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath) - } - - private[tuning] def load[M <: Model[M]]( - path: String, - sc: SparkContext, - expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap], Int) = { - - val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) - - implicit val format = DefaultFormats - val evaluatorPath = new Path(path, "evaluator").toString - val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc) - val estimatorPath = new Path(path, "estimator").toString - val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc) - - val uidToParams = Map(evaluator.uid -> evaluator) ++ CrossValidatorReader.getUidMap(estimator) - - val numFolds = (metadata.params \ "numFolds").extract[Int] - val estimatorParamMaps: Array[ParamMap] = - (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map { - pMap => - val paramPairs = pMap.map { case pInfo: Map[String, String] => - val est = uidToParams(pInfo("parent")) - val param = est.getParam(pInfo("name")) - val value = param.jsonDecode(pInfo("value")) - param -> value - } - ParamMap(paramPairs: _*) - }.toArray - (metadata, estimator, evaluator, estimatorParamMaps, numFolds) - } - } } /** @@ -346,8 +227,6 @@ class CrossValidatorModel private[ml] ( @Since("1.6.0") object CrossValidatorModel extends MLReadable[CrossValidatorModel] { - import CrossValidator.SharedReadWrite - @Since("1.6.0") override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader @@ -357,12 +236,12 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { private[CrossValidatorModel] class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter { - SharedReadWrite.validateParams(instance) + ValidatorParams.validateParams(instance) override protected def saveImpl(path: String): Unit = { import org.json4s.JsonDSL._ val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq - SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata)) + ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata)) val bestModelPath = new Path(path, "bestModel").toString instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) } @@ -376,8 +255,9 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { override def load(path: String): CrossValidatorModel = { implicit val format = DefaultFormats - val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) = - SharedReadWrite.load(path, sc, className) + val (metadata, estimator, evaluator, estimatorParamMaps) = + ValidatorParams.loadImpl(path, sc, className) + val numFolds = (metadata.params \ "numFolds").extract[Int] val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 70fa5f0234..4d1d6364d7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -17,12 +17,15 @@ package org.apache.spark.ml.tuning +import org.apache.hadoop.fs.Path +import org.json4s.DefaultFormats + import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -33,6 +36,7 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams { /** * Param for ratio between train and validation data. Must be between 0 and 1. * Default: 0.75 + * * @group param */ val trainRatio: DoubleParam = new DoubleParam(this, "trainRatio", @@ -55,7 +59,7 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams { @Experimental class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String) extends Estimator[TrainValidationSplitModel] - with TrainValidationSplitParams with Logging { + with TrainValidationSplitParams with MLWritable with Logging { @Since("1.5.0") def this() = this(Identifiable.randomUID("tvs")) @@ -130,6 +134,47 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St } copied } + + @Since("2.0.0") + override def write: MLWriter = new TrainValidationSplit.TrainValidationSplitWriter(this) +} + +@Since("2.0.0") +object TrainValidationSplit extends MLReadable[TrainValidationSplit] { + + @Since("2.0.0") + override def read: MLReader[TrainValidationSplit] = new TrainValidationSplitReader + + @Since("2.0.0") + override def load(path: String): TrainValidationSplit = super.load(path) + + private[TrainValidationSplit] class TrainValidationSplitWriter(instance: TrainValidationSplit) + extends MLWriter { + + ValidatorParams.validateParams(instance) + + override protected def saveImpl(path: String): Unit = + ValidatorParams.saveImpl(path, instance, sc) + } + + private class TrainValidationSplitReader extends MLReader[TrainValidationSplit] { + + /** Checked against metadata when loading model */ + private val className = classOf[TrainValidationSplit].getName + + override def load(path: String): TrainValidationSplit = { + implicit val format = DefaultFormats + + val (metadata, estimator, evaluator, estimatorParamMaps) = + ValidatorParams.loadImpl(path, sc, className) + val trainRatio = (metadata.params \ "trainRatio").extract[Double] + new TrainValidationSplit(metadata.uid) + .setEstimator(estimator) + .setEvaluator(evaluator) + .setEstimatorParamMaps(estimatorParamMaps) + .setTrainRatio(trainRatio) + } + } } /** @@ -146,7 +191,7 @@ class TrainValidationSplitModel private[ml] ( @Since("1.5.0") override val uid: String, @Since("1.5.0") val bestModel: Model[_], @Since("1.5.0") val validationMetrics: Array[Double]) - extends Model[TrainValidationSplitModel] with TrainValidationSplitParams { + extends Model[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable { @Since("1.5.0") override def transform(dataset: DataFrame): DataFrame = { @@ -167,4 +212,53 @@ class TrainValidationSplitModel private[ml] ( validationMetrics.clone()) copyValues(copied, extra) } + + @Since("2.0.0") + override def write: MLWriter = new TrainValidationSplitModel.TrainValidationSplitModelWriter(this) +} + +@Since("2.0.0") +object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { + + @Since("2.0.0") + override def read: MLReader[TrainValidationSplitModel] = new TrainValidationSplitModelReader + + @Since("2.0.0") + override def load(path: String): TrainValidationSplitModel = super.load(path) + + private[TrainValidationSplitModel] + class TrainValidationSplitModelWriter(instance: TrainValidationSplitModel) extends MLWriter { + + ValidatorParams.validateParams(instance) + + override protected def saveImpl(path: String): Unit = { + import org.json4s.JsonDSL._ + val extraMetadata = "validationMetrics" -> instance.validationMetrics.toSeq + ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata)) + val bestModelPath = new Path(path, "bestModel").toString + instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + } + } + + private class TrainValidationSplitModelReader extends MLReader[TrainValidationSplitModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[TrainValidationSplitModel].getName + + override def load(path: String): TrainValidationSplitModel = { + implicit val format = DefaultFormats + + val (metadata, estimator, evaluator, estimatorParamMaps) = + ValidatorParams.loadImpl(path, sc, className) + val trainRatio = (metadata.params \ "trainRatio").extract[Double] + val bestModelPath = new Path(path, "bestModel").toString + val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) + val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray + val tvs = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics) + tvs.set(tvs.estimator, estimator) + .set(tvs.evaluator, evaluator) + .set(tvs.estimatorParamMaps, estimatorParamMaps) + .set(tvs.trainRatio, trainRatio) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 953456e8f0..7a4e106aeb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -17,9 +17,17 @@ package org.apache.spark.ml.tuning -import org.apache.spark.ml.Estimator +import org.apache.hadoop.fs.Path +import org.json4s.{DefaultFormats, _} +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator -import org.apache.spark.ml.param.{Param, ParamMap, Params} +import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} +import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MetaAlgorithmReadWrite, + MLWritable} +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.sql.types.StructType /** @@ -69,3 +77,108 @@ private[ml] trait ValidatorParams extends Params { est.copy(firstEstimatorParamMap).transformSchema(schema) } } + +private[ml] object ValidatorParams { + /** + * Check that [[ValidatorParams.evaluator]] and [[ValidatorParams.estimator]] are Writable. + * This does not check [[ValidatorParams.estimatorParamMaps]]. + */ + def validateParams(instance: ValidatorParams): Unit = { + def checkElement(elem: Params, name: String): Unit = elem match { + case stage: MLWritable => // good + case other => + throw new UnsupportedOperationException(instance.getClass.getName + " write will fail " + + s" because it contains $name which does not implement Writable." + + s" Non-Writable $name: ${other.uid} of type ${other.getClass}") + } + checkElement(instance.getEvaluator, "evaluator") + checkElement(instance.getEstimator, "estimator") + // Check to make sure all Params apply to this estimator. Throw an error if any do not. + // Extraneous Params would cause problems when loading the estimatorParamMaps. + val uidToInstance: Map[String, Params] = MetaAlgorithmReadWrite.getUidMap(instance) + instance.getEstimatorParamMaps.foreach { case pMap: ParamMap => + pMap.toSeq.foreach { case ParamPair(p, v) => + require(uidToInstance.contains(p.parent), s"ValidatorParams save requires all Params in" + + s" estimatorParamMaps to apply to this ValidatorParams, its Estimator, or its" + + s" Evaluator. An extraneous Param was found: $p") + } + } + } + + /** + * Generic implementation of save for [[ValidatorParams]] types. + * This handles all [[ValidatorParams]] fields and saves [[Param]] values, but the implementing + * class needs to handle model data. + */ + def saveImpl( + path: String, + instance: ValidatorParams, + sc: SparkContext, + extraMetadata: Option[JObject] = None): Unit = { + import org.json4s.JsonDSL._ + + val estimatorParamMapsJson = compact(render( + instance.getEstimatorParamMaps.map { case paramMap => + paramMap.toSeq.map { case ParamPair(p, v) => + Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v)) + } + }.toSeq + )) + + val validatorSpecificParams = instance match { + case cv: CrossValidatorParams => + List("numFolds" -> parse(cv.numFolds.jsonEncode(cv.getNumFolds))) + case tvs: TrainValidationSplitParams => + List("trainRatio" -> parse(tvs.trainRatio.jsonEncode(tvs.getTrainRatio))) + case _ => + // This should not happen. + throw new NotImplementedError("ValidatorParams.saveImpl does not handle type: " + + instance.getClass.getCanonicalName) + } + + val jsonParams = validatorSpecificParams ++ List( + "estimatorParamMaps" -> parse(estimatorParamMapsJson)) + + DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) + + val evaluatorPath = new Path(path, "evaluator").toString + instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath) + val estimatorPath = new Path(path, "estimator").toString + instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath) + } + + /** + * Generic implementation of load for [[ValidatorParams]] types. + * This handles all [[ValidatorParams]] fields, but the implementing + * class needs to handle model data and special [[Param]] values. + */ + def loadImpl[M <: Model[M]]( + path: String, + sc: SparkContext, + expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap]) = { + + val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + + implicit val format = DefaultFormats + val evaluatorPath = new Path(path, "evaluator").toString + val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc) + val estimatorPath = new Path(path, "estimator").toString + val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc) + + val uidToParams = Map(evaluator.uid -> evaluator) ++ MetaAlgorithmReadWrite.getUidMap(estimator) + + val estimatorParamMaps: Array[ParamMap] = + (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map { + pMap => + val paramPairs = pMap.map { case pInfo: Map[String, String] => + val est = uidToParams(pInfo("parent")) + val param = est.getParam(pInfo("name")) + val value = param.jsonDecode(pInfo("value")) + param -> value + } + ParamMap(paramPairs: _*) + }.toArray + + (metadata, estimator, evaluator, estimatorParamMaps) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index c95e536abd..5a596cad06 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -21,13 +21,18 @@ import java.io.IOException import org.apache.hadoop.fs.Path import org.json4s._ -import org.json4s.jackson.JsonMethods._ +import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging +import org.apache.spark.ml._ +import org.apache.spark.ml.classification.OneVsRestParams +import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param.{ParamPair, Params} +import org.apache.spark.ml.tuning.ValidatorParams import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils @@ -352,3 +357,38 @@ private[ml] object DefaultParamsReader { cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) } } + +/** + * Default Meta-Algorithm read and write implementation. + */ +private[ml] object MetaAlgorithmReadWrite { + /** + * Examine the given estimator (which may be a compound estimator) and extract a mapping + * from UIDs to corresponding [[Params]] instances. + */ + def getUidMap(instance: Params): Map[String, Params] = { + val uidList = getUidMapImpl(instance) + val uidMap = uidList.toMap + if (uidList.size != uidMap.size) { + throw new RuntimeException(s"${instance.getClass.getName}.load found a compound estimator" + + s" with stages with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}.") + } + uidMap + } + + private def getUidMapImpl(instance: Params): List[(String, Params)] = { + val subStages: Array[Params] = instance match { + case p: Pipeline => p.getStages.asInstanceOf[Array[Params]] + case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]] + case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator) + case ovr: OneVsRestParams => + // TODO: SPARK-11892: This case may require special handling. + throw new UnsupportedOperationException(s"${instance.getClass.getName} write will fail" + + s" because it cannot yet handle an estimator containing type: ${ovr.getClass.getName}.") + case rformModel: RFormulaModel => Array(rformModel.pipelineModel) + case _: Params => Array() + } + val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _) + List((instance.uid, instance)) ++ subStageMaps + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index cf8dcefebc..7cf7b3e087 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -19,17 +19,20 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType -class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext { +class TrainValidationSplitSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("train validation with logistic regression") { val dataset = sqlContext.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) @@ -105,6 +108,44 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext cv.transformSchema(new StructType()) } } + + test("read/write: TrainValidationSplit") { + val lr = new LogisticRegression().setMaxIter(3) + val evaluator = new BinaryClassificationEvaluator() + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val tvs = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(evaluator) + .setTrainRatio(0.5) + .setEstimatorParamMaps(paramMaps) + + val tvs2 = testDefaultReadWrite(tvs, testParams = false) + + assert(tvs.getTrainRatio === tvs2.getTrainRatio) + } + + test("read/write: TrainValidationSplitModel") { + val lr = new LogisticRegression() + .setThreshold(0.6) + val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 1.2) + .setThreshold(0.6) + val evaluator = new BinaryClassificationEvaluator() + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val tvs = new TrainValidationSplitModel("cvUid", lrModel, Array(0.3, 0.6)) + tvs.set(tvs.estimator, lr) + .set(tvs.evaluator, evaluator) + .set(tvs.trainRatio, 0.5) + .set(tvs.estimatorParamMaps, paramMaps) + + val tvs2 = testDefaultReadWrite(tvs, testParams = false) + + assert(tvs.getTrainRatio === tvs2.getTrainRatio) + assert(tvs.validationMetrics === tvs2.validationMetrics) + } } object TrainValidationSplitSuite { |