From 8b207f3b6a0eb617d38091f3b9001830ac3651fe Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 31 Mar 2016 11:17:32 -0700 Subject: [SPARK-11892][ML] Model export/import for spark.ml: OneVsRest # What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-11892 Add save/load for spark ml.OneVsRest and its model. Also add OneVsRest and OneVsRestModel in MetaAlgorithmReadWrite. # How was this patch tested? Test with Scala unit test. Author: Xusen Yin Closes #9934 from yinxusen/SPARK-11892. --- .../spark/ml/classification/OneVsRestSuite.scala | 68 +++++++++++++++++++++- 1 file changed, 66 insertions(+), 2 deletions(-) (limited to 'mllib/src/test') diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 2ae74a2090..51c1baf682 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.feature.StringIndexer import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{MetadataUtils, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics @@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.Metadata -class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { +class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var dataset: DataFrame = _ @transient var rdd: RDD[LabeledPoint] = _ @@ -160,6 +160,70 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { require(m.getThreshold === 0.1, "copy should handle extra model params") } } + + test("read/write: OneVsRest") { + val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01) + + val ova = new OneVsRest() + .setClassifier(lr) + .setLabelCol("myLabel") + .setFeaturesCol("myFeature") + .setPredictionCol("myPrediction") + + val ova2 = testDefaultReadWrite(ova, testParams = false) + assert(ova.uid === ova2.uid) + assert(ova.getFeaturesCol === ova2.getFeaturesCol) + assert(ova.getLabelCol === ova2.getLabelCol) + assert(ova.getPredictionCol === ova2.getPredictionCol) + + ova2.getClassifier match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getMaxIter === lr2.getMaxIter) + assert(lr.getRegParam === lr2.getRegParam) + case other => + throw new AssertionError(s"Loaded OneVsRest expected classifier of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + } + + test("read/write: OneVsRestModel") { + def checkModelData(model: OneVsRestModel, model2: OneVsRestModel): Unit = { + assert(model.uid === model2.uid) + assert(model.getFeaturesCol === model2.getFeaturesCol) + assert(model.getLabelCol === model2.getLabelCol) + assert(model.getPredictionCol === model2.getPredictionCol) + + val classifier = model.getClassifier.asInstanceOf[LogisticRegression] + + model2.getClassifier match { + case lr2: LogisticRegression => + assert(classifier.uid === lr2.uid) + assert(classifier.getMaxIter === lr2.getMaxIter) + assert(classifier.getRegParam === lr2.getRegParam) + case other => + throw new AssertionError(s"Loaded OneVsRestModel expected classifier of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + + assert(model.labelMetadata === model2.labelMetadata) + model.models.zip(model2.models).foreach { + case (lrModel1: LogisticRegressionModel, lrModel2: LogisticRegressionModel) => + assert(lrModel1.uid === lrModel2.uid) + assert(lrModel1.coefficients === lrModel2.coefficients) + assert(lrModel1.intercept === lrModel2.intercept) + case other => + throw new AssertionError(s"Loaded OneVsRestModel expected model of type" + + s" LogisticRegressionModel but found ${other.getClass.getName}") + } + } + + val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01) + val ova = new OneVsRest().setClassifier(lr) + val ovaModel = ova.fit(dataset) + val newOvaModel = testDefaultReadWrite(ovaModel, testParams = false) + checkModelData(ovaModel, newOvaModel) + } } private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) { -- cgit v1.2.3