diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-03-09 11:59:22 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-03-09 11:59:22 -0800 |
commit | 0dd06485c4222a896c0d1ee6a04d30043de3626c (patch) | |
tree | f1ffc18c78e4dcb4caf3872c34a4a6fc2616e223 /mllib/src/test | |
parent | cad29a40b24a8e89f2d906e263866546f8ab6071 (diff) | |
download | spark-0dd06485c4222a896c0d1ee6a04d30043de3626c.tar.gz spark-0dd06485c4222a896c0d1ee6a04d30043de3626c.tar.bz2 spark-0dd06485c4222a896c0d1ee6a04d30043de3626c.zip |
[SPARK-13615][ML] GeneralizedLinearRegression supports save/load
## What changes were proposed in this pull request?
```GeneralizedLinearRegression``` supports ```save/load```.
cc mengxr
## How was this patch tested?
unit test.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #11465 from yanboliang/spark-13615.
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala | 32 |
1 files changed, 30 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 8bfa9855ce..618304ad19 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.{BLAS, DenseVector, Vectors} import org.apache.spark.mllib.random._ @@ -30,7 +30,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class GeneralizedLinearRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { private val seed: Int = 42 @transient var datasetGaussianIdentity: DataFrame = _ @@ -464,10 +465,37 @@ class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSpark } } } + + test("read/write") { + def checkModelData( + model: GeneralizedLinearRegressionModel, + model2: GeneralizedLinearRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients.toArray === model2.coefficients.toArray) + } + + val glr = new GeneralizedLinearRegression() + testEstimatorAndModelReadWrite(glr, datasetPoissonLog, + GeneralizedLinearRegressionSuite.allParamSettings, checkModelData) + } } object GeneralizedLinearRegressionSuite { + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "family" -> "poisson", + "link" -> "log", + "fitIntercept" -> true, + "maxIter" -> 2, // intentionally small + "tol" -> 0.8, + "regParam" -> 0.01, + "predictionCol" -> "myPrediction") + def generateGeneralizedLinearRegressionInput( intercept: Double, coefficients: Array[Double], |