aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala25
1 files changed, 25 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 374002c5b4..7cdda3db88 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
@@ -55,6 +56,30 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
}
+ test("params") {
+ ParamsSuite.checkParams(new LinearRegression)
+ val model = new LinearRegressionModel("linearReg", Vectors.dense(0.0), 0.0)
+ ParamsSuite.checkParams(model)
+ }
+
+ test("linear regression: default params") {
+ val lir = new LinearRegression
+ assert(lir.getLabelCol === "label")
+ assert(lir.getFeaturesCol === "features")
+ assert(lir.getPredictionCol === "prediction")
+ assert(lir.getRegParam === 0.0)
+ assert(lir.getElasticNetParam === 0.0)
+ assert(lir.getFitIntercept)
+ val model = lir.fit(dataset)
+ model.transform(dataset)
+ .select("label", "prediction")
+ .collect()
+ assert(model.getFeaturesCol === "features")
+ assert(model.getPredictionCol === "prediction")
+ assert(model.intercept !== 0.0)
+ assert(model.hasParent)
+ }
+
test("linear regression with intercept without regularization") {
val trainer = new LinearRegression
val model = trainer.fit(dataset)