From 8c11d1aab8522c75d78bc6b30402c64e8d9ff065 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 28 Mar 2016 15:40:06 -0700 Subject: [SPARK-11893] Model export/import for spark.ml: TrainValidationSplit https://issues.apache.org/jira/browse/SPARK-11893 jkbradley In order to share read/write with `TrainValidationSplit`, I move the `SharedReadWrite` out of `CrossValidator` into a new trait `SharedReadWrite` in the tunning package. To reduce the repeated tests, I move the complex tests from `CrossValidatorSuite` to `SharedReadWriteSuite`, and create a fake validator called `MyValidator` to test the shared code. With `SharedReadWrite`, potential newly added `Validator` can share the read/write common part, and only need to implement their extra params save/load. Author: Xusen Yin Author: Joseph K. Bradley Closes #9971 from yinxusen/SPARK-11893. --- .../ml/tuning/TrainValidationSplitSuite.scala | 45 +++++++++++++++++++++- 1 file changed, 43 insertions(+), 2 deletions(-) (limited to 'mllib/src/test') diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index cf8dcefebc..7cf7b3e087 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -19,17 +19,20 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType -class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext { +class TrainValidationSplitSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("train validation with logistic regression") { val dataset = sqlContext.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) @@ -105,6 +108,44 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext cv.transformSchema(new StructType()) } } + + test("read/write: TrainValidationSplit") { + val lr = new LogisticRegression().setMaxIter(3) + val evaluator = new BinaryClassificationEvaluator() + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val tvs = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(evaluator) + .setTrainRatio(0.5) + .setEstimatorParamMaps(paramMaps) + + val tvs2 = testDefaultReadWrite(tvs, testParams = false) + + assert(tvs.getTrainRatio === tvs2.getTrainRatio) + } + + test("read/write: TrainValidationSplitModel") { + val lr = new LogisticRegression() + .setThreshold(0.6) + val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 1.2) + .setThreshold(0.6) + val evaluator = new BinaryClassificationEvaluator() + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val tvs = new TrainValidationSplitModel("cvUid", lrModel, Array(0.3, 0.6)) + tvs.set(tvs.estimator, lr) + .set(tvs.evaluator, evaluator) + .set(tvs.trainRatio, 0.5) + .set(tvs.estimatorParamMaps, paramMaps) + + val tvs2 = testDefaultReadWrite(tvs, testParams = false) + + assert(tvs.getTrainRatio === tvs2.getTrainRatio) + assert(tvs.validationMetrics === tvs2.validationMetrics) + } } object TrainValidationSplitSuite { -- cgit v1.2.3