diff options
author | Holden Karau <holden@pigscanfly.ca> | 2015-07-20 22:15:10 -0700 |
---|---|---|
committer | DB Tsai <dbt@netflix.com> | 2015-07-20 22:15:10 -0700 |
commit | 4d97be95300f729391c17b4c162e3c7fba09b8bf (patch) | |
tree | 456974487257569d23434823855f891519aaf8a5 | |
parent | a3c7a3ce32697ad293b8bcaf29f9384c8255b37f (diff) | |
download | spark-4d97be95300f729391c17b4c162e3c7fba09b8bf.tar.gz spark-4d97be95300f729391c17b4c162e3c7fba09b8bf.tar.bz2 spark-4d97be95300f729391c17b4c162e3c7fba09b8bf.zip |
[SPARK-9204][ML] Add default params test for linearyregression suite
Author: Holden Karau <holden@pigscanfly.ca>
Closes #7553 from holdenk/SPARK-9204-add-default-params-test-to-linear-regression and squashes the following commits:
630ba19 [Holden Karau] style fix
faa08a3 [Holden Karau] Add default params test for linearyregression suite
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 374002c5b4..7cdda3db88 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ @@ -55,6 +56,30 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { } + test("params") { + ParamsSuite.checkParams(new LinearRegression) + val model = new LinearRegressionModel("linearReg", Vectors.dense(0.0), 0.0) + ParamsSuite.checkParams(model) + } + + test("linear regression: default params") { + val lir = new LinearRegression + assert(lir.getLabelCol === "label") + assert(lir.getFeaturesCol === "features") + assert(lir.getPredictionCol === "prediction") + assert(lir.getRegParam === 0.0) + assert(lir.getElasticNetParam === 0.0) + assert(lir.getFitIntercept) + val model = lir.fit(dataset) + model.transform(dataset) + .select("label", "prediction") + .collect() + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.intercept !== 0.0) + assert(model.hasParent) + } + test("linear regression with intercept without regularization") { val trainer = new LinearRegression val model = trainer.fit(dataset) |