aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-03-09 11:59:22 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-03-09 11:59:22 -0800
commit0dd06485c4222a896c0d1ee6a04d30043de3626c (patch)
treef1ffc18c78e4dcb4caf3872c34a4a6fc2616e223 /mllib/src/test
parentcad29a40b24a8e89f2d906e263866546f8ab6071 (diff)
downloadspark-0dd06485c4222a896c0d1ee6a04d30043de3626c.tar.gz
spark-0dd06485c4222a896c0d1ee6a04d30043de3626c.tar.bz2
spark-0dd06485c4222a896c0d1ee6a04d30043de3626c.zip
[SPARK-13615][ML] GeneralizedLinearRegression supports save/load
## What changes were proposed in this pull request? ```GeneralizedLinearRegression``` supports ```save/load```. cc mengxr ## How was this patch tested? unit test. Author: Yanbo Liang <ybliang8@gmail.com> Closes #11465 from yanboliang/spark-13615.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala32
1 files changed, 30 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 8bfa9855ce..618304ad19 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -21,7 +21,7 @@ import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.linalg.{BLAS, DenseVector, Vectors}
import org.apache.spark.mllib.random._
@@ -30,7 +30,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
-class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
+class GeneralizedLinearRegressionSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
private val seed: Int = 42
@transient var datasetGaussianIdentity: DataFrame = _
@@ -464,10 +465,37 @@ class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSpark
}
}
}
+
+ test("read/write") {
+ def checkModelData(
+ model: GeneralizedLinearRegressionModel,
+ model2: GeneralizedLinearRegressionModel): Unit = {
+ assert(model.intercept === model2.intercept)
+ assert(model.coefficients.toArray === model2.coefficients.toArray)
+ }
+
+ val glr = new GeneralizedLinearRegression()
+ testEstimatorAndModelReadWrite(glr, datasetPoissonLog,
+ GeneralizedLinearRegressionSuite.allParamSettings, checkModelData)
+ }
}
object GeneralizedLinearRegressionSuite {
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = Map(
+ "family" -> "poisson",
+ "link" -> "log",
+ "fitIntercept" -> true,
+ "maxIter" -> 2, // intentionally small
+ "tol" -> 0.8,
+ "regParam" -> 0.01,
+ "predictionCol" -> "myPrediction")
+
def generateGeneralizedLinearRegressionInput(
intercept: Double,
coefficients: Array[Double],