diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2015-11-10 18:45:48 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-11-10 18:45:48 -0800 |
commit | 6e101d2e9d6e08a6a63f7065c1e87a5338f763ea (patch) | |
tree | f93c013e57ee3644af985e1c5aae11659269e22e /mllib/src/test/scala/org/apache | |
parent | 745e45d5ff7fe251c0d5197b7e08b1f80807b005 (diff) | |
download | spark-6e101d2e9d6e08a6a63f7065c1e87a5338f763ea.tar.gz spark-6e101d2e9d6e08a6a63f7065c1e87a5338f763ea.tar.bz2 spark-6e101d2e9d6e08a6a63f7065c1e87a5338f763ea.zip |
[SPARK-6726][ML] Import/export for spark.ml LogisticRegressionModel
This PR adds model save/load for spark.ml's LogisticRegressionModel. It also does minor refactoring of the default save/load classes to reuse code.
CC: mengxr
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #9606 from jkbradley/logreg-io2.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala | 17 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala | 4 |
2 files changed, 18 insertions, 3 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 325faf37e8..51b06b7eb6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -23,7 +23,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{Identifiable, DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint @@ -31,7 +31,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class LogisticRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var dataset: DataFrame = _ @transient var binaryDataset: DataFrame = _ @@ -869,6 +870,18 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) assert(model1a0.coefficients ~== model1b.coefficients absTol 1E-3) assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) + } + test("read/write") { + // Set some Params to make sure set Params are serialized. + val lr = new LogisticRegression() + .setElasticNetParam(0.1) + .setMaxIter(2) + .fit(dataset) + val lr2 = testDefaultReadWrite(lr) + assert(lr.intercept === lr2.intercept) + assert(lr.coefficients.toArray === lr2.coefficients.toArray) + assert(lr.numClasses === lr2.numClasses) + assert(lr.numFeatures === lr2.numFeatures) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 4545b0f281..cac4bd9aa3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -31,8 +31,9 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * Checks "overwrite" option and params. * @param instance ML instance to test saving/loading * @tparam T ML instance type + * @return Instance loaded from file */ - def testDefaultReadWrite[T <: Params with Writable](instance: T): Unit = { + def testDefaultReadWrite[T <: Params with Writable](instance: T): T = { val uid = instance.uid val path = new File(tempDir, uid).getPath @@ -61,6 +62,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => val load = instance.getClass.getMethod("load", classOf[String]) val another = load.invoke(instance, path).asInstanceOf[T] assert(another.uid === instance.uid) + another } } |