aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-03-28 15:40:06 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-28 15:40:06 -0700
commit8c11d1aab8522c75d78bc6b30402c64e8d9ff065 (patch)
tree7e31c6256fcb49db9eb0f47b86b24daa24764650 /mllib/src/test
parent39f743a6231cbd8cc770a28f43ee601eff28d597 (diff)
downloadspark-8c11d1aab8522c75d78bc6b30402c64e8d9ff065.tar.gz
spark-8c11d1aab8522c75d78bc6b30402c64e8d9ff065.tar.bz2
spark-8c11d1aab8522c75d78bc6b30402c64e8d9ff065.zip
[SPARK-11893] Model export/import for spark.ml: TrainValidationSplit
https://issues.apache.org/jira/browse/SPARK-11893 jkbradley In order to share read/write with `TrainValidationSplit`, I move the `SharedReadWrite` out of `CrossValidator` into a new trait `SharedReadWrite` in the tunning package. To reduce the repeated tests, I move the complex tests from `CrossValidatorSuite` to `SharedReadWriteSuite`, and create a fake validator called `MyValidator` to test the shared code. With `SharedReadWrite`, potential newly added `Validator` can share the read/write common part, and only need to implement their extra params save/load. Author: Xusen Yin <yinxusen@gmail.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #9971 from yinxusen/SPARK-11893.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala45
1 files changed, 43 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index cf8dcefebc..7cf7b3e087 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -19,17 +19,20 @@ package org.apache.spark.ml.tuning
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
-class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext {
+class TrainValidationSplitSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("train validation with logistic regression") {
val dataset = sqlContext.createDataFrame(
sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
@@ -105,6 +108,44 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
cv.transformSchema(new StructType())
}
}
+
+ test("read/write: TrainValidationSplit") {
+ val lr = new LogisticRegression().setMaxIter(3)
+ val evaluator = new BinaryClassificationEvaluator()
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.1, 0.2))
+ .build()
+ val tvs = new TrainValidationSplit()
+ .setEstimator(lr)
+ .setEvaluator(evaluator)
+ .setTrainRatio(0.5)
+ .setEstimatorParamMaps(paramMaps)
+
+ val tvs2 = testDefaultReadWrite(tvs, testParams = false)
+
+ assert(tvs.getTrainRatio === tvs2.getTrainRatio)
+ }
+
+ test("read/write: TrainValidationSplitModel") {
+ val lr = new LogisticRegression()
+ .setThreshold(0.6)
+ val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 1.2)
+ .setThreshold(0.6)
+ val evaluator = new BinaryClassificationEvaluator()
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.1, 0.2))
+ .build()
+ val tvs = new TrainValidationSplitModel("cvUid", lrModel, Array(0.3, 0.6))
+ tvs.set(tvs.estimator, lr)
+ .set(tvs.evaluator, evaluator)
+ .set(tvs.trainRatio, 0.5)
+ .set(tvs.estimatorParamMaps, paramMaps)
+
+ val tvs2 = testDefaultReadWrite(tvs, testParams = false)
+
+ assert(tvs.getTrainRatio === tvs2.getTrainRatio)
+ assert(tvs.validationMetrics === tvs2.validationMetrics)
+ }
}
object TrainValidationSplitSuite {