aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala13
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala11
7 files changed, 27 insertions, 61 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 3d7a91dd39..963f81cb3e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -131,19 +131,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
}
@Since("1.4.0")
- override def transformSchema(schema: StructType): StructType = {
- validateParams()
- $(estimator).transformSchema(schema)
- }
-
- @Since("1.4.0")
- override def validateParams(): Unit = {
- super.validateParams()
- val est = $(estimator)
- for (paramMap <- $(estimatorParamMaps)) {
- est.copy(paramMap).validateParams()
- }
- }
+ override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema)
@Since("1.4.0")
override def copy(extra: ParamMap): CrossValidator = {
@@ -332,11 +320,6 @@ class CrossValidatorModel private[ml] (
extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {
@Since("1.4.0")
- override def validateParams(): Unit = {
- bestModel.validateParams()
- }
-
- @Since("1.4.0")
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
bestModel.transform(dataset)
@@ -344,7 +327,6 @@ class CrossValidatorModel private[ml] (
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
- validateParams()
bestModel.transformSchema(schema)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 4587e259e8..70fa5f0234 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -117,19 +117,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
}
@Since("1.5.0")
- override def transformSchema(schema: StructType): StructType = {
- validateParams()
- $(estimator).transformSchema(schema)
- }
-
- @Since("1.5.0")
- override def validateParams(): Unit = {
- super.validateParams()
- val est = $(estimator)
- for (paramMap <- $(estimatorParamMaps)) {
- est.copy(paramMap).validateParams()
- }
- }
+ override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema)
@Since("1.5.0")
override def copy(extra: ParamMap): TrainValidationSplit = {
@@ -161,11 +149,6 @@ class TrainValidationSplitModel private[ml] (
extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {
@Since("1.5.0")
- override def validateParams(): Unit = {
- bestModel.validateParams()
- }
-
- @Since("1.5.0")
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
bestModel.transform(dataset)
@@ -173,7 +156,6 @@ class TrainValidationSplitModel private[ml] (
@Since("1.5.0")
override def transformSchema(schema: StructType): StructType = {
- validateParams()
bestModel.transformSchema(schema)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
index 553f254172..c004644ad8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
@@ -21,6 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Estimator
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.{Param, ParamMap, Params}
+import org.apache.spark.sql.types.StructType
/**
* :: DeveloperApi ::
@@ -31,6 +32,7 @@ private[ml] trait ValidatorParams extends Params {
/**
* param for the estimator to be validated
+ *
* @group param
*/
val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
@@ -40,6 +42,7 @@ private[ml] trait ValidatorParams extends Params {
/**
* param for estimator param maps
+ *
* @group param
*/
val estimatorParamMaps: Param[Array[ParamMap]] =
@@ -50,6 +53,7 @@ private[ml] trait ValidatorParams extends Params {
/**
* param for the evaluator used to select hyper-parameters that maximize the validated metric
+ *
* @group param
*/
val evaluator: Param[Evaluator] = new Param(this, "evaluator",
@@ -57,4 +61,14 @@ private[ml] trait ValidatorParams extends Params {
/** @group getParam */
def getEvaluator: Evaluator = $(evaluator)
+
+ protected def transformSchemaImpl(schema: StructType): StructType = {
+ require($(estimatorParamMaps).nonEmpty, s"Validator requires non-empty estimatorParamMaps")
+ val firstEstimatorParamMap = $(estimatorParamMaps).head
+ val est = $(estimator)
+ for (paramMap <- $(estimatorParamMaps).tail) {
+ est.copy(paramMap).transformSchema(schema)
+ }
+ est.copy(firstEstimatorParamMap).transformSchema(schema)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 748868554f..a3366c0e59 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -268,15 +268,10 @@ class ParamsSuite extends SparkFunSuite {
solver.getParam("abc")
}
- intercept[IllegalArgumentException] {
- solver.validateParams()
- }
- solver.copy(ParamMap(inputCol -> "input")).validateParams()
solver.setInputCol("input")
assert(solver.isSet(inputCol))
assert(solver.isDefined(inputCol))
assert(solver.getInputCol === "input")
- solver.validateParams()
intercept[IllegalArgumentException] {
ParamMap(maxIter -> -10)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
index 9d23547f28..7d990ce0bc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -34,10 +34,5 @@ class TestParams(override val uid: String) extends Params with HasHandleInvalid
def clearMaxIter(): this.type = clear(maxIter)
- override def validateParams(): Unit = {
- super.validateParams()
- require(isDefined(inputCol))
- }
-
override def copy(extra: ParamMap): TestParams = defaultCopy(extra)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 56545de14b..7af3c6d6ed 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLog
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{StructField, StructType}
class CrossValidatorSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -96,7 +96,7 @@ class CrossValidatorSuite
assert(cvModel2.avgMetrics.length === lrParamMaps.length)
}
- test("validateParams should check estimatorParamMaps") {
+ test("transformSchema should check estimatorParamMaps") {
import CrossValidatorSuite.{MyEstimator, MyEvaluator}
val est = new MyEstimator("est")
@@ -110,12 +110,12 @@ class CrossValidatorSuite
.setEstimatorParamMaps(paramMaps)
.setEvaluator(eval)
- cv.validateParams() // This should pass.
+ cv.transformSchema(new StructType()) // This should pass.
val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
cv.setEstimatorParamMaps(invalidParamMaps)
intercept[IllegalArgumentException] {
- cv.validateParams()
+ cv.transformSchema(new StructType())
}
}
@@ -311,14 +311,13 @@ object CrossValidatorSuite extends SparkFunSuite {
class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
- override def validateParams(): Unit = require($(inputCol).nonEmpty)
-
override def fit(dataset: DataFrame): MyModel = {
throw new UnsupportedOperationException
}
override def transformSchema(schema: StructType): StructType = {
- throw new UnsupportedOperationException
+ require($(inputCol).nonEmpty)
+ schema
}
override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index 5fb80091d0..cf8dcefebc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -83,7 +83,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
assert(cvModel2.validationMetrics.length === lrParamMaps.length)
}
- test("validateParams should check estimatorParamMaps") {
+ test("transformSchema should check estimatorParamMaps") {
import TrainValidationSplitSuite._
val est = new MyEstimator("est")
@@ -97,12 +97,12 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
.setEstimatorParamMaps(paramMaps)
.setEvaluator(eval)
.setTrainRatio(0.5)
- cv.validateParams() // This should pass.
+ cv.transformSchema(new StructType()) // This should pass.
val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
cv.setEstimatorParamMaps(invalidParamMaps)
intercept[IllegalArgumentException] {
- cv.validateParams()
+ cv.transformSchema(new StructType())
}
}
}
@@ -113,14 +113,13 @@ object TrainValidationSplitSuite {
class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
- override def validateParams(): Unit = require($(inputCol).nonEmpty)
-
override def fit(dataset: DataFrame): MyModel = {
throw new UnsupportedOperationException
}
override def transformSchema(schema: StructType): StructType = {
- throw new UnsupportedOperationException
+ require($(inputCol).nonEmpty)
+ schema
}
override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)