aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-26 23:51:32 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-26 23:51:32 -0700
commita9f1c0c57b9be586dbada09dab91dcfce31141d9 (patch)
tree3555425d377a4f89b60ac68ce2f2d2b65b93840e /mllib
parentb463e6d618e69c535297e51f41eca4f91bd33cc8 (diff)
downloadspark-a9f1c0c57b9be586dbada09dab91dcfce31141d9.tar.gz
spark-a9f1c0c57b9be586dbada09dab91dcfce31141d9.tar.bz2
spark-a9f1c0c57b9be586dbada09dab91dcfce31141d9.zip
[SPARK-7535] [.1] [MLLIB] minor changes to the pipeline API
1. removed `Params.validateParams(extra)` 2. added `Evaluate.evaluate(dataset, paramPairs*)` 3. updated `RegressionEvaluator` doc jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #6392 from mengxr/SPARK-7535.1 and squashes the following commits: 5ff5af8 [Xiangrui Meng] add unit test for CV.validateParams f1f8369 [Xiangrui Meng] update CV.validateParams() to test estimatorParamMaps 607445d [Xiangrui Meng] merge master 8716f5f [Xiangrui Meng] specify default metric name in RegressionEvaluator e4e5631 [Xiangrui Meng] update RegressionEvaluator doc 801e864 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7535.1 fcbd3e2 [Xiangrui Meng] Merge branch 'master' into SPARK-7535.1 2192316 [Xiangrui Meng] remove validateParams(extra); add evaluate(dataset, extra*)
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala23
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala52
6 files changed, 71 insertions, 32 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 9da3ff65c7..11a4722722 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -97,12 +97,9 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
/** @group getParam */
def getStages: Array[PipelineStage] = $(stages).clone()
- override def validateParams(paramMap: ParamMap): Unit = {
- val map = extractParamMap(paramMap)
- getStages.foreach {
- case pStage: Params => pStage.validateParams(map)
- case _ =>
- }
+ override def validateParams(): Unit = {
+ super.validateParams()
+ $(stages).foreach(_.validateParams())
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index 1771177e1e..abb1b35bed 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -36,8 +36,8 @@ final class RegressionEvaluator(override val uid: String)
def this() = this(Identifiable.randomUID("regEval"))
/**
- * param for metric name in evaluation
- * @group param supports mse, rmse, r2, mae as valid metric names.
+ * param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`)
+ * @group param
*/
val metricName: Param[String] = {
val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae"))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 1afa59c994..473488dce9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -334,19 +334,6 @@ trait Params extends Identifiable with Serializable {
}
/**
- * Validates parameter values stored internally plus the input parameter map.
- * Raises an exception if any parameter is invalid.
- *
- * This only needs to check for interactions between parameters.
- * Parameter value checks which do not depend on other parameters are handled by
- * [[Param.validate()]]. This method does not handle input/output column parameters;
- * those are checked during schema validation.
- */
- def validateParams(paramMap: ParamMap): Unit = {
- copy(paramMap).validateParams()
- }
-
- /**
* Validates parameter values stored internally.
* Raise an exception if any parameter value is invalid.
*
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 2e5a629561..6434b64aed 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -102,12 +102,6 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
/** @group setParam */
def setNumFolds(value: Int): this.type = set(numFolds, value)
- override def validateParams(paramMap: ParamMap): Unit = {
- getEstimatorParamMaps.foreach { eMap =>
- getEstimator.validateParams(eMap ++ paramMap)
- }
- }
-
override def fit(dataset: DataFrame): CrossValidatorModel = {
val schema = dataset.schema
transformSchema(schema, logging = true)
@@ -147,6 +141,14 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
override def transformSchema(schema: StructType): StructType = {
$(estimator).transformSchema(schema)
}
+
+ override def validateParams(): Unit = {
+ super.validateParams()
+ val est = $(estimator)
+ for (paramMap <- $(estimatorParamMaps)) {
+ est.copy(paramMap).validateParams()
+ }
+ }
}
/**
@@ -159,8 +161,8 @@ class CrossValidatorModel private[ml] (
val bestModel: Model[_])
extends Model[CrossValidatorModel] with CrossValidatorParams {
- override def validateParams(paramMap: ParamMap): Unit = {
- bestModel.validateParams(paramMap)
+ override def validateParams(): Unit = {
+ bestModel.validateParams()
}
override def transform(dataset: DataFrame): DataFrame = {
@@ -171,4 +173,9 @@ class CrossValidatorModel private[ml] (
override def transformSchema(schema: StructType): StructType = {
bestModel.transformSchema(schema)
}
+
+ override def copy(extra: ParamMap): CrossValidatorModel = {
+ val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]])
+ copyValues(copied, extra)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index d270ad7613..04f2af4727 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -135,7 +135,7 @@ class ParamsSuite extends FunSuite {
intercept[IllegalArgumentException] {
solver.validateParams()
}
- solver.validateParams(ParamMap(inputCol -> "input"))
+ solver.copy(ParamMap(inputCol -> "input")).validateParams()
solver.setInputCol("input")
assert(solver.isSet(inputCol))
assert(solver.isDefined(inputCol))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 05313d440f..65972ec79b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -19,11 +19,15 @@ package org.apache.spark.ml.tuning
import org.scalatest.FunSuite
+import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.classification.LogisticRegression
-import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
+import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{SQLContext, DataFrame}
+import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.types.StructType
class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext {
@@ -53,4 +57,48 @@ class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext {
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
}
+
+ test("validateParams should check estimatorParamMaps") {
+ import CrossValidatorSuite._
+
+ val est = new MyEstimator("est")
+ val eval = new MyEvaluator
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(est.inputCol, Array("input1", "input2"))
+ .build()
+
+ val cv = new CrossValidator()
+ .setEstimator(est)
+ .setEstimatorParamMaps(paramMaps)
+ .setEvaluator(eval)
+
+ cv.validateParams() // This should pass.
+
+ val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
+ cv.setEstimatorParamMaps(invalidParamMaps)
+ intercept[IllegalArgumentException] {
+ cv.validateParams()
+ }
+ }
+}
+
+object CrossValidatorSuite {
+
+ abstract class MyModel extends Model[MyModel]
+
+ class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
+
+ override def validateParams(): Unit = require($(inputCol).nonEmpty)
+
+ override def fit(dataset: DataFrame): MyModel = ???
+
+ override def transformSchema(schema: StructType): StructType = ???
+ }
+
+ class MyEvaluator extends Evaluator {
+
+ override def evaluate(dataset: DataFrame): Double = ???
+
+ override val uid: String = "eval"
+ }
}