diff options
Diffstat (limited to 'mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java')
-rw-r--r-- | mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java | 28 |
1 files changed, 6 insertions, 22 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java index 126aa6298f..6cdcdda1a6 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -17,48 +17,32 @@ package org.apache.spark.ml.regression; -import java.io.Serializable; +import java.io.IOException; import java.util.List; -import org.junit.After; -import org.junit.Before; import org.junit.Test; import static org.junit.Assert.assertEquals; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList; import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; -public class JavaLinearRegressionSuite implements Serializable { - - private transient SparkSession spark; - private transient JavaSparkContext jsc; +public class JavaLinearRegressionSuite extends SharedSparkSession { private transient Dataset<Row> dataset; private transient JavaRDD<LabeledPoint> datasetRDD; - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaLinearRegressionSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); + @Override + public void setUp() throws IOException { + super.setUp(); List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42); datasetRDD = jsc.parallelize(points, 2); dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class); dataset.createOrReplaceTempView("dataset"); } - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } - @Test public void linearRegressionDefaultParams() { LinearRegression lr = new LinearRegression(); |