From a9f1c0c57b9be586dbada09dab91dcfce31141d9 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 26 May 2015 23:51:32 -0700 Subject: [SPARK-7535] [.1] [MLLIB] minor changes to the pipeline API 1. removed `Params.validateParams(extra)` 2. added `Evaluate.evaluate(dataset, paramPairs*)` 3. updated `RegressionEvaluator` doc jkbradley Author: Xiangrui Meng Closes #6392 from mengxr/SPARK-7535.1 and squashes the following commits: 5ff5af8 [Xiangrui Meng] add unit test for CV.validateParams f1f8369 [Xiangrui Meng] update CV.validateParams() to test estimatorParamMaps 607445d [Xiangrui Meng] merge master 8716f5f [Xiangrui Meng] specify default metric name in RegressionEvaluator e4e5631 [Xiangrui Meng] update RegressionEvaluator doc 801e864 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7535.1 fcbd3e2 [Xiangrui Meng] Merge branch 'master' into SPARK-7535.1 2192316 [Xiangrui Meng] remove validateParams(extra); add evaluate(dataset, extra*) --- .../main/scala/org/apache/spark/ml/Pipeline.scala | 9 +++------ .../spark/ml/evaluation/RegressionEvaluator.scala | 4 ++-- .../scala/org/apache/spark/ml/param/params.scala | 13 ------------ .../apache/spark/ml/tuning/CrossValidator.scala | 23 ++++++++++++++-------- 4 files changed, 20 insertions(+), 29 deletions(-) (limited to 'mllib/src/main') diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 9da3ff65c7..11a4722722 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -97,12 +97,9 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { /** @group getParam */ def getStages: Array[PipelineStage] = $(stages).clone() - override def validateParams(paramMap: ParamMap): Unit = { - val map = extractParamMap(paramMap) - getStages.foreach { - case pStage: Params => pStage.validateParams(map) - case _ => - } + override def validateParams(): Unit = { + super.validateParams() + $(stages).foreach(_.validateParams()) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 1771177e1e..abb1b35bed 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -36,8 +36,8 @@ final class RegressionEvaluator(override val uid: String) def this() = this(Identifiable.randomUID("regEval")) /** - * param for metric name in evaluation - * @group param supports mse, rmse, r2, mae as valid metric names. + * param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`) + * @group param */ val metricName: Param[String] = { val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae")) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 1afa59c994..473488dce9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -333,19 +333,6 @@ trait Params extends Identifiable with Serializable { .map(m => m.invoke(this).asInstanceOf[Param[_]]) } - /** - * Validates parameter values stored internally plus the input parameter map. - * Raises an exception if any parameter is invalid. - * - * This only needs to check for interactions between parameters. - * Parameter value checks which do not depend on other parameters are handled by - * [[Param.validate()]]. This method does not handle input/output column parameters; - * those are checked during schema validation. - */ - def validateParams(paramMap: ParamMap): Unit = { - copy(paramMap).validateParams() - } - /** * Validates parameter values stored internally. * Raise an exception if any parameter value is invalid. diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 2e5a629561..6434b64aed 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -102,12 +102,6 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM /** @group setParam */ def setNumFolds(value: Int): this.type = set(numFolds, value) - override def validateParams(paramMap: ParamMap): Unit = { - getEstimatorParamMaps.foreach { eMap => - getEstimator.validateParams(eMap ++ paramMap) - } - } - override def fit(dataset: DataFrame): CrossValidatorModel = { val schema = dataset.schema transformSchema(schema, logging = true) @@ -147,6 +141,14 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM override def transformSchema(schema: StructType): StructType = { $(estimator).transformSchema(schema) } + + override def validateParams(): Unit = { + super.validateParams() + val est = $(estimator) + for (paramMap <- $(estimatorParamMaps)) { + est.copy(paramMap).validateParams() + } + } } /** @@ -159,8 +161,8 @@ class CrossValidatorModel private[ml] ( val bestModel: Model[_]) extends Model[CrossValidatorModel] with CrossValidatorParams { - override def validateParams(paramMap: ParamMap): Unit = { - bestModel.validateParams(paramMap) + override def validateParams(): Unit = { + bestModel.validateParams() } override def transform(dataset: DataFrame): DataFrame = { @@ -171,4 +173,9 @@ class CrossValidatorModel private[ml] ( override def transformSchema(schema: StructType): StructType = { bestModel.transformSchema(schema) } + + override def copy(extra: ParamMap): CrossValidatorModel = { + val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]]) + copyValues(copied, extra) + } } -- cgit v1.2.3