diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 56545de14b..7af3c6d6ed 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLog import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -96,7 +96,7 @@ class CrossValidatorSuite assert(cvModel2.avgMetrics.length === lrParamMaps.length) } - test("validateParams should check estimatorParamMaps") { + test("transformSchema should check estimatorParamMaps") { import CrossValidatorSuite.{MyEstimator, MyEvaluator} val est = new MyEstimator("est") @@ -110,12 +110,12 @@ class CrossValidatorSuite .setEstimatorParamMaps(paramMaps) .setEvaluator(eval) - cv.validateParams() // This should pass. + cv.transformSchema(new StructType()) // This should pass. val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") cv.setEstimatorParamMaps(invalidParamMaps) intercept[IllegalArgumentException] { - cv.validateParams() + cv.transformSchema(new StructType()) } } @@ -311,14 +311,13 @@ object CrossValidatorSuite extends SparkFunSuite { class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { - override def validateParams(): Unit = require($(inputCol).nonEmpty) - override def fit(dataset: DataFrame): MyModel = { throw new UnsupportedOperationException } override def transformSchema(schema: StructType): StructType = { - throw new UnsupportedOperationException + require($(inputCol).nonEmpty) + schema } override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra) |