aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/SharedSparkSession.java48
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java27
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java27
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java28
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java28
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java23
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java23
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java30
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java28
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java27
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java21
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java21
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java21
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java24
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java26
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java24
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java24
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java22
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java26
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java24
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java26
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java25
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java21
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java21
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java23
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java26
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java26
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java28
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java26
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java20
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java33
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java31
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java25
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java25
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java25
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java3
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java26
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java25
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java25
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java29
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java3
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java28
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java25
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java25
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java25
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java25
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java24
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java3
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java3
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java24
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java25
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java25
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java25
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java25
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java25
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java3
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java3
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java26
59 files changed, 207 insertions, 1148 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java
index 79b9909581..526bed93fb 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java
@@ -37,7 +37,7 @@ public class JavaGaussianMixtureExample {
public static void main(String[] args) {
- // Creates a SparkSession
+ // Creates a SparkSession
SparkSession spark = SparkSession
.builder()
.appName("JavaGaussianMixtureExample")
diff --git a/mllib/src/test/java/org/apache/spark/SharedSparkSession.java b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java
new file mode 100644
index 0000000000..4377987889
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+import java.io.IOException;
+import java.io.Serializable;
+
+import org.junit.After;
+import org.junit.Before;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
+
+public abstract class SharedSparkSession implements Serializable {
+
+ protected transient SparkSession spark;
+ protected transient JavaSparkContext jsc;
+
+ @Before
+ public void setUp() throws IOException {
+ spark = SparkSession.builder()
+ .master("local[2]")
+ .appName(getClass().getSimpleName())
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
+ }
+
+ @After
+ public void tearDown() {
+ spark.stop();
+ spark = null;
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
index a81a36d1b1..9b209006bc 100644
--- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -17,47 +17,34 @@
package org.apache.spark.ml;
-import org.junit.After;
-import org.junit.Before;
+import java.io.IOException;
+
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegression;
import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.feature.StandardScaler;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
/**
* Test Pipeline construction and fitting in Java.
*/
-public class JavaPipelineSuite {
+public class JavaPipelineSuite extends SharedSparkSession {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
private transient Dataset<Row> dataset;
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaPipelineSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
+ @Override
+ public void setUp() throws IOException {
+ super.setUp();
JavaRDD<LabeledPoint> points =
jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
dataset = spark.createDataFrame(points, LabeledPoint.class);
}
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
-
@Test
public void pipeline() {
StandardScaler scaler = new StandardScaler()
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
index c76a1947c6..5aba4e8f7d 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
@@ -17,42 +17,19 @@
package org.apache.spark.ml.classification;
-import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
-import org.junit.After;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.classification.LogisticRegressionSuite;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-public class JavaDecisionTreeClassifierSuite implements Serializable {
-
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaDecisionTreeClassifierSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaDecisionTreeClassifierSuite extends SharedSparkSession {
@Test
public void runDT() {
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
index 4648926c34..74bb46bd21 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
@@ -17,43 +17,19 @@
package org.apache.spark.ml.classification;
-import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
-import org.junit.After;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.classification.LogisticRegressionSuite;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-
-public class JavaGBTClassifierSuite implements Serializable {
-
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaGBTClassifierSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaGBTClassifierSuite extends SharedSparkSession {
@Test
public void runDT() {
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index b8da04c26a..004102103d 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -17,52 +17,36 @@
package org.apache.spark.ml.classification;
-import java.io.Serializable;
+import java.io.IOException;
import java.util.List;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-public class JavaLogisticRegressionSuite implements Serializable {
+public class JavaLogisticRegressionSuite extends SharedSparkSession {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
private double eps = 1e-5;
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaLogisticRegressionSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
-
+ @Override
+ public void setUp() throws IOException {
+ super.setUp();
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2);
dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.createOrReplaceTempView("dataset");
}
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
-
@Test
public void logisticRegressionDefaultParams() {
LogisticRegression lr = new LogisticRegression();
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
index 48edbc838c..6d0604d8f9 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
@@ -17,38 +17,19 @@
package org.apache.spark.ml.classification;
-import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
-
- private transient SparkSession spark;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaLogisticRegressionSuite")
- .getOrCreate();
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaMultilayerPerceptronClassifierSuite extends SharedSparkSession {
@Test
public void testMLPC() {
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
index 787909821b..c2a9e7b58b 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
@@ -17,43 +17,24 @@
package org.apache.spark.ml.classification;
-import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
-import org.junit.After;
-import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
-public class JavaNaiveBayesSuite implements Serializable {
-
- private transient SparkSession spark;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaLogisticRegressionSuite")
- .getOrCreate();
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaNaiveBayesSuite extends SharedSparkSession {
public void validatePrediction(Dataset<Row> predictionAndLabels) {
for (Row r : predictionAndLabels.collectAsList()) {
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
index 58bc5a448a..6194167bda 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
@@ -17,39 +17,29 @@
package org.apache.spark.ml.classification;
-import java.io.Serializable;
+import java.io.IOException;
import java.util.List;
import scala.collection.JavaConverters;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
+import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
-public class JavaOneVsRestSuite implements Serializable {
+public class JavaOneVsRestSuite extends SharedSparkSession {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaLOneVsRestSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
-
+ @Override
+ public void setUp() throws IOException {
+ super.setUp();
int nPoints = 3;
// The following coefficients and xMean/xVariance are computed from iris dataset with
@@ -68,12 +58,6 @@ public class JavaOneVsRestSuite implements Serializable {
dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
}
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
-
@Test
public void oneVsRestDefaultParams() {
OneVsRest ova = new OneVsRest();
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
index 1ed20b1bfa..dd98513f37 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -17,45 +17,21 @@
package org.apache.spark.ml.classification;
-import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.classification.LogisticRegressionSuite;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-
-public class JavaRandomForestClassifierSuite implements Serializable {
-
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaRandomForestClassifierSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaRandomForestClassifierSuite extends SharedSparkSession {
@Test
public void runDT() {
diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
index 9d07170fa1..1be6f96f4c 100644
--- a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
@@ -17,43 +17,30 @@
package org.apache.spark.ml.clustering;
-import java.io.Serializable;
+import java.io.IOException;
import java.util.Arrays;
import java.util.List;
+import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
-
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-public class JavaKMeansSuite implements Serializable {
+public class JavaKMeansSuite extends SharedSparkSession {
private transient int k = 5;
private transient Dataset<Row> dataset;
- private transient SparkSession spark;
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaKMeansSuite")
- .getOrCreate();
+ @Override
+ public void setUp() throws IOException {
+ super.setUp();
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k);
}
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
-
@Test
public void fitAndTransform() {
KMeans kmeans = new KMeans().setK(k).setSeed(1);
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
index a96b43de15..87639380bd 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
@@ -20,36 +20,19 @@ package org.apache.spark.ml.feature;
import java.util.Arrays;
import java.util.List;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
-public class JavaBucketizerSuite {
- private transient SparkSession spark;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaBucketizerSuite")
- .getOrCreate();
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaBucketizerSuite extends SharedSparkSession {
@Test
public void bucketizerTest() {
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
index 9d8c09b30c..b7956b6fd3 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
@@ -22,38 +22,21 @@ import java.util.List;
import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
-public class JavaDCTSuite {
- private transient SparkSession spark;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaDCTSuite")
- .getOrCreate();
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaDCTSuite extends SharedSparkSession {
@Test
public void javaCompatibilityTest() {
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
index 3c37441a77..57696d0150 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
@@ -20,38 +20,21 @@ package org.apache.spark.ml.feature;
import java.util.Arrays;
import java.util.List;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
-public class JavaHashingTFSuite {
- private transient SparkSession spark;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaHashingTFSuite")
- .getOrCreate();
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaHashingTFSuite extends SharedSparkSession {
@Test
public void hashingTF() {
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
index b3e213a497..6f877b5668 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
@@ -19,35 +19,15 @@ package org.apache.spark.ml.feature;
import java.util.Arrays;
-import org.junit.After;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-public class JavaNormalizerSuite {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaNormalizerSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaNormalizerSuite extends SharedSparkSession {
@Test
public void normalizer() {
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
index a4bce2283b..ac479c0841 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
@@ -23,13 +23,11 @@ import java.util.List;
import scala.Tuple2;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors;
@@ -37,26 +35,8 @@ import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.distributed.RowMatrix;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-
-public class JavaPCASuite implements Serializable {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaPCASuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaPCASuite extends SharedSparkSession {
public static class VectorPair implements Serializable {
private Vector features = Vectors.dense(0.0);
@@ -95,7 +75,7 @@ public class JavaPCASuite implements Serializable {
}
}
).rdd());
-
+
Matrix pc = mat.computePrincipalComponents(3);
mat.multiply(pc).rows().toJavaRDD();
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
index a28f73f10a..df5d34fbe9 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
@@ -20,41 +20,21 @@ package org.apache.spark.ml.feature;
import java.util.Arrays;
import java.util.List;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
-public class JavaPolynomialExpansionSuite {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaPolynomialExpansionSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- jsc.stop();
- jsc = null;
- }
+public class JavaPolynomialExpansionSuite extends SharedSparkSession {
@Test
public void polynomialExpansionTest() {
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
index 8415fdb84f..dbc0b1db5c 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
@@ -20,34 +20,14 @@ package org.apache.spark.ml.feature;
import java.util.Arrays;
import java.util.List;
-import org.junit.After;
-import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-public class JavaStandardScalerSuite {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaStandardScalerSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaStandardScalerSuite extends SharedSparkSession {
@Test
public void standardScaler() {
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
index 2b156f3bca..6480b57e1f 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
@@ -20,37 +20,19 @@ package org.apache.spark.ml.feature;
import java.util.Arrays;
import java.util.List;
-import org.junit.After;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
-public class JavaStopWordsRemoverSuite {
-
- private transient SparkSession spark;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaStopWordsRemoverSuite")
- .getOrCreate();
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaStopWordsRemoverSuite extends SharedSparkSession {
@Test
public void javaCompatibilityTest() {
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
index 52c0bde8f3..c1928a26b6 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
@@ -20,37 +20,19 @@ package org.apache.spark.ml.feature;
import java.util.Arrays;
import java.util.List;
-import org.junit.After;
+import static org.apache.spark.sql.types.DataTypes.*;
+
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.SparkConf;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
-import static org.apache.spark.sql.types.DataTypes.*;
-
-public class JavaStringIndexerSuite {
- private transient SparkSession spark;
- @Before
- public void setUp() {
- SparkConf sparkConf = new SparkConf();
- sparkConf.setMaster("local");
- sparkConf.setAppName("JavaStringIndexerSuite");
-
- spark = SparkSession.builder().config(sparkConf).getOrCreate();
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaStringIndexerSuite extends SharedSparkSession {
@Test
public void testStringIndexer() {
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
index 0bac2839e1..27550a3d5c 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
@@ -20,35 +20,15 @@ package org.apache.spark.ml.feature;
import java.util.Arrays;
import java.util.List;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-public class JavaTokenizerSuite {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaTokenizerSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaTokenizerSuite extends SharedSparkSession {
@Test
public void regexTokenizer() {
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
index fedaa77176..583652badb 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
@@ -19,40 +19,22 @@ package org.apache.spark.ml.feature;
import java.util.Arrays;
-import org.junit.After;
+import static org.apache.spark.sql.types.DataTypes.*;
+
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.SparkConf;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
-import static org.apache.spark.sql.types.DataTypes.*;
-
-public class JavaVectorAssemblerSuite {
- private transient SparkSession spark;
- @Before
- public void setUp() {
- SparkConf sparkConf = new SparkConf();
- sparkConf.setMaster("local");
- sparkConf.setAppName("JavaVectorAssemblerSuite");
-
- spark = SparkSession.builder().config(sparkConf).getOrCreate();
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaVectorAssemblerSuite extends SharedSparkSession {
@Test
public void testVectorAssembler() {
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
index a8dd44608d..ca8fae3a48 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
@@ -17,42 +17,21 @@
package org.apache.spark.ml.feature;
-import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-public class JavaVectorIndexerSuite implements Serializable {
- private transient SparkSession spark;
- private JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaVectorIndexerSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaVectorIndexerSuite extends SharedSparkSession {
@Test
public void vectorIndexerAPI() {
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
index a565c77af4..3dc2e1f896 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
@@ -20,11 +20,10 @@ package org.apache.spark.ml.feature;
import java.util.Arrays;
import java.util.List;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.NumericAttribute;
@@ -33,26 +32,10 @@ import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructType;
-public class JavaVectorSlicerSuite {
- private transient SparkSession spark;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaVectorSlicerSuite")
- .getOrCreate();
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaVectorSlicerSuite extends SharedSparkSession {
@Test
public void vectorSlice() {
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
index bef7eb0f99..d0a849fd11 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
@@ -19,34 +19,17 @@ package org.apache.spark.ml.feature;
import java.util.Arrays;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.*;
-public class JavaWord2VecSuite {
- private transient SparkSession spark;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaWord2VecSuite")
- .getOrCreate();
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaWord2VecSuite extends SharedSparkSession {
@Test
public void testJavaWord2Vec() {
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
index a5b5dd4088..1077e103a3 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
@@ -19,37 +19,14 @@ package org.apache.spark.ml.param;
import java.util.Arrays;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.SparkSession;
-
/**
* Test Param and related classes in Java
*/
public class JavaParamsSuite {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaParamsSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
-
@Test
public void testParams() {
JavaTestParams testParams = new JavaTestParams();
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
index 4ea3f2255e..1da85ed9da 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
@@ -17,43 +17,21 @@
package org.apache.spark.ml.regression;
-import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
-import org.junit.After;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegressionSuite;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-public class JavaDecisionTreeRegressorSuite implements Serializable {
-
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaDecisionTreeRegressorSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaDecisionTreeRegressorSuite extends SharedSparkSession {
@Test
public void runDT() {
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
index 3b5edf1e15..7fd9b1feb7 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
@@ -17,43 +17,21 @@
package org.apache.spark.ml.regression;
-import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
-import org.junit.After;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegressionSuite;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-public class JavaGBTRegressorSuite implements Serializable {
-
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaGBTRegressorSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaGBTRegressorSuite extends SharedSparkSession {
@Test
public void runDT() {
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index 126aa6298f..6cdcdda1a6 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -17,48 +17,32 @@
package org.apache.spark.ml.regression;
-import java.io.Serializable;
+import java.io.IOException;
import java.util.List;
-import org.junit.After;
-import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-public class JavaLinearRegressionSuite implements Serializable {
-
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
+public class JavaLinearRegressionSuite extends SharedSparkSession {
private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaLinearRegressionSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
+ @Override
+ public void setUp() throws IOException {
+ super.setUp();
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2);
dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.createOrReplaceTempView("dataset");
}
- @After
- public void tearDown() {
- jsc.stop();
- jsc = null;
- }
-
@Test
public void linearRegressionDefaultParams() {
LinearRegression lr = new LinearRegression();
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
index d601e7c540..4ba13e2e06 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -17,45 +17,23 @@
package org.apache.spark.ml.regression;
-import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegressionSuite;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-public class JavaRandomForestRegressorSuite implements Serializable {
-
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaRandomForestRegressorSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaRandomForestRegressorSuite extends SharedSparkSession {
@Test
public void runDT() {
diff --git a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
index 022dcf94bd..fa39f4560c 100644
--- a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
@@ -23,35 +23,28 @@ import java.nio.charset.StandardCharsets;
import com.google.common.io.Files;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils;
/**
* Test LibSVMRelation in Java.
*/
-public class JavaLibSVMRelationSuite {
- private transient SparkSession spark;
+public class JavaLibSVMRelationSuite extends SharedSparkSession {
private File tempDir;
private String path;
- @Before
+ @Override
public void setUp() throws IOException {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaLibSVMRelationSuite")
- .getOrCreate();
-
+ super.setUp();
tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
File file = new File(tempDir, "part-00000");
String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0";
@@ -59,10 +52,9 @@ public class JavaLibSVMRelationSuite {
path = tempDir.toURI().toString();
}
- @After
+ @Override
public void tearDown() {
- spark.stop();
- spark = null;
+ super.tearDown();
Utils.deleteRecursively(tempDir);
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
index b874ccd48b..692d5ad591 100644
--- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -17,48 +17,33 @@
package org.apache.spark.ml.tuning;
-import java.io.Serializable;
+import java.io.IOException;
import java.util.List;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.classification.LogisticRegression;
-import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
-import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
+import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
+import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
-public class JavaCrossValidatorSuite implements Serializable {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
- private transient Dataset<Row> dataset;
+public class JavaCrossValidatorSuite extends SharedSparkSession {
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaCrossValidatorSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
+ private transient Dataset<Row> dataset;
+ @Override
+ public void setUp() throws IOException {
+ super.setUp();
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
dataset = spark.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class);
}
- @After
- public void tearDown() {
- jsc.stop();
- jsc = null;
- }
-
@Test
public void crossValidationWithLogisticRegression() {
LogisticRegression lr = new LogisticRegression();
diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
index 7151e27cde..da623d1d15 100644
--- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
@@ -20,42 +20,25 @@ package org.apache.spark.ml.util;
import java.io.File;
import java.io.IOException;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.SQLContext;
-import org.apache.spark.sql.SparkSession;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.util.Utils;
-public class JavaDefaultReadWriteSuite {
-
- JavaSparkContext jsc = null;
- SparkSession spark = null;
+public class JavaDefaultReadWriteSuite extends SharedSparkSession {
File tempDir = null;
- @Before
- public void setUp() {
- SQLContext.clearActive();
- spark = SparkSession.builder()
- .master("local[2]")
- .appName("JavaDefaultReadWriteSuite")
- .getOrCreate();
- SQLContext.setActive(spark.wrapped());
-
+ @Override
+ public void setUp() throws IOException {
+ super.setUp();
tempDir = Utils.createTempDir(
System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite");
}
- @After
+ @Override
public void tearDown() {
- SQLContext.clearActive();
- if (spark != null) {
- spark.stop();
- spark = null;
- }
+ super.tearDown();
Utils.deleteRecursively(tempDir);
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
index 2f10d14da5..c04e2e6954 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
@@ -17,37 +17,16 @@
package org.apache.spark.mllib.classification;
-import java.io.Serializable;
import java.util.List;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.SparkSession;
-public class JavaLogisticRegressionSuite implements Serializable {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaLogisticRegressionSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaLogisticRegressionSuite extends SharedSparkSession {
int validatePrediction(List<LabeledPoint> validationData, LogisticRegressionModel model) {
int numAccurate = 0;
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
index 5e212e2fc5..6ded42e928 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
@@ -17,42 +17,21 @@
package org.apache.spark.mllib.classification;
-import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.SparkSession;
-public class JavaNaiveBayesSuite implements Serializable {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaNaiveBayesSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaNaiveBayesSuite extends SharedSparkSession {
private static final List<LabeledPoint> POINTS = Arrays.asList(
new LabeledPoint(0, Vectors.dense(1.0, 0.0, 0.0)),
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
index 2a090c054f..0f54e684e4 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
@@ -17,37 +17,16 @@
package org.apache.spark.mllib.classification;
-import java.io.Serializable;
import java.util.List;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.SparkSession;
-public class JavaSVMSuite implements Serializable {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaSVMSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaSVMSuite extends SharedSparkSession {
int validatePrediction(List<LabeledPoint> validationData, SVMModel model) {
int numAccurate = 0;
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
index 62c6d9b7e3..8c6bced52d 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
@@ -17,7 +17,6 @@
package org.apache.spark.mllib.classification;
-import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
@@ -37,7 +36,7 @@ import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
import static org.apache.spark.streaming.JavaTestUtils.*;
-public class JavaStreamingLogisticRegressionSuite implements Serializable {
+public class JavaStreamingLogisticRegressionSuite {
protected transient JavaStreamingContext ssc;
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
index 7f29b05047..3d62b273d2 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
@@ -17,39 +17,17 @@
package org.apache.spark.mllib.clustering;
-import java.io.Serializable;
-
import com.google.common.collect.Lists;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.SparkSession;
-
-public class JavaBisectingKMeansSuite implements Serializable {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaBisectingKMeansSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaBisectingKMeansSuite extends SharedSparkSession {
@Test
public void twoDimensionalData() {
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
index 20edd08a21..bf76719937 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
@@ -17,40 +17,19 @@
package org.apache.spark.mllib.clustering;
-import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
-import org.junit.After;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.SparkSession;
-public class JavaGaussianMixtureSuite implements Serializable {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaGaussianMixture")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaGaussianMixtureSuite extends SharedSparkSession {
@Test
public void runGaussianMixture() {
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
index 4e5b87f588..270e636f82 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
@@ -17,40 +17,19 @@
package org.apache.spark.mllib.clustering;
-import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
-import org.junit.After;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.SparkSession;
-public class JavaKMeansSuite implements Serializable {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaKMeans")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaKMeansSuite extends SharedSparkSession {
@Test
public void runKMeansUsingStaticMethods() {
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
index f16585aff4..08d6713ab2 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
@@ -17,39 +17,28 @@
package org.apache.spark.mllib.clustering;
-import java.io.Serializable;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import scala.Tuple2;
import scala.Tuple3;
-import org.junit.After;
-import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.*;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.SparkSession;
-
-public class JavaLDASuite implements Serializable {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaLDASuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
+public class JavaLDASuite extends SharedSparkSession {
+ @Override
+ public void setUp() throws IOException {
+ super.setUp();
ArrayList<Tuple2<Long, Vector>> tinyCorpus = new ArrayList<>();
for (int i = 0; i < LDASuite.tinyCorpus().length; i++) {
tinyCorpus.add(new Tuple2<>((Long) LDASuite.tinyCorpus()[i]._1(),
@@ -59,12 +48,6 @@ public class JavaLDASuite implements Serializable {
corpus = JavaPairRDD.fromJavaRDD(tmpCorpus);
}
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
-
@Test
public void localLDAModel() {
Matrix topics = LDASuite.tinyTopics();
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
index d1d618f7de..d41fc0e4dc 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
@@ -17,7 +17,6 @@
package org.apache.spark.mllib.clustering;
-import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
@@ -36,7 +35,7 @@ import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
import static org.apache.spark.streaming.JavaTestUtils.*;
-public class JavaStreamingKMeansSuite implements Serializable {
+public class JavaStreamingKMeansSuite {
protected transient JavaStreamingContext ssc;
diff --git a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
index 6a096d6386..e9d7e4fdbe 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
@@ -17,35 +17,25 @@
package org.apache.spark.mllib.evaluation;
-import java.io.Serializable;
+import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import scala.Tuple2;
import scala.Tuple2$;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.SparkSession;
-public class JavaRankingMetricsSuite implements Serializable {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
+public class JavaRankingMetricsSuite extends SharedSparkSession {
private transient JavaRDD<Tuple2<List<Integer>, List<Integer>>> predictionAndLabels;
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaPCASuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
-
+ @Override
+ public void setUp() throws IOException {
+ super.setUp();
predictionAndLabels = jsc.parallelize(Arrays.asList(
Tuple2$.MODULE$.apply(
Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Arrays.asList(1, 2, 3, 4, 5)),
@@ -55,12 +45,6 @@ public class JavaRankingMetricsSuite implements Serializable {
Arrays.asList(1, 2, 3, 4, 5), Arrays.<Integer>asList())), 2);
}
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
-
@Test
public void rankingMetrics() {
@SuppressWarnings("unchecked")
diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
index de50fb8c4f..05128ea343 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
@@ -17,38 +17,17 @@
package org.apache.spark.mllib.feature;
-import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.sql.SparkSession;
-public class JavaTfIdfSuite implements Serializable {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaPCASuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaTfIdfSuite extends SharedSparkSession {
@Test
public void tfIdf() {
diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
index 64885cc842..3e3abddbee 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
@@ -17,7 +17,6 @@
package org.apache.spark.mllib.feature;
-import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
@@ -25,33 +24,13 @@ import com.google.common.base.Strings;
import scala.Tuple2;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.SparkSession;
-public class JavaWord2VecSuite implements Serializable {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaPCASuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaWord2VecSuite extends SharedSparkSession {
@Test
@SuppressWarnings("unchecked")
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
index fdc19a5b3d..3451e07737 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
@@ -16,36 +16,15 @@
*/
package org.apache.spark.mllib.fpm;
-import java.io.Serializable;
import java.util.Arrays;
-import org.junit.After;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
-import org.apache.spark.sql.SparkSession;
-public class JavaAssociationRulesSuite implements Serializable {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaAssociationRulesSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaAssociationRulesSuite extends SharedSparkSession {
@Test
public void runAssociationRules() {
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
index f235251e61..46e9dd8b59 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
@@ -18,39 +18,18 @@
package org.apache.spark.mllib.fpm;
import java.io.File;
-import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
-import org.junit.After;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils;
-public class JavaFPGrowthSuite implements Serializable {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaFPGrowth")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaFPGrowthSuite extends SharedSparkSession {
@Test
public void runFPGrowth() {
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
index bf7f1fc71b..75b0ec6480 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
@@ -21,35 +21,15 @@ import java.io.File;
import java.util.Arrays;
import java.util.List;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence;
-import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils;
-public class JavaPrefixSpanSuite {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaPrefixSpan")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaPrefixSpanSuite extends SharedSparkSession {
@Test
public void runPrefixSpan() {
diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
index 92fc57871c..f427846b9a 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
@@ -17,7 +17,6 @@
package org.apache.spark.mllib.linalg;
-import java.io.Serializable;
import java.util.Random;
import static org.junit.Assert.assertArrayEquals;
@@ -25,7 +24,7 @@ import static org.junit.Assert.assertEquals;
import org.junit.Test;
-public class JavaMatricesSuite implements Serializable {
+public class JavaMatricesSuite {
@Test
public void randMatrixConstruction() {
diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
index 817b962c75..f67f555e41 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
@@ -17,7 +17,6 @@
package org.apache.spark.mllib.linalg;
-import java.io.Serializable;
import java.util.Arrays;
import static org.junit.Assert.assertArrayEquals;
@@ -26,7 +25,7 @@ import scala.Tuple2;
import org.junit.Test;
-public class JavaVectorsSuite implements Serializable {
+public class JavaVectorsSuite {
@Test
public void denseArrayConstruction() {
diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
index b449108a9b..6d114024c3 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
@@ -20,36 +20,16 @@ package org.apache.spark.mllib.random;
import java.io.Serializable;
import java.util.Arrays;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.sql.SparkSession;
import static org.apache.spark.mllib.random.RandomRDDs.*;
-public class JavaRandomRDDsSuite {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaRandomRDDsSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaRandomRDDsSuite extends SharedSparkSession {
@Test
public void testUniformRDD() {
diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
index aa784054d5..363ab42546 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
@@ -17,41 +17,20 @@
package org.apache.spark.mllib.recommendation;
-import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import scala.Tuple2;
import scala.Tuple3;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.SparkSession;
-
-public class JavaALSSuite implements Serializable {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaALS")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaALSSuite extends SharedSparkSession {
private void validatePrediction(
MatrixFactorizationModel model,
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java
index 8b05675d65..dbd4cbfd2b 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java
@@ -17,26 +17,20 @@
package org.apache.spark.mllib.regression;
-import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import scala.Tuple3;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.SparkSession;
-public class JavaIsotonicRegressionSuite implements Serializable {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
+public class JavaIsotonicRegressionSuite extends SharedSparkSession {
private static List<Tuple3<Double, Double, Double>> generateIsotonicInput(double[] labels) {
List<Tuple3<Double, Double, Double>> input = new ArrayList<>(labels.length);
@@ -55,21 +49,6 @@ public class JavaIsotonicRegressionSuite implements Serializable {
return new IsotonicRegression().run(trainRDD);
}
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaLinearRegressionSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
-
@Test
public void testIsotonicRegressionJavaRDD() {
IsotonicRegressionModel model =
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
index 098bac3bed..1458cc72bc 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
@@ -17,37 +17,16 @@
package org.apache.spark.mllib.regression;
-import java.io.Serializable;
import java.util.List;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.util.LinearDataGenerator;
-import org.apache.spark.sql.SparkSession;
-public class JavaLassoSuite implements Serializable {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaLassoSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaLassoSuite extends SharedSparkSession {
int validatePrediction(List<LabeledPoint> validationData, LassoModel model) {
int numAccurate = 0;
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
index 35087a5e46..a46b1321b3 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
@@ -17,39 +17,18 @@
package org.apache.spark.mllib.regression;
-import java.io.Serializable;
import java.util.List;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.util.LinearDataGenerator;
-import org.apache.spark.sql.SparkSession;
-public class JavaLinearRegressionSuite implements Serializable {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaLinearRegressionSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaLinearRegressionSuite extends SharedSparkSession {
int validatePrediction(List<LabeledPoint> validationData, LinearRegressionModel model) {
int numAccurate = 0;
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
index b2efb2e72e..cb00977412 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
@@ -17,38 +17,17 @@
package org.apache.spark.mllib.regression;
-import java.io.Serializable;
import java.util.List;
import java.util.Random;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.util.LinearDataGenerator;
-import org.apache.spark.sql.SparkSession;
-
-public class JavaRidgeRegressionSuite implements Serializable {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaRidgeRegressionSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaRidgeRegressionSuite extends SharedSparkSession {
private static double predictionError(List<LabeledPoint> validationData,
RidgeRegressionModel model) {
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java
index ea0ccd7448..ab554475d5 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java
@@ -17,7 +17,6 @@
package org.apache.spark.mllib.regression;
-import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
@@ -36,7 +35,7 @@ import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
import static org.apache.spark.streaming.JavaTestUtils.*;
-public class JavaStreamingLinearRegressionSuite implements Serializable {
+public class JavaStreamingLinearRegressionSuite {
protected transient JavaStreamingContext ssc;
diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
index 373417d3ba..1abaa39ead 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
@@ -17,7 +17,6 @@
package org.apache.spark.mllib.stat;
-import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
@@ -42,7 +41,7 @@ import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
import static org.apache.spark.streaming.JavaTestUtils.*;
-public class JavaStatisticsSuite implements Serializable {
+public class JavaStatisticsSuite {
private transient SparkSession spark;
private transient JavaSparkContext jsc;
private transient JavaStreamingContext ssc;
diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
index 5b464a4722..1dcbbcaa02 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
@@ -17,17 +17,14 @@
package org.apache.spark.mllib.tree;
-import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
-import org.junit.After;
import org.junit.Assert;
-import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
@@ -35,27 +32,8 @@ import org.apache.spark.mllib.tree.configuration.Algo;
import org.apache.spark.mllib.tree.configuration.Strategy;
import org.apache.spark.mllib.tree.impurity.Gini;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
-import org.apache.spark.sql.SparkSession;
-
-public class JavaDecisionTreeSuite implements Serializable {
- private transient SparkSession spark;
- private transient JavaSparkContext jsc;
-
- @Before
- public void setUp() {
- spark = SparkSession.builder()
- .master("local")
- .appName("JavaDecisionTreeSuite")
- .getOrCreate();
- jsc = new JavaSparkContext(spark.sparkContext());
- }
-
- @After
- public void tearDown() {
- spark.stop();
- spark = null;
- }
+public class JavaDecisionTreeSuite extends SharedSparkSession {
int validatePrediction(List<LabeledPoint> validationData, DecisionTreeModel model) {
int numCorrect = 0;