aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-03-17 10:19:10 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-17 10:19:10 -0700
commitedf8b8775b81f5522680094bf24f372aa0c61447 (patch)
treea3c4255165d1674a23bacccdfa66e8f903714717 /mllib/src/test
parent828213d4ca4b0e845c4d6d778455335f187158a4 (diff)
downloadspark-edf8b8775b81f5522680094bf24f372aa0c61447.tar.gz
spark-edf8b8775b81f5522680094bf24f372aa0c61447.tar.bz2
spark-edf8b8775b81f5522680094bf24f372aa0c61447.zip
[SPARK-11891] Model export/import for RFormula and RFormulaModel
https://issues.apache.org/jira/browse/SPARK-11891 Author: Xusen Yin <yinxusen@gmail.com> Closes #9884 from yinxusen/SPARK-11891.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala40
1 files changed, 39 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 16e565d8b5..e1b269b5b6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -20,10 +20,11 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
+class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new RFormula())
}
@@ -252,4 +253,41 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
new NumericAttribute(Some("a_foo:b_zz"), Some(4))))
assert(attrs === expectedAttrs)
}
+
+ test("read/write: RFormula") {
+ val rFormula = new RFormula()
+ .setFormula("id ~ a:b")
+ .setFeaturesCol("myFeatures")
+ .setLabelCol("myLabels")
+
+ testDefaultReadWrite(rFormula)
+ }
+
+ test("read/write: RFormulaModel") {
+ def checkModelData(model: RFormulaModel, model2: RFormulaModel): Unit = {
+ assert(model.uid === model2.uid)
+
+ assert(model.resolvedFormula.label === model2.resolvedFormula.label)
+ assert(model.resolvedFormula.terms === model2.resolvedFormula.terms)
+ assert(model.resolvedFormula.hasIntercept === model2.resolvedFormula.hasIntercept)
+
+ assert(model.pipelineModel.uid === model2.pipelineModel.uid)
+
+ model.pipelineModel.stages.zip(model2.pipelineModel.stages).foreach {
+ case (transformer1, transformer2) =>
+ assert(transformer1.uid === transformer2.uid)
+ assert(transformer1.params === transformer2.params)
+ }
+ }
+
+ val dataset = sqlContext.createDataFrame(
+ Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz"))
+ ).toDF("id", "a", "b")
+
+ val rFormula = new RFormula().setFormula("id ~ a:b")
+
+ val model = rFormula.fit(dataset)
+ val newModel = testDefaultReadWrite(model)
+ checkModelData(model, newModel)
+ }
}