aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-03-31 11:17:32 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-31 11:17:32 -0700
commit8b207f3b6a0eb617d38091f3b9001830ac3651fe (patch)
treeb3bca571692fd67c3ae40d4e3af29a4ddecd056d /mllib/src/test
parenta0a1991580ed24230f88cae9f5a4dfbe58f03b28 (diff)
downloadspark-8b207f3b6a0eb617d38091f3b9001830ac3651fe.tar.gz
spark-8b207f3b6a0eb617d38091f3b9001830ac3651fe.tar.bz2
spark-8b207f3b6a0eb617d38091f3b9001830ac3651fe.zip
[SPARK-11892][ML] Model export/import for spark.ml: OneVsRest
# What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-11892 Add save/load for spark ml.OneVsRest and its model. Also add OneVsRest and OneVsRestModel in MetaAlgorithmReadWrite. # How was this patch tested? Test with Scala unit test. Author: Xusen Yin <yinxusen@gmail.com> Closes #9934 from yinxusen/SPARK-11892.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala68
1 files changed, 66 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 2ae74a2090..51c1baf682 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
-import org.apache.spark.ml.util.{MetadataUtils, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils}
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
@@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.Metadata
-class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
+class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@transient var dataset: DataFrame = _
@transient var rdd: RDD[LabeledPoint] = _
@@ -160,6 +160,70 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
require(m.getThreshold === 0.1, "copy should handle extra model params")
}
}
+
+ test("read/write: OneVsRest") {
+ val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01)
+
+ val ova = new OneVsRest()
+ .setClassifier(lr)
+ .setLabelCol("myLabel")
+ .setFeaturesCol("myFeature")
+ .setPredictionCol("myPrediction")
+
+ val ova2 = testDefaultReadWrite(ova, testParams = false)
+ assert(ova.uid === ova2.uid)
+ assert(ova.getFeaturesCol === ova2.getFeaturesCol)
+ assert(ova.getLabelCol === ova2.getLabelCol)
+ assert(ova.getPredictionCol === ova2.getPredictionCol)
+
+ ova2.getClassifier match {
+ case lr2: LogisticRegression =>
+ assert(lr.uid === lr2.uid)
+ assert(lr.getMaxIter === lr2.getMaxIter)
+ assert(lr.getRegParam === lr2.getRegParam)
+ case other =>
+ throw new AssertionError(s"Loaded OneVsRest expected classifier of type" +
+ s" LogisticRegression but found ${other.getClass.getName}")
+ }
+ }
+
+ test("read/write: OneVsRestModel") {
+ def checkModelData(model: OneVsRestModel, model2: OneVsRestModel): Unit = {
+ assert(model.uid === model2.uid)
+ assert(model.getFeaturesCol === model2.getFeaturesCol)
+ assert(model.getLabelCol === model2.getLabelCol)
+ assert(model.getPredictionCol === model2.getPredictionCol)
+
+ val classifier = model.getClassifier.asInstanceOf[LogisticRegression]
+
+ model2.getClassifier match {
+ case lr2: LogisticRegression =>
+ assert(classifier.uid === lr2.uid)
+ assert(classifier.getMaxIter === lr2.getMaxIter)
+ assert(classifier.getRegParam === lr2.getRegParam)
+ case other =>
+ throw new AssertionError(s"Loaded OneVsRestModel expected classifier of type" +
+ s" LogisticRegression but found ${other.getClass.getName}")
+ }
+
+ assert(model.labelMetadata === model2.labelMetadata)
+ model.models.zip(model2.models).foreach {
+ case (lrModel1: LogisticRegressionModel, lrModel2: LogisticRegressionModel) =>
+ assert(lrModel1.uid === lrModel2.uid)
+ assert(lrModel1.coefficients === lrModel2.coefficients)
+ assert(lrModel1.intercept === lrModel2.intercept)
+ case other =>
+ throw new AssertionError(s"Loaded OneVsRestModel expected model of type" +
+ s" LogisticRegressionModel but found ${other.getClass.getName}")
+ }
+ }
+
+ val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01)
+ val ova = new OneVsRest().setClassifier(lr)
+ val ovaModel = ova.fit(dataset)
+ val newOvaModel = testDefaultReadWrite(ovaModel, testParams = false)
+ checkModelData(ovaModel, newOvaModel)
+ }
}
private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {