aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-02-05 23:43:47 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-05 23:43:47 -0800
commitdc0c4490a12ecedd8ca5a1bb256c7ccbdf0be04f (patch)
tree745d33737eaddc95a0c55a814e84c7b96f9ecbcf /mllib/src/test
parent6b88825a25a0a072c13bbcc57bbfdb102a3f133d (diff)
downloadspark-dc0c4490a12ecedd8ca5a1bb256c7ccbdf0be04f.tar.gz
spark-dc0c4490a12ecedd8ca5a1bb256c7ccbdf0be04f.tar.bz2
spark-dc0c4490a12ecedd8ca5a1bb256c7ccbdf0be04f.zip
[SPARK-4789] [SPARK-4942] [SPARK-5031] [mllib] Standardize ML Prediction APIs
This is part (1a) of the updates from the design doc in [https://docs.google.com/document/d/1BH9el33kBX8JiDdgUJXdLW14CA2qhTCWIG46eXZVoJs] **UPDATE**: Most of the APIs are being kept private[spark] to allow further discussion. Here is a list of changes which are public: * new output columns: rawPrediction, probabilities * The “score” column is now called “rawPrediction” * Classifiers now provide numClasses * Params.get and .set are now protected instead of private[ml]. * ParamMap now has a size method. * new classes: LinearRegression, LinearRegressionModel * LogisticRegression now has an intercept. ### Sketch of APIs (most of which are private[spark] for now) Abstract classes for learning algorithms (+ corresponding Model abstractions): * Classifier (+ ClassificationModel) * ProbabilisticClassifier (+ ProbabilisticClassificationModel) * Regressor (+ RegressionModel) * Predictor (+ PredictionModel) * *For all of these*: * There is no strongly typed training-time API. * There is a strongly typed test-time (prediction) API which helps developers implement new algorithms. Concrete classes: learning algorithms * LinearRegression * LogisticRegression (updated to use new abstract classes) * Also, removed "score" in favor of "probability" output column. Changed BinaryClassificationEvaluator to match. (SPARK-5031) Other updates: * params.scala: Changed Params.set/get to be protected instead of private[ml] * This was needed for the example of defining a class from outside of the MLlib namespace. * VectorUDT: Will later change from private[spark] to public. * This is needed for outside users to write their own validateAndTransformSchema() methods using vectors. * Also, added equals() method.f * SPARK-4942 : ML Transformers should allow output cols to be turned on,off * Update validateAndTransformSchema * Update transform * (Updated examples, test suites according to other changes) New examples: * DeveloperApiExample.scala (example of defining algorithm from outside of the MLlib namespace) * Added Java version too Test Suites: * LinearRegressionSuite * LogisticRegressionSuite * + Java versions of above suites CC: mengxr etrain shivaram Author: Joseph K. Bradley <joseph@databricks.com> Closes #3637 from jkbradley/ml-api-part1 and squashes the following commits: 405bfb8 [Joseph K. Bradley] Last edits based on code review. Small cleanups fec348a [Joseph K. Bradley] Added JavaDeveloperApiExample.java and fixed other issues: Made developer API private[spark] for now. Added constructors Java can understand to specialized Param types. 8316d5e [Joseph K. Bradley] fixes after rebasing on master fc62406 [Joseph K. Bradley] fixed test suites after last commit bcb9549 [Joseph K. Bradley] Fixed issues after rebasing from master (after move from SchemaRDD to DataFrame) 9872424 [Joseph K. Bradley] fixed JavaLinearRegressionSuite.java Java sql api f542997 [Joseph K. Bradley] Added MIMA excludes for VectorUDT (now public), and added DeveloperApi annotation to it 216d199 [Joseph K. Bradley] fixed after sql datatypes PR got merged f549e34 [Joseph K. Bradley] Updates based on code review. Major ones are: * Created weakly typed Predictor.train() method which is called by fit() so that developers do not have to call schema validation or copy parameters. * Made Predictor.featuresDataType have a default value of VectorUDT. * NOTE: This could be dangerous since the FeaturesType type parameter cannot have a default value. 343e7bd [Joseph K. Bradley] added blanket mima exclude for ml package 82f340b [Joseph K. Bradley] Fixed bug in LogisticRegression (introduced in this PR). Fixed Java suites 0a16da9 [Joseph K. Bradley] Fixed Linear/Logistic RegressionSuites c3c8da5 [Joseph K. Bradley] small cleanup 934f97b [Joseph K. Bradley] Fixed bugs from previous commit. 1c61723 [Joseph K. Bradley] * Made ProbabilisticClassificationModel into a subclass of ClassificationModel. Also introduced ProbabilisticClassifier. * This was to support output column “probabilityCol” in transform(). 4e2f711 [Joseph K. Bradley] rat fix bc654e1 [Joseph K. Bradley] Added spark.ml LinearRegressionSuite 8d13233 [Joseph K. Bradley] Added methods: * Classifier: batch predictRaw() * Predictor: train() without paramMap ProbabilisticClassificationModel.predictProbabilities() * Java versions of all above batch methods + others 1680905 [Joseph K. Bradley] Added JavaLabeledPointSuite.java for spark.ml, and added constructor to LabeledPoint which defaults weight to 1.0 adbe50a [Joseph K. Bradley] * fixed LinearRegression train() to use embedded paramMap * added Predictor.predict(RDD[Vector]) method * updated Linear/LogisticRegressionSuites 58802e3 [Joseph K. Bradley] added train() to Predictor subclasses which does not take a ParamMap. 57d54ab [Joseph K. Bradley] * Changed semantics of Predictor.train() to merge the given paramMap with the embedded paramMap. * remove threshold_internal from logreg * Added Predictor.copy() * Extended LogisticRegressionSuite e433872 [Joseph K. Bradley] Updated docs. Added LabeledPointSuite to spark.ml 54b7b31 [Joseph K. Bradley] Fixed issue with logreg threshold being set correctly 0617d61 [Joseph K. Bradley] Fixed bug from last commit (sorting paramMap by parameter names in toString). Fixed bug in persisting logreg data. Added threshold_internal to logreg for faster test-time prediction (avoiding map lookup). 601e792 [Joseph K. Bradley] Modified ParamMap to sort parameters in toString. Cleaned up classes in class hierarchy, before implementing tests and examples. d705e87 [Joseph K. Bradley] Added LinearRegression and Regressor back from ml-api branch 52f4fde [Joseph K. Bradley] removing everything except for simple class hierarchy for classification d35bb5d [Joseph K. Bradley] fixed compilation issues, but have not added tests yet bfade12 [Joseph K. Bradley] Added lots of classes for new ML API:
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java91
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java89
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala86
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala65
5 files changed, 310 insertions, 23 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
index 56a9dbdd58..50995ffef9 100644
--- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -65,7 +65,7 @@ public class JavaPipelineSuite {
.setStages(new PipelineStage[] {scaler, lr});
PipelineModel model = pipeline.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index f4ba23c445..26284023b0 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -18,17 +18,22 @@
package org.apache.spark.ml.classification;
import java.io.Serializable;
+import java.lang.Math;
import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
+import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
-import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
+import org.apache.spark.sql.Row;
+
public class JavaLogisticRegressionSuite implements Serializable {
@@ -36,12 +41,17 @@ public class JavaLogisticRegressionSuite implements Serializable {
private transient SQLContext jsql;
private transient DataFrame dataset;
+ private transient JavaRDD<LabeledPoint> datasetRDD;
+ private double eps = 1e-5;
+
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
jsql = new SQLContext(jsc);
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
- dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
+ datasetRDD = jsc.parallelize(points, 2);
+ dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
+ dataset.registerTempTable("dataset");
}
@After
@@ -51,29 +61,88 @@ public class JavaLogisticRegressionSuite implements Serializable {
}
@Test
- public void logisticRegression() {
+ public void logisticRegressionDefaultParams() {
LogisticRegression lr = new LogisticRegression();
+ assert(lr.getLabelCol().equals("label"));
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
+ // Check defaults
+ assert(model.getThreshold() == 0.5);
+ assert(model.getFeaturesCol().equals("features"));
+ assert(model.getPredictionCol().equals("prediction"));
+ assert(model.getProbabilityCol().equals("probability"));
}
@Test
public void logisticRegressionWithSetters() {
+ // Set params, train, and check as many params as we can.
LogisticRegression lr = new LogisticRegression()
.setMaxIter(10)
- .setRegParam(1.0);
+ .setRegParam(1.0)
+ .setThreshold(0.6)
+ .setProbabilityCol("myProbability");
LogisticRegressionModel model = lr.fit(dataset);
- model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold
- .registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
- predictions.collectAsList();
+ assert(model.fittingParamMap().apply(lr.maxIter()) == 10);
+ assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));
+ assert(model.fittingParamMap().apply(lr.threshold()).equals(0.6));
+ assert(model.getThreshold() == 0.6);
+
+ // Modify model params, and check that the params worked.
+ model.setThreshold(1.0);
+ model.transform(dataset).registerTempTable("predAllZero");
+ DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
+ for (Row r: predAllZero.collectAsList()) {
+ assert(r.getDouble(0) == 0.0);
+ }
+ // Call transform with params, and check that the params worked.
+ model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
+ .registerTempTable("predNotAllZero");
+ DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
+ boolean foundNonZero = false;
+ for (Row r: predNotAllZero.collectAsList()) {
+ if (r.getDouble(0) != 0.0) foundNonZero = true;
+ }
+ assert(foundNonZero);
+
+ // Call fit() with new params, and check as many params as we can.
+ LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
+ lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
+ assert(model2.fittingParamMap().apply(lr.maxIter()) == 5);
+ assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
+ assert(model2.fittingParamMap().apply(lr.threshold()).equals(0.4));
+ assert(model2.getThreshold() == 0.4);
+ assert(model2.getProbabilityCol().equals("theProb"));
}
+ @SuppressWarnings("unchecked")
@Test
- public void logisticRegressionFitWithVarargs() {
+ public void logisticRegressionPredictorClassifierMethods() {
LogisticRegression lr = new LogisticRegression();
- lr.fit(dataset, lr.maxIter().w(10), lr.regParam().w(1.0));
+ LogisticRegressionModel model = lr.fit(dataset);
+ assert(model.numClasses() == 2);
+
+ model.transform(dataset).registerTempTable("transformed");
+ DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
+ for (Row row: trans1.collect()) {
+ Vector raw = (Vector)row.get(0);
+ Vector prob = (Vector)row.get(1);
+ assert(raw.size() == 2);
+ assert(prob.size() == 2);
+ double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1)));
+ assert(Math.abs(prob.apply(1) - probFromRaw1) < eps);
+ assert(Math.abs(prob.apply(0) - (1.0 - probFromRaw1)) < eps);
+ }
+
+ DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
+ for (Row row: trans2.collect()) {
+ double pred = row.getDouble(0);
+ Vector prob = (Vector)row.get(1);
+ double probOfPred = prob.apply((int)pred);
+ for (int i = 0; i < prob.size(); ++i) {
+ assert(probOfPred >= prob.apply(i));
+ }
+ }
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
new file mode 100644
index 0000000000..5bd616e74d
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression;
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite
+ .generateLogisticInputAsList;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+
+
+public class JavaLinearRegressionSuite implements Serializable {
+
+ private transient JavaSparkContext jsc;
+ private transient SQLContext jsql;
+ private transient DataFrame dataset;
+ private transient JavaRDD<LabeledPoint> datasetRDD;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
+ jsql = new SQLContext(jsc);
+ List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
+ datasetRDD = jsc.parallelize(points, 2);
+ dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
+ dataset.registerTempTable("dataset");
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void linearRegressionDefaultParams() {
+ LinearRegression lr = new LinearRegression();
+ assert(lr.getLabelCol().equals("label"));
+ LinearRegressionModel model = lr.fit(dataset);
+ model.transform(dataset).registerTempTable("prediction");
+ DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction");
+ predictions.collect();
+ // Check defaults
+ assert(model.getFeaturesCol().equals("features"));
+ assert(model.getPredictionCol().equals("prediction"));
+ }
+
+ @Test
+ public void linearRegressionWithSetters() {
+ // Set params, train, and check as many params as we can.
+ LinearRegression lr = new LinearRegression()
+ .setMaxIter(10)
+ .setRegParam(1.0);
+ LinearRegressionModel model = lr.fit(dataset);
+ assert(model.fittingParamMap().apply(lr.maxIter()) == 10);
+ assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));
+
+ // Call fit() with new params, and check as many params as we can.
+ LinearRegressionModel model2 =
+ lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
+ assert(model2.fittingParamMap().apply(lr.maxIter()) == 5);
+ assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
+ assert(model2.getPredictionCol().equals("thePred"));
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 33e40dc741..b3d1bfcfbe 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -20,44 +20,108 @@ package org.apache.spark.ml.classification
import org.scalatest.FunSuite
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{SQLContext, DataFrame}
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
@transient var sqlContext: SQLContext = _
@transient var dataset: DataFrame = _
+ private val eps: Double = 1e-5
override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
dataset = sqlContext.createDataFrame(
- sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
+ sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2))
}
- test("logistic regression") {
+ test("logistic regression: default params") {
val lr = new LogisticRegression
+ assert(lr.getLabelCol == "label")
+ assert(lr.getFeaturesCol == "features")
+ assert(lr.getPredictionCol == "prediction")
+ assert(lr.getRawPredictionCol == "rawPrediction")
+ assert(lr.getProbabilityCol == "probability")
val model = lr.fit(dataset)
model.transform(dataset)
- .select("label", "prediction")
+ .select("label", "probability", "prediction", "rawPrediction")
.collect()
+ assert(model.getThreshold === 0.5)
+ assert(model.getFeaturesCol == "features")
+ assert(model.getPredictionCol == "prediction")
+ assert(model.getRawPredictionCol == "rawPrediction")
+ assert(model.getProbabilityCol == "probability")
}
test("logistic regression with setters") {
+ // Set params, train, and check as many params as we can.
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0)
+ .setThreshold(0.6)
+ .setProbabilityCol("myProbability")
val model = lr.fit(dataset)
- model.transform(dataset, model.threshold -> 0.8) // overwrite threshold
- .select("label", "score", "prediction")
+ assert(model.fittingParamMap.get(lr.maxIter) === Some(10))
+ assert(model.fittingParamMap.get(lr.regParam) === Some(1.0))
+ assert(model.fittingParamMap.get(lr.threshold) === Some(0.6))
+ assert(model.getThreshold === 0.6)
+
+ // Modify model params, and check that the params worked.
+ model.setThreshold(1.0)
+ val predAllZero = model.transform(dataset)
+ .select("prediction", "myProbability")
.collect()
+ .map { case Row(pred: Double, prob: Vector) => pred }
+ assert(predAllZero.forall(_ === 0),
+ s"With threshold=1.0, expected predictions to be all 0, but only" +
+ s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.")
+ // Call transform with params, and check that the params worked.
+ val predNotAllZero =
+ model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb")
+ .select("prediction", "myProb")
+ .collect()
+ .map { case Row(pred: Double, prob: Vector) => pred }
+ assert(predNotAllZero.exists(_ !== 0.0))
+
+ // Call fit() with new params, and check as many params as we can.
+ val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4,
+ lr.probabilityCol -> "theProb")
+ assert(model2.fittingParamMap.get(lr.maxIter).get === 5)
+ assert(model2.fittingParamMap.get(lr.regParam).get === 0.1)
+ assert(model2.fittingParamMap.get(lr.threshold).get === 0.4)
+ assert(model2.getThreshold === 0.4)
+ assert(model2.getProbabilityCol == "theProb")
}
- test("logistic regression fit and transform with varargs") {
+ test("logistic regression: Predictor, Classifier methods") {
+ val sqlContext = this.sqlContext
val lr = new LogisticRegression
- val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
- model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability")
- .select("label", "probability", "prediction")
- .collect()
+
+ val model = lr.fit(dataset)
+ assert(model.numClasses === 2)
+
+ val threshold = model.getThreshold
+ val results = model.transform(dataset)
+
+ // Compare rawPrediction with probability
+ results.select("rawPrediction", "probability").collect().map {
+ case Row(raw: Vector, prob: Vector) =>
+ assert(raw.size === 2)
+ assert(prob.size === 2)
+ val probFromRaw1 = 1.0 / (1.0 + math.exp(-raw(1)))
+ assert(prob(1) ~== probFromRaw1 relTol eps)
+ assert(prob(0) ~== 1.0 - probFromRaw1 relTol eps)
+ }
+
+ // Compare prediction with probability
+ results.select("prediction", "probability").collect().map {
+ case Row(pred: Double, prob: Vector) =>
+ val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
+ assert(pred == predFromProb)
+ }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
new file mode 100644
index 0000000000..bbb44c3e2d
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
+class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
+
+ @transient var sqlContext: SQLContext = _
+ @transient var dataset: DataFrame = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sqlContext = new SQLContext(sc)
+ dataset = sqlContext.createDataFrame(
+ sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2))
+ }
+
+ test("linear regression: default params") {
+ val lr = new LinearRegression
+ assert(lr.getLabelCol == "label")
+ val model = lr.fit(dataset)
+ model.transform(dataset)
+ .select("label", "prediction")
+ .collect()
+ // Check defaults
+ assert(model.getFeaturesCol == "features")
+ assert(model.getPredictionCol == "prediction")
+ }
+
+ test("linear regression with setters") {
+ // Set params, train, and check as many as we can.
+ val lr = new LinearRegression()
+ .setMaxIter(10)
+ .setRegParam(1.0)
+ val model = lr.fit(dataset)
+ assert(model.fittingParamMap.get(lr.maxIter).get === 10)
+ assert(model.fittingParamMap.get(lr.regParam).get === 1.0)
+
+ // Call fit() with new params, and check as many as we can.
+ val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.predictionCol -> "thePred")
+ assert(model2.fittingParamMap.get(lr.maxIter).get === 5)
+ assert(model2.fittingParamMap.get(lr.regParam).get === 0.1)
+ assert(model2.getPredictionCol == "thePred")
+ }
+}