aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-07-20 22:15:10 -0700
committerDB Tsai <dbt@netflix.com>2015-07-20 22:15:10 -0700
commit4d97be95300f729391c17b4c162e3c7fba09b8bf (patch)
tree456974487257569d23434823855f891519aaf8a5
parenta3c7a3ce32697ad293b8bcaf29f9384c8255b37f (diff)
downloadspark-4d97be95300f729391c17b4c162e3c7fba09b8bf.tar.gz
spark-4d97be95300f729391c17b4c162e3c7fba09b8bf.tar.bz2
spark-4d97be95300f729391c17b4c162e3c7fba09b8bf.zip
[SPARK-9204][ML] Add default params test for linearyregression suite
Author: Holden Karau <holden@pigscanfly.ca> Closes #7553 from holdenk/SPARK-9204-add-default-params-test-to-linear-regression and squashes the following commits: 630ba19 [Holden Karau] style fix faa08a3 [Holden Karau] Add default params test for linearyregression suite
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala25
1 files changed, 25 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 374002c5b4..7cdda3db88 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
@@ -55,6 +56,30 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
}
+ test("params") {
+ ParamsSuite.checkParams(new LinearRegression)
+ val model = new LinearRegressionModel("linearReg", Vectors.dense(0.0), 0.0)
+ ParamsSuite.checkParams(model)
+ }
+
+ test("linear regression: default params") {
+ val lir = new LinearRegression
+ assert(lir.getLabelCol === "label")
+ assert(lir.getFeaturesCol === "features")
+ assert(lir.getPredictionCol === "prediction")
+ assert(lir.getRegParam === 0.0)
+ assert(lir.getElasticNetParam === 0.0)
+ assert(lir.getFitIntercept)
+ val model = lir.fit(dataset)
+ model.transform(dataset)
+ .select("label", "prediction")
+ .collect()
+ assert(model.getFeaturesCol === "features")
+ assert(model.getPredictionCol === "prediction")
+ assert(model.intercept !== 0.0)
+ assert(model.hasParent)
+ }
+
test("linear regression with intercept without regularization") {
val trainer = new LinearRegression
val model = trainer.fit(dataset)