aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-11-10 18:45:48 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-10 18:45:48 -0800
commit6e101d2e9d6e08a6a63f7065c1e87a5338f763ea (patch)
treef93c013e57ee3644af985e1c5aae11659269e22e /mllib/src/test/scala/org/apache
parent745e45d5ff7fe251c0d5197b7e08b1f80807b005 (diff)
downloadspark-6e101d2e9d6e08a6a63f7065c1e87a5338f763ea.tar.gz
spark-6e101d2e9d6e08a6a63f7065c1e87a5338f763ea.tar.bz2
spark-6e101d2e9d6e08a6a63f7065c1e87a5338f763ea.zip
[SPARK-6726][ML] Import/export for spark.ml LogisticRegressionModel
This PR adds model save/load for spark.ml's LogisticRegressionModel. It also does minor refactoring of the default save/load classes to reuse code. CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #9606 from jkbradley/logreg-io2.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala17
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala4
2 files changed, 18 insertions, 3 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 325faf37e8..51b06b7eb6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -23,7 +23,7 @@ import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{Identifiable, DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
@@ -31,7 +31,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
-class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
+class LogisticRegressionSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@transient var dataset: DataFrame = _
@transient var binaryDataset: DataFrame = _
@@ -869,6 +870,18 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3)
assert(model1a0.coefficients ~== model1b.coefficients absTol 1E-3)
assert(model1a0.intercept ~== model1b.intercept absTol 1E-3)
+ }
+ test("read/write") {
+ // Set some Params to make sure set Params are serialized.
+ val lr = new LogisticRegression()
+ .setElasticNetParam(0.1)
+ .setMaxIter(2)
+ .fit(dataset)
+ val lr2 = testDefaultReadWrite(lr)
+ assert(lr.intercept === lr2.intercept)
+ assert(lr.coefficients.toArray === lr2.coefficients.toArray)
+ assert(lr.numClasses === lr2.numClasses)
+ assert(lr.numFeatures === lr2.numFeatures)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index 4545b0f281..cac4bd9aa3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -31,8 +31,9 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
* Checks "overwrite" option and params.
* @param instance ML instance to test saving/loading
* @tparam T ML instance type
+ * @return Instance loaded from file
*/
- def testDefaultReadWrite[T <: Params with Writable](instance: T): Unit = {
+ def testDefaultReadWrite[T <: Params with Writable](instance: T): T = {
val uid = instance.uid
val path = new File(tempDir, uid).getPath
@@ -61,6 +62,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
val load = instance.getClass.getMethod("load", classOf[String])
val another = load.invoke(instance, path).asInstanceOf[T]
assert(another.uid === instance.uid)
+ another
}
}