aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java28
1 files changed, 6 insertions, 22 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index 126aa6298f..6cdcdda1a6 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -17,48 +17,32 @@
package org.apache.spark.ml.regression;
-import java.io.Serializable;
+import java.io.IOException;
import java.util.List;
-import org.junit.After;
-import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-public class JavaLinearRegressionSuite implements Serializable {
-
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
+public class JavaLinearRegressionSuite extends SharedSparkSession {
private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaLinearRegressionSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
+ @Override
+ public void setUp() throws IOException {
+ super.setUp();
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2);
dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.createOrReplaceTempView("dataset");
}
- @After
- public void tearDown() {
- jsc.stop();
- jsc = null;
- }
-
@Test
public void linearRegressionDefaultParams() {
LinearRegression lr = new LinearRegression();