aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala11
1 files changed, 5 insertions, 6 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index 5fb80091d0..cf8dcefebc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -83,7 +83,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
assert(cvModel2.validationMetrics.length === lrParamMaps.length)
}
- test("validateParams should check estimatorParamMaps") {
+ test("transformSchema should check estimatorParamMaps") {
import TrainValidationSplitSuite._
val est = new MyEstimator("est")
@@ -97,12 +97,12 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
.setEstimatorParamMaps(paramMaps)
.setEvaluator(eval)
.setTrainRatio(0.5)
- cv.validateParams() // This should pass.
+ cv.transformSchema(new StructType()) // This should pass.
val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
cv.setEstimatorParamMaps(invalidParamMaps)
intercept[IllegalArgumentException] {
- cv.validateParams()
+ cv.transformSchema(new StructType())
}
}
}
@@ -113,14 +113,13 @@ object TrainValidationSplitSuite {
class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
- override def validateParams(): Unit = require($(inputCol).nonEmpty)
-
override def fit(dataset: DataFrame): MyModel = {
throw new UnsupportedOperationException
}
override def transformSchema(schema: StructType): StructType = {
- throw new UnsupportedOperationException
+ require($(inputCol).nonEmpty)
+ schema
}
override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)