aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala68
1 files changed, 66 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 2ae74a2090..51c1baf682 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
-import org.apache.spark.ml.util.{MetadataUtils, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils}
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
@@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.Metadata
-class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
+class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@transient var dataset: DataFrame = _
@transient var rdd: RDD[LabeledPoint] = _
@@ -160,6 +160,70 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
require(m.getThreshold === 0.1, "copy should handle extra model params")
}
}
+
+ test("read/write: OneVsRest") {
+ val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01)
+
+ val ova = new OneVsRest()
+ .setClassifier(lr)
+ .setLabelCol("myLabel")
+ .setFeaturesCol("myFeature")
+ .setPredictionCol("myPrediction")
+
+ val ova2 = testDefaultReadWrite(ova, testParams = false)
+ assert(ova.uid === ova2.uid)
+ assert(ova.getFeaturesCol === ova2.getFeaturesCol)
+ assert(ova.getLabelCol === ova2.getLabelCol)
+ assert(ova.getPredictionCol === ova2.getPredictionCol)
+
+ ova2.getClassifier match {
+ case lr2: LogisticRegression =>
+ assert(lr.uid === lr2.uid)
+ assert(lr.getMaxIter === lr2.getMaxIter)
+ assert(lr.getRegParam === lr2.getRegParam)
+ case other =>
+ throw new AssertionError(s"Loaded OneVsRest expected classifier of type" +
+ s" LogisticRegression but found ${other.getClass.getName}")
+ }
+ }
+
+ test("read/write: OneVsRestModel") {
+ def checkModelData(model: OneVsRestModel, model2: OneVsRestModel): Unit = {
+ assert(model.uid === model2.uid)
+ assert(model.getFeaturesCol === model2.getFeaturesCol)
+ assert(model.getLabelCol === model2.getLabelCol)
+ assert(model.getPredictionCol === model2.getPredictionCol)
+
+ val classifier = model.getClassifier.asInstanceOf[LogisticRegression]
+
+ model2.getClassifier match {
+ case lr2: LogisticRegression =>
+ assert(classifier.uid === lr2.uid)
+ assert(classifier.getMaxIter === lr2.getMaxIter)
+ assert(classifier.getRegParam === lr2.getRegParam)
+ case other =>
+ throw new AssertionError(s"Loaded OneVsRestModel expected classifier of type" +
+ s" LogisticRegression but found ${other.getClass.getName}")
+ }
+
+ assert(model.labelMetadata === model2.labelMetadata)
+ model.models.zip(model2.models).foreach {
+ case (lrModel1: LogisticRegressionModel, lrModel2: LogisticRegressionModel) =>
+ assert(lrModel1.uid === lrModel2.uid)
+ assert(lrModel1.coefficients === lrModel2.coefficients)
+ assert(lrModel1.intercept === lrModel2.intercept)
+ case other =>
+ throw new AssertionError(s"Loaded OneVsRestModel expected model of type" +
+ s" LogisticRegressionModel but found ${other.getClass.getName}")
+ }
+ }
+
+ val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01)
+ val ova = new OneVsRest().setClassifier(lr)
+ val ovaModel = ova.fit(dataset)
+ val newOvaModel = testDefaultReadWrite(ovaModel, testParams = false)
+ checkModelData(ovaModel, newOvaModel)
+ }
}
private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {