aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorWenjian Huang <nextrush@163.com>2015-11-18 13:06:25 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-18 13:06:25 -0800
commit045a4f045821dcf60442f0600c2df1b79bddb536 (patch)
treea2ba86ade1009cef997bf0bfed4e9d4a34152384 /mllib/src/test
parent09ad9533d5760652de59fa4830c24cb8667958ac (diff)
downloadspark-045a4f045821dcf60442f0600c2df1b79bddb536.tar.gz
spark-045a4f045821dcf60442f0600c2df1b79bddb536.tar.bz2
spark-045a4f045821dcf60442f0600c2df1b79bddb536.zip
[SPARK-6790][ML] Add spark.ml LinearRegression import/export
This replaces [https://github.com/apache/spark/pull/9656] with updates. fayeshine should be the main author when this PR is committed. CC: mengxr fayeshine Author: Wenjian Huang <nextrush@163.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #9814 from jkbradley/fayeshine-patch-6790.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala34
1 files changed, 32 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index a1d86fe8fe..2bdc0e184d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -22,14 +22,15 @@ import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.{Vector, DenseVector, Vectors}
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
-class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
+class LinearRegressionSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
private val seed: Int = 42
@transient var datasetWithDenseFeature: DataFrame = _
@@ -854,4 +855,33 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) }
model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) }
}
+
+ test("read/write") {
+ def checkModelData(model: LinearRegressionModel, model2: LinearRegressionModel): Unit = {
+ assert(model.intercept === model2.intercept)
+ assert(model.coefficients === model2.coefficients)
+ }
+ val lr = new LinearRegression()
+ testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings,
+ checkModelData)
+ }
+}
+
+object LinearRegressionSuite {
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = Map(
+ "predictionCol" -> "myPrediction",
+ "regParam" -> 0.01,
+ "elasticNetParam" -> 0.1,
+ "maxIter" -> 2, // intentionally small
+ "fitIntercept" -> true,
+ "tol" -> 0.8,
+ "standardization" -> false,
+ "solver" -> "l-bfgs"
+ )
}