From a9f1c0c57b9be586dbada09dab91dcfce31141d9 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 26 May 2015 23:51:32 -0700 Subject: [SPARK-7535] [.1] [MLLIB] minor changes to the pipeline API 1. removed `Params.validateParams(extra)` 2. added `Evaluate.evaluate(dataset, paramPairs*)` 3. updated `RegressionEvaluator` doc jkbradley Author: Xiangrui Meng Closes #6392 from mengxr/SPARK-7535.1 and squashes the following commits: 5ff5af8 [Xiangrui Meng] add unit test for CV.validateParams f1f8369 [Xiangrui Meng] update CV.validateParams() to test estimatorParamMaps 607445d [Xiangrui Meng] merge master 8716f5f [Xiangrui Meng] specify default metric name in RegressionEvaluator e4e5631 [Xiangrui Meng] update RegressionEvaluator doc 801e864 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7535.1 fcbd3e2 [Xiangrui Meng] Merge branch 'master' into SPARK-7535.1 2192316 [Xiangrui Meng] remove validateParams(extra); add evaluate(dataset, extra*) --- .../main/scala/org/apache/spark/ml/Pipeline.scala | 9 ++-- .../spark/ml/evaluation/RegressionEvaluator.scala | 4 +- .../scala/org/apache/spark/ml/param/params.scala | 13 ------ .../apache/spark/ml/tuning/CrossValidator.scala | 23 ++++++---- .../org/apache/spark/ml/param/ParamsSuite.scala | 2 +- .../spark/ml/tuning/CrossValidatorSuite.scala | 52 +++++++++++++++++++++- 6 files changed, 71 insertions(+), 32 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 9da3ff65c7..11a4722722 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -97,12 +97,9 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { /** @group getParam */ def getStages: Array[PipelineStage] = $(stages).clone() - override def validateParams(paramMap: ParamMap): Unit = { - val map = extractParamMap(paramMap) - getStages.foreach { - case pStage: Params => pStage.validateParams(map) - case _ => - } + override def validateParams(): Unit = { + super.validateParams() + $(stages).foreach(_.validateParams()) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 1771177e1e..abb1b35bed 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -36,8 +36,8 @@ final class RegressionEvaluator(override val uid: String) def this() = this(Identifiable.randomUID("regEval")) /** - * param for metric name in evaluation - * @group param supports mse, rmse, r2, mae as valid metric names. + * param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`) + * @group param */ val metricName: Param[String] = { val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae")) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 1afa59c994..473488dce9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -333,19 +333,6 @@ trait Params extends Identifiable with Serializable { .map(m => m.invoke(this).asInstanceOf[Param[_]]) } - /** - * Validates parameter values stored internally plus the input parameter map. - * Raises an exception if any parameter is invalid. - * - * This only needs to check for interactions between parameters. - * Parameter value checks which do not depend on other parameters are handled by - * [[Param.validate()]]. This method does not handle input/output column parameters; - * those are checked during schema validation. - */ - def validateParams(paramMap: ParamMap): Unit = { - copy(paramMap).validateParams() - } - /** * Validates parameter values stored internally. * Raise an exception if any parameter value is invalid. diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 2e5a629561..6434b64aed 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -102,12 +102,6 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM /** @group setParam */ def setNumFolds(value: Int): this.type = set(numFolds, value) - override def validateParams(paramMap: ParamMap): Unit = { - getEstimatorParamMaps.foreach { eMap => - getEstimator.validateParams(eMap ++ paramMap) - } - } - override def fit(dataset: DataFrame): CrossValidatorModel = { val schema = dataset.schema transformSchema(schema, logging = true) @@ -147,6 +141,14 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM override def transformSchema(schema: StructType): StructType = { $(estimator).transformSchema(schema) } + + override def validateParams(): Unit = { + super.validateParams() + val est = $(estimator) + for (paramMap <- $(estimatorParamMaps)) { + est.copy(paramMap).validateParams() + } + } } /** @@ -159,8 +161,8 @@ class CrossValidatorModel private[ml] ( val bestModel: Model[_]) extends Model[CrossValidatorModel] with CrossValidatorParams { - override def validateParams(paramMap: ParamMap): Unit = { - bestModel.validateParams(paramMap) + override def validateParams(): Unit = { + bestModel.validateParams() } override def transform(dataset: DataFrame): DataFrame = { @@ -171,4 +173,9 @@ class CrossValidatorModel private[ml] ( override def transformSchema(schema: StructType): StructType = { bestModel.transformSchema(schema) } + + override def copy(extra: ParamMap): CrossValidatorModel = { + val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]]) + copyValues(copied, extra) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index d270ad7613..04f2af4727 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -135,7 +135,7 @@ class ParamsSuite extends FunSuite { intercept[IllegalArgumentException] { solver.validateParams() } - solver.validateParams(ParamMap(inputCol -> "input")) + solver.copy(ParamMap(inputCol -> "input")).validateParams() solver.setInputCol("input") assert(solver.isSet(inputCol)) assert(solver.isDefined(inputCol)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 05313d440f..65972ec79b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -19,11 +19,15 @@ package org.apache.spark.ml.tuning import org.scalatest.FunSuite +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, DataFrame} +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.types.StructType class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { @@ -53,4 +57,48 @@ class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) } + + test("validateParams should check estimatorParamMaps") { + import CrossValidatorSuite._ + + val est = new MyEstimator("est") + val eval = new MyEvaluator + val paramMaps = new ParamGridBuilder() + .addGrid(est.inputCol, Array("input1", "input2")) + .build() + + val cv = new CrossValidator() + .setEstimator(est) + .setEstimatorParamMaps(paramMaps) + .setEvaluator(eval) + + cv.validateParams() // This should pass. + + val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") + cv.setEstimatorParamMaps(invalidParamMaps) + intercept[IllegalArgumentException] { + cv.validateParams() + } + } +} + +object CrossValidatorSuite { + + abstract class MyModel extends Model[MyModel] + + class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { + + override def validateParams(): Unit = require($(inputCol).nonEmpty) + + override def fit(dataset: DataFrame): MyModel = ??? + + override def transformSchema(schema: StructType): StructType = ??? + } + + class MyEvaluator extends Evaluator { + + override def evaluate(dataset: DataFrame): Double = ??? + + override val uid: String = "eval" + } } -- cgit v1.2.3