aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-11-22 21:48:48 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-22 21:48:48 -0800
commita6fda0bfc16a13b28b1cecc96f1ff91363089144 (patch)
treecad3ab93b533e4611aa86d79298ae34f376d7dd5 /mllib/src/test/scala
parentfe89c1817d668e46adf70d0896c42c22a547c76a (diff)
downloadspark-a6fda0bfc16a13b28b1cecc96f1ff91363089144.tar.gz
spark-a6fda0bfc16a13b28b1cecc96f1ff91363089144.tar.bz2
spark-a6fda0bfc16a13b28b1cecc96f1ff91363089144.zip
[SPARK-6791][ML] Add read/write for CrossValidator and Evaluators
I believe this works for general estimators within CrossValidator, including compound estimators. (See the complex unit test.) Added read/write for all 3 Evaluators as well. CC: mengxr yanboliang Author: Joseph K. Bradley <joseph@databricks.com> Closes #9848 from jkbradley/cv-io.
Diffstat (limited to 'mllib/src/test/scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala13
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala13
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala202
5 files changed, 231 insertions, 13 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 12aba6bc6d..8c86767456 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -17,11 +17,9 @@
package org.apache.spark.ml
-import java.io.File
-
import scala.collection.JavaConverters._
-import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.Path
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.when
import org.scalatest.mock.MockitoSugar.mock
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
index def869fe66..a535c1218e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
@@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.util.MLlibTestSparkContext
-class BinaryClassificationEvaluatorSuite extends SparkFunSuite {
+class BinaryClassificationEvaluatorSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new BinaryClassificationEvaluator)
}
+
+ test("read/write") {
+ val evaluator = new BinaryClassificationEvaluator()
+ .setRawPredictionCol("myRawPrediction")
+ .setLabelCol("myLabel")
+ .setMetricName("areaUnderPR")
+ testDefaultReadWrite(evaluator)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
index 6d8412b0b3..7ee65975d2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
@@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.util.MLlibTestSparkContext
-class MulticlassClassificationEvaluatorSuite extends SparkFunSuite {
+class MulticlassClassificationEvaluatorSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new MulticlassClassificationEvaluator)
}
+
+ test("read/write") {
+ val evaluator = new MulticlassClassificationEvaluator()
+ .setPredictionCol("myPrediction")
+ .setLabelCol("myLabel")
+ .setMetricName("recall")
+ testDefaultReadWrite(evaluator)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
index aa722da323..60886bf77d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
@@ -20,10 +20,12 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
-class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class RegressionEvaluatorSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new RegressionEvaluator)
@@ -73,4 +75,12 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext
evaluator.setMetricName("mae")
assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001)
}
+
+ test("read/write") {
+ val evaluator = new RegressionEvaluator()
+ .setPredictionCol("myPrediction")
+ .setLabelCol("myLabel")
+ .setMetricName("r2")
+ testDefaultReadWrite(evaluator)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index cbe09292a0..dd6366050c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -18,19 +18,22 @@
package org.apache.spark.ml.tuning
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.MLTestingUtils
-import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.feature.HashingTF
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.{Pipeline, Estimator, Model}
+import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression}
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
-import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.{ParamPair, ParamMap}
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.types.StructType
-class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class CrossValidatorSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@transient var dataset: DataFrame = _
@@ -95,7 +98,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("validateParams should check estimatorParamMaps") {
- import CrossValidatorSuite._
+ import CrossValidatorSuite.{MyEstimator, MyEvaluator}
val est = new MyEstimator("est")
val eval = new MyEvaluator
@@ -116,9 +119,194 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
cv.validateParams()
}
}
+
+ test("read/write: CrossValidator with simple estimator") {
+ val lr = new LogisticRegression().setMaxIter(3)
+ val evaluator = new BinaryClassificationEvaluator()
+ .setMetricName("areaUnderPR") // not default metric
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.1, 0.2))
+ .build()
+ val cv = new CrossValidator()
+ .setEstimator(lr)
+ .setEvaluator(evaluator)
+ .setNumFolds(20)
+ .setEstimatorParamMaps(paramMaps)
+
+ val cv2 = testDefaultReadWrite(cv, testParams = false)
+
+ assert(cv.uid === cv2.uid)
+ assert(cv.getNumFolds === cv2.getNumFolds)
+
+ assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
+ val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
+ assert(evaluator.uid === evaluator2.uid)
+ assert(evaluator.getMetricName === evaluator2.getMetricName)
+
+ cv2.getEstimator match {
+ case lr2: LogisticRegression =>
+ assert(lr.uid === lr2.uid)
+ assert(lr.getMaxIter === lr2.getMaxIter)
+ case other =>
+ throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
+ s" LogisticRegression but found ${other.getClass.getName}")
+ }
+
+ CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+ }
+
+ test("read/write: CrossValidator with complex estimator") {
+ // workflow: CrossValidator[Pipeline[HashingTF, CrossValidator[LogisticRegression]]]
+ val lrEvaluator = new BinaryClassificationEvaluator()
+ .setMetricName("areaUnderPR") // not default metric
+
+ val lr = new LogisticRegression().setMaxIter(3)
+ val lrParamMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.1, 0.2))
+ .build()
+ val lrcv = new CrossValidator()
+ .setEstimator(lr)
+ .setEvaluator(lrEvaluator)
+ .setEstimatorParamMaps(lrParamMaps)
+
+ val hashingTF = new HashingTF()
+ val pipeline = new Pipeline().setStages(Array(hashingTF, lrcv))
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(hashingTF.numFeatures, Array(10, 20))
+ .addGrid(lr.elasticNetParam, Array(0.0, 1.0))
+ .build()
+ val evaluator = new BinaryClassificationEvaluator()
+
+ val cv = new CrossValidator()
+ .setEstimator(pipeline)
+ .setEvaluator(evaluator)
+ .setNumFolds(20)
+ .setEstimatorParamMaps(paramMaps)
+
+ val cv2 = testDefaultReadWrite(cv, testParams = false)
+
+ assert(cv.uid === cv2.uid)
+ assert(cv.getNumFolds === cv2.getNumFolds)
+
+ assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
+ assert(cv.getEvaluator.uid === cv2.getEvaluator.uid)
+
+ CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+
+ cv2.getEstimator match {
+ case pipeline2: Pipeline =>
+ assert(pipeline.uid === pipeline2.uid)
+ pipeline2.getStages match {
+ case Array(hashingTF2: HashingTF, lrcv2: CrossValidator) =>
+ assert(hashingTF.uid === hashingTF2.uid)
+ lrcv2.getEstimator match {
+ case lr2: LogisticRegression =>
+ assert(lr.uid === lr2.uid)
+ assert(lr.getMaxIter === lr2.getMaxIter)
+ case other =>
+ throw new AssertionError(s"Loaded internal CrossValidator expected to be" +
+ s" LogisticRegression but found type ${other.getClass.getName}")
+ }
+ assert(lrcv.uid === lrcv2.uid)
+ assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
+ assert(lrEvaluator.uid === lrcv2.getEvaluator.uid)
+ CrossValidatorSuite.compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps)
+ case other =>
+ throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)" +
+ " but found: " + other.map(_.getClass.getName).mkString(", "))
+ }
+ case other =>
+ throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
+ s" CrossValidator but found ${other.getClass.getName}")
+ }
+ }
+
+ test("read/write: CrossValidator fails for extraneous Param") {
+ val lr = new LogisticRegression()
+ val lr2 = new LogisticRegression()
+ val evaluator = new BinaryClassificationEvaluator()
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.1, 0.2))
+ .addGrid(lr2.regParam, Array(0.1, 0.2))
+ .build()
+ val cv = new CrossValidator()
+ .setEstimator(lr)
+ .setEvaluator(evaluator)
+ .setEstimatorParamMaps(paramMaps)
+ withClue("CrossValidator.write failed to catch extraneous Param error") {
+ intercept[IllegalArgumentException] {
+ cv.write
+ }
+ }
+ }
+
+ test("read/write: CrossValidatorModel") {
+ val lr = new LogisticRegression()
+ .setThreshold(0.6)
+ val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 1.2)
+ .setThreshold(0.6)
+ val evaluator = new BinaryClassificationEvaluator()
+ .setMetricName("areaUnderPR") // not default metric
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.1, 0.2))
+ .build()
+ val cv = new CrossValidatorModel("cvUid", lrModel, Array(0.3, 0.6))
+ cv.set(cv.estimator, lr)
+ .set(cv.evaluator, evaluator)
+ .set(cv.numFolds, 20)
+ .set(cv.estimatorParamMaps, paramMaps)
+
+ val cv2 = testDefaultReadWrite(cv, testParams = false)
+
+ assert(cv.uid === cv2.uid)
+ assert(cv.getNumFolds === cv2.getNumFolds)
+
+ assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
+ val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
+ assert(evaluator.uid === evaluator2.uid)
+ assert(evaluator.getMetricName === evaluator2.getMetricName)
+
+ cv2.getEstimator match {
+ case lr2: LogisticRegression =>
+ assert(lr.uid === lr2.uid)
+ assert(lr.getThreshold === lr2.getThreshold)
+ case other =>
+ throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
+ s" LogisticRegression but found ${other.getClass.getName}")
+ }
+
+ CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+
+ cv2.bestModel match {
+ case lrModel2: LogisticRegressionModel =>
+ assert(lrModel.uid === lrModel2.uid)
+ assert(lrModel.getThreshold === lrModel2.getThreshold)
+ assert(lrModel.coefficients === lrModel2.coefficients)
+ assert(lrModel.intercept === lrModel2.intercept)
+ case other =>
+ throw new AssertionError(s"Loaded CrossValidator expected bestModel of type" +
+ s" LogisticRegressionModel but found ${other.getClass.getName}")
+ }
+ assert(cv.avgMetrics === cv2.avgMetrics)
+ }
}
-object CrossValidatorSuite {
+object CrossValidatorSuite extends SparkFunSuite {
+
+ /**
+ * Assert sequences of estimatorParamMaps are identical.
+ * Params must be simple types comparable with `===`.
+ */
+ def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = {
+ assert(pMaps.length === pMaps2.length)
+ pMaps.zip(pMaps2).foreach { case (pMap, pMap2) =>
+ assert(pMap.size === pMap2.size)
+ pMap.toSeq.foreach { case ParamPair(p, v) =>
+ assert(pMap2.contains(p))
+ assert(pMap2(p) === v)
+ }
+ }
+ }
abstract class MyModel extends Model[MyModel]