diff options
Diffstat (limited to 'mllib/src/test/scala')
4 files changed, 11 insertions, 23 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 748868554f..a3366c0e59 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -268,15 +268,10 @@ class ParamsSuite extends SparkFunSuite { solver.getParam("abc") } - intercept[IllegalArgumentException] { - solver.validateParams() - } - solver.copy(ParamMap(inputCol -> "input")).validateParams() solver.setInputCol("input") assert(solver.isSet(inputCol)) assert(solver.isDefined(inputCol)) assert(solver.getInputCol === "input") - solver.validateParams() intercept[IllegalArgumentException] { ParamMap(maxIter -> -10) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 9d23547f28..7d990ce0bc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -34,10 +34,5 @@ class TestParams(override val uid: String) extends Params with HasHandleInvalid def clearMaxIter(): this.type = clear(maxIter) - override def validateParams(): Unit = { - super.validateParams() - require(isDefined(inputCol)) - } - override def copy(extra: ParamMap): TestParams = defaultCopy(extra) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 56545de14b..7af3c6d6ed 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLog import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -96,7 +96,7 @@ class CrossValidatorSuite assert(cvModel2.avgMetrics.length === lrParamMaps.length) } - test("validateParams should check estimatorParamMaps") { + test("transformSchema should check estimatorParamMaps") { import CrossValidatorSuite.{MyEstimator, MyEvaluator} val est = new MyEstimator("est") @@ -110,12 +110,12 @@ class CrossValidatorSuite .setEstimatorParamMaps(paramMaps) .setEvaluator(eval) - cv.validateParams() // This should pass. + cv.transformSchema(new StructType()) // This should pass. val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") cv.setEstimatorParamMaps(invalidParamMaps) intercept[IllegalArgumentException] { - cv.validateParams() + cv.transformSchema(new StructType()) } } @@ -311,14 +311,13 @@ object CrossValidatorSuite extends SparkFunSuite { class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { - override def validateParams(): Unit = require($(inputCol).nonEmpty) - override def fit(dataset: DataFrame): MyModel = { throw new UnsupportedOperationException } override def transformSchema(schema: StructType): StructType = { - throw new UnsupportedOperationException + require($(inputCol).nonEmpty) + schema } override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 5fb80091d0..cf8dcefebc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -83,7 +83,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext assert(cvModel2.validationMetrics.length === lrParamMaps.length) } - test("validateParams should check estimatorParamMaps") { + test("transformSchema should check estimatorParamMaps") { import TrainValidationSplitSuite._ val est = new MyEstimator("est") @@ -97,12 +97,12 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext .setEstimatorParamMaps(paramMaps) .setEvaluator(eval) .setTrainRatio(0.5) - cv.validateParams() // This should pass. + cv.transformSchema(new StructType()) // This should pass. val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") cv.setEstimatorParamMaps(invalidParamMaps) intercept[IllegalArgumentException] { - cv.validateParams() + cv.transformSchema(new StructType()) } } } @@ -113,14 +113,13 @@ object TrainValidationSplitSuite { class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { - override def validateParams(): Unit = require($(inputCol).nonEmpty) - override def fit(dataset: DataFrame): MyModel = { throw new UnsupportedOperationException } override def transformSchema(schema: StructType): StructType = { - throw new UnsupportedOperationException + require($(inputCol).nonEmpty) + schema } override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra) |