diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 5fb80091d0..cf8dcefebc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -83,7 +83,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext assert(cvModel2.validationMetrics.length === lrParamMaps.length) } - test("validateParams should check estimatorParamMaps") { + test("transformSchema should check estimatorParamMaps") { import TrainValidationSplitSuite._ val est = new MyEstimator("est") @@ -97,12 +97,12 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext .setEstimatorParamMaps(paramMaps) .setEvaluator(eval) .setTrainRatio(0.5) - cv.validateParams() // This should pass. + cv.transformSchema(new StructType()) // This should pass. val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") cv.setEstimatorParamMaps(invalidParamMaps) intercept[IllegalArgumentException] { - cv.validateParams() + cv.transformSchema(new StructType()) } } } @@ -113,14 +113,13 @@ object TrainValidationSplitSuite { class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { - override def validateParams(): Unit = require($(inputCol).nonEmpty) - override def fit(dataset: DataFrame): MyModel = { throw new UnsupportedOperationException } override def transformSchema(schema: StructType): StructType = { - throw new UnsupportedOperationException + require($(inputCol).nonEmpty) + schema } override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra) |