From 01cf649c4f96f64fb4bd09e0e1811cabcc5ead2e Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Thu, 19 May 2016 20:38:44 -0700 Subject: [SPARK-15296][MLLIB] Refactor All Java Tests that use SparkSession ## What changes were proposed in this pull request? Refactor All Java Tests that use SparkSession, to extend SharedSparkSesion ## How was this patch tested? Existing Tests Author: Sandeep Singh Closes #13101 from techaddict/SPARK-15296. --- .../examples/ml/JavaGaussianMixtureExample.java | 2 +- .../java/org/apache/spark/SharedSparkSession.java | 48 ++++++++++++++++++++++ .../org/apache/spark/ml/JavaPipelineSuite.java | 27 ++++-------- .../JavaDecisionTreeClassifierSuite.java | 27 +----------- .../ml/classification/JavaGBTClassifierSuite.java | 28 +------------ .../JavaLogisticRegressionSuite.java | 28 +++---------- .../JavaMultilayerPerceptronClassifierSuite.java | 23 +---------- .../ml/classification/JavaNaiveBayesSuite.java | 23 +---------- .../ml/classification/JavaOneVsRestSuite.java | 30 ++++---------- .../JavaRandomForestClassifierSuite.java | 28 +------------ .../spark/ml/clustering/JavaKMeansSuite.java | 27 ++++-------- .../spark/ml/feature/JavaBucketizerSuite.java | 21 +--------- .../org/apache/spark/ml/feature/JavaDCTSuite.java | 21 +--------- .../spark/ml/feature/JavaHashingTFSuite.java | 21 +--------- .../spark/ml/feature/JavaNormalizerSuite.java | 24 +---------- .../org/apache/spark/ml/feature/JavaPCASuite.java | 26 ++---------- .../ml/feature/JavaPolynomialExpansionSuite.java | 24 +---------- .../spark/ml/feature/JavaStandardScalerSuite.java | 24 +---------- .../ml/feature/JavaStopWordsRemoverSuite.java | 22 +--------- .../spark/ml/feature/JavaStringIndexerSuite.java | 26 ++---------- .../spark/ml/feature/JavaTokenizerSuite.java | 24 +---------- .../spark/ml/feature/JavaVectorAssemblerSuite.java | 26 ++---------- .../spark/ml/feature/JavaVectorIndexerSuite.java | 25 +---------- .../spark/ml/feature/JavaVectorSlicerSuite.java | 21 +--------- .../apache/spark/ml/feature/JavaWord2VecSuite.java | 21 +--------- .../org/apache/spark/ml/param/JavaParamsSuite.java | 23 ----------- .../regression/JavaDecisionTreeRegressorSuite.java | 26 +----------- .../spark/ml/regression/JavaGBTRegressorSuite.java | 26 +----------- .../ml/regression/JavaLinearRegressionSuite.java | 28 +++---------- .../regression/JavaRandomForestRegressorSuite.java | 26 +----------- .../ml/source/libsvm/JavaLibSVMRelationSuite.java | 20 +++------ .../spark/ml/tuning/JavaCrossValidatorSuite.java | 33 ++++----------- .../spark/ml/util/JavaDefaultReadWriteSuite.java | 31 ++++---------- .../JavaLogisticRegressionSuite.java | 25 +---------- .../mllib/classification/JavaNaiveBayesSuite.java | 25 +---------- .../spark/mllib/classification/JavaSVMSuite.java | 25 +---------- .../JavaStreamingLogisticRegressionSuite.java | 3 +- .../mllib/clustering/JavaBisectingKMeansSuite.java | 26 +----------- .../mllib/clustering/JavaGaussianMixtureSuite.java | 25 +---------- .../spark/mllib/clustering/JavaKMeansSuite.java | 25 +---------- .../spark/mllib/clustering/JavaLDASuite.java | 29 +++---------- .../mllib/clustering/JavaStreamingKMeansSuite.java | 3 +- .../mllib/evaluation/JavaRankingMetricsSuite.java | 28 +++---------- .../apache/spark/mllib/feature/JavaTfIdfSuite.java | 25 +---------- .../spark/mllib/feature/JavaWord2VecSuite.java | 25 +---------- .../spark/mllib/fpm/JavaAssociationRulesSuite.java | 25 +---------- .../apache/spark/mllib/fpm/JavaFPGrowthSuite.java | 25 +---------- .../spark/mllib/fpm/JavaPrefixSpanSuite.java | 24 +---------- .../spark/mllib/linalg/JavaMatricesSuite.java | 3 +- .../spark/mllib/linalg/JavaVectorsSuite.java | 3 +- .../spark/mllib/random/JavaRandomRDDsSuite.java | 24 +---------- .../spark/mllib/recommendation/JavaALSSuite.java | 25 +---------- .../regression/JavaIsotonicRegressionSuite.java | 25 +---------- .../spark/mllib/regression/JavaLassoSuite.java | 25 +---------- .../regression/JavaLinearRegressionSuite.java | 25 +---------- .../mllib/regression/JavaRidgeRegressionSuite.java | 25 +---------- .../JavaStreamingLinearRegressionSuite.java | 3 +- .../spark/mllib/stat/JavaStatisticsSuite.java | 3 +- .../spark/mllib/tree/JavaDecisionTreeSuite.java | 26 +----------- 59 files changed, 207 insertions(+), 1148 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/SharedSparkSession.java diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java index 79b9909581..526bed93fb 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java @@ -37,7 +37,7 @@ public class JavaGaussianMixtureExample { public static void main(String[] args) { - // Creates a SparkSession + // Creates a SparkSession SparkSession spark = SparkSession .builder() .appName("JavaGaussianMixtureExample") diff --git a/mllib/src/test/java/org/apache/spark/SharedSparkSession.java b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java new file mode 100644 index 0000000000..4377987889 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import java.io.IOException; +import java.io.Serializable; + +import org.junit.After; +import org.junit.Before; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; + +public abstract class SharedSparkSession implements Serializable { + + protected transient SparkSession spark; + protected transient JavaSparkContext jsc; + + @Before + public void setUp() throws IOException { + spark = SparkSession.builder() + .master("local[2]") + .appName(getClass().getSimpleName()) + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); + } + + @After + public void tearDown() { + spark.stop(); + spark = null; + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index a81a36d1b1..9b209006bc 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -17,47 +17,34 @@ package org.apache.spark.ml; -import org.junit.After; -import org.junit.Before; +import java.io.IOException; + import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.classification.LogisticRegression; import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList; import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.ml.feature.StandardScaler; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; /** * Test Pipeline construction and fitting in Java. */ -public class JavaPipelineSuite { +public class JavaPipelineSuite extends SharedSparkSession { - private transient SparkSession spark; - private transient JavaSparkContext jsc; private transient Dataset dataset; - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaPipelineSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); + @Override + public void setUp() throws IOException { + super.setUp(); JavaRDD points = jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2); dataset = spark.createDataFrame(points, LabeledPoint.class); } - @After - public void tearDown() { - spark.stop(); - spark = null; - } - @Test public void pipeline() { StandardScaler scaler = new StandardScaler() diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java index c76a1947c6..5aba4e8f7d 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java @@ -17,42 +17,19 @@ package org.apache.spark.ml.classification; -import java.io.Serializable; import java.util.HashMap; import java.util.Map; -import org.junit.After; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.classification.LogisticRegressionSuite; import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; -public class JavaDecisionTreeClassifierSuite implements Serializable { - - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaDecisionTreeClassifierSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaDecisionTreeClassifierSuite extends SharedSparkSession { @Test public void runDT() { diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java index 4648926c34..74bb46bd21 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java @@ -17,43 +17,19 @@ package org.apache.spark.ml.classification; -import java.io.Serializable; import java.util.HashMap; import java.util.Map; -import org.junit.After; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.classification.LogisticRegressionSuite; import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; - -public class JavaGBTClassifierSuite implements Serializable { - - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaGBTClassifierSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaGBTClassifierSuite extends SharedSparkSession { @Test public void runDT() { diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index b8da04c26a..004102103d 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -17,52 +17,36 @@ package org.apache.spark.ml.classification; -import java.io.Serializable; +import java.io.IOException; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList; import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.ml.linalg.Vector; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; -public class JavaLogisticRegressionSuite implements Serializable { +public class JavaLogisticRegressionSuite extends SharedSparkSession { - private transient SparkSession spark; - private transient JavaSparkContext jsc; private transient Dataset dataset; private transient JavaRDD datasetRDD; private double eps = 1e-5; - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaLogisticRegressionSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - + @Override + public void setUp() throws IOException { + super.setUp(); List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); datasetRDD = jsc.parallelize(points, 2); dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class); dataset.createOrReplaceTempView("dataset"); } - @After - public void tearDown() { - spark.stop(); - spark = null; - } - @Test public void logisticRegressionDefaultParams() { LogisticRegression lr = new LogisticRegression(); diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java index 48edbc838c..6d0604d8f9 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java @@ -17,38 +17,19 @@ package org.apache.spark.ml.classification; -import java.io.Serializable; import java.util.Arrays; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; -public class JavaMultilayerPerceptronClassifierSuite implements Serializable { - - private transient SparkSession spark; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaLogisticRegressionSuite") - .getOrCreate(); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaMultilayerPerceptronClassifierSuite extends SharedSparkSession { @Test public void testMLPC() { diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java index 787909821b..c2a9e7b58b 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -17,43 +17,24 @@ package org.apache.spark.ml.classification; -import java.io.Serializable; import java.util.Arrays; import java.util.List; -import org.junit.After; -import org.junit.Before; import org.junit.Test; import static org.junit.Assert.assertEquals; +import org.apache.spark.SharedSparkSession; import org.apache.spark.ml.linalg.VectorUDT; import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -public class JavaNaiveBayesSuite implements Serializable { - - private transient SparkSession spark; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaLogisticRegressionSuite") - .getOrCreate(); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaNaiveBayesSuite extends SharedSparkSession { public void validatePrediction(Dataset predictionAndLabels) { for (Row r : predictionAndLabels.collectAsList()) { diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java index 58bc5a448a..6194167bda 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java @@ -17,39 +17,29 @@ package org.apache.spark.ml.classification; -import java.io.Serializable; +import java.io.IOException; import java.util.List; import scala.collection.JavaConverters; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateMultinomialLogisticInput; import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; +import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateMultinomialLogisticInput; -public class JavaOneVsRestSuite implements Serializable { +public class JavaOneVsRestSuite extends SharedSparkSession { - private transient SparkSession spark; - private transient JavaSparkContext jsc; private transient Dataset dataset; private transient JavaRDD datasetRDD; - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaLOneVsRestSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - + @Override + public void setUp() throws IOException { + super.setUp(); int nPoints = 3; // The following coefficients and xMean/xVariance are computed from iris dataset with @@ -68,12 +58,6 @@ public class JavaOneVsRestSuite implements Serializable { dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class); } - @After - public void tearDown() { - spark.stop(); - spark = null; - } - @Test public void oneVsRestDefaultParams() { OneVsRest ova = new OneVsRest(); diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java index 1ed20b1bfa..dd98513f37 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -17,45 +17,21 @@ package org.apache.spark.ml.classification; -import java.io.Serializable; import java.util.HashMap; import java.util.Map; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.classification.LogisticRegressionSuite; import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.ml.linalg.Vector; import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; - -public class JavaRandomForestClassifierSuite implements Serializable { - - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaRandomForestClassifierSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaRandomForestClassifierSuite extends SharedSparkSession { @Test public void runDT() { diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java index 9d07170fa1..1be6f96f4c 100644 --- a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java @@ -17,43 +17,30 @@ package org.apache.spark.ml.clustering; -import java.io.Serializable; +import java.io.IOException; import java.util.Arrays; import java.util.List; +import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; - +import org.apache.spark.SharedSparkSession; import org.apache.spark.ml.linalg.Vector; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; -public class JavaKMeansSuite implements Serializable { +public class JavaKMeansSuite extends SharedSparkSession { private transient int k = 5; private transient Dataset dataset; - private transient SparkSession spark; - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaKMeansSuite") - .getOrCreate(); + @Override + public void setUp() throws IOException { + super.setUp(); dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k); } - @After - public void tearDown() { - spark.stop(); - spark = null; - } - @Test public void fitAndTransform() { KMeans kmeans = new KMeans().setK(k).setSeed(1); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java index a96b43de15..87639380bd 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java @@ -20,36 +20,19 @@ package org.apache.spark.ml.feature; import java.util.Arrays; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -public class JavaBucketizerSuite { - private transient SparkSession spark; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaBucketizerSuite") - .getOrCreate(); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaBucketizerSuite extends SharedSparkSession { @Test public void bucketizerTest() { diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java index 9d8c09b30c..b7956b6fd3 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java @@ -22,38 +22,21 @@ import java.util.List; import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.ml.linalg.Vector; import org.apache.spark.ml.linalg.VectorUDT; import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -public class JavaDCTSuite { - private transient SparkSession spark; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaDCTSuite") - .getOrCreate(); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaDCTSuite extends SharedSparkSession { @Test public void javaCompatibilityTest() { diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index 3c37441a77..57696d0150 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -20,38 +20,21 @@ package org.apache.spark.ml.feature; import java.util.Arrays; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.ml.linalg.Vector; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -public class JavaHashingTFSuite { - private transient SparkSession spark; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaHashingTFSuite") - .getOrCreate(); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaHashingTFSuite extends SharedSparkSession { @Test public void hashingTF() { diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java index b3e213a497..6f877b5668 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java @@ -19,35 +19,15 @@ package org.apache.spark.ml.feature; import java.util.Arrays; -import org.junit.After; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; -public class JavaNormalizerSuite { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaNormalizerSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaNormalizerSuite extends SharedSparkSession { @Test public void normalizer() { diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java index a4bce2283b..ac479c0841 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java @@ -23,13 +23,11 @@ import java.util.List; import scala.Tuple2; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.ml.linalg.Vector; import org.apache.spark.ml.linalg.Vectors; @@ -37,26 +35,8 @@ import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.linalg.distributed.RowMatrix; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; - -public class JavaPCASuite implements Serializable { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaPCASuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaPCASuite extends SharedSparkSession { public static class VectorPair implements Serializable { private Vector features = Vectors.dense(0.0); @@ -95,7 +75,7 @@ public class JavaPCASuite implements Serializable { } } ).rdd()); - + Matrix pc = mat.computePrincipalComponents(3); mat.multiply(pc).rows().toJavaRDD(); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java index a28f73f10a..df5d34fbe9 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java @@ -20,41 +20,21 @@ package org.apache.spark.ml.feature; import java.util.Arrays; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.SharedSparkSession; import org.apache.spark.ml.linalg.Vector; import org.apache.spark.ml.linalg.VectorUDT; import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -public class JavaPolynomialExpansionSuite { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaPolynomialExpansionSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } +public class JavaPolynomialExpansionSuite extends SharedSparkSession { @Test public void polynomialExpansionTest() { diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java index 8415fdb84f..dbc0b1db5c 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java @@ -20,34 +20,14 @@ package org.apache.spark.ml.feature; import java.util.Arrays; import java.util.List; -import org.junit.After; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.SharedSparkSession; import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; -public class JavaStandardScalerSuite { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaStandardScalerSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaStandardScalerSuite extends SharedSparkSession { @Test public void standardScaler() { diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java index 2b156f3bca..6480b57e1f 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java @@ -20,37 +20,19 @@ package org.apache.spark.ml.feature; import java.util.Arrays; import java.util.List; -import org.junit.After; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -public class JavaStopWordsRemoverSuite { - - private transient SparkSession spark; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaStopWordsRemoverSuite") - .getOrCreate(); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaStopWordsRemoverSuite extends SharedSparkSession { @Test public void javaCompatibilityTest() { diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java index 52c0bde8f3..c1928a26b6 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -20,37 +20,19 @@ package org.apache.spark.ml.feature; import java.util.Arrays; import java.util.List; -import org.junit.After; +import static org.apache.spark.sql.types.DataTypes.*; + import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.SparkConf; +import org.apache.spark.SharedSparkSession; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -import static org.apache.spark.sql.types.DataTypes.*; - -public class JavaStringIndexerSuite { - private transient SparkSession spark; - @Before - public void setUp() { - SparkConf sparkConf = new SparkConf(); - sparkConf.setMaster("local"); - sparkConf.setAppName("JavaStringIndexerSuite"); - - spark = SparkSession.builder().config(sparkConf).getOrCreate(); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaStringIndexerSuite extends SharedSparkSession { @Test public void testStringIndexer() { diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java index 0bac2839e1..27550a3d5c 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java @@ -20,35 +20,15 @@ package org.apache.spark.ml.feature; import java.util.Arrays; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; -public class JavaTokenizerSuite { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaTokenizerSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaTokenizerSuite extends SharedSparkSession { @Test public void regexTokenizer() { diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java index fedaa77176..583652badb 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java @@ -19,40 +19,22 @@ package org.apache.spark.ml.feature; import java.util.Arrays; -import org.junit.After; +import static org.apache.spark.sql.types.DataTypes.*; + import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.SparkConf; +import org.apache.spark.SharedSparkSession; import org.apache.spark.ml.linalg.Vector; import org.apache.spark.ml.linalg.VectorUDT; import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -import static org.apache.spark.sql.types.DataTypes.*; - -public class JavaVectorAssemblerSuite { - private transient SparkSession spark; - @Before - public void setUp() { - SparkConf sparkConf = new SparkConf(); - sparkConf.setMaster("local"); - sparkConf.setAppName("JavaVectorAssemblerSuite"); - - spark = SparkSession.builder().config(sparkConf).getOrCreate(); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaVectorAssemblerSuite extends SharedSparkSession { @Test public void testVectorAssembler() { diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java index a8dd44608d..ca8fae3a48 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -17,42 +17,21 @@ package org.apache.spark.ml.feature; -import java.io.Serializable; import java.util.Arrays; import java.util.List; import java.util.Map; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.SharedSparkSession; import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData; import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; -public class JavaVectorIndexerSuite implements Serializable { - private transient SparkSession spark; - private JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaVectorIndexerSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaVectorIndexerSuite extends SharedSparkSession { @Test public void vectorIndexerAPI() { diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java index a565c77af4..3dc2e1f896 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java @@ -20,11 +20,10 @@ package org.apache.spark.ml.feature; import java.util.Arrays; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.ml.attribute.Attribute; import org.apache.spark.ml.attribute.AttributeGroup; import org.apache.spark.ml.attribute.NumericAttribute; @@ -33,26 +32,10 @@ import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.StructType; -public class JavaVectorSlicerSuite { - private transient SparkSession spark; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaVectorSlicerSuite") - .getOrCreate(); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaVectorSlicerSuite extends SharedSparkSession { @Test public void vectorSlice() { diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java index bef7eb0f99..d0a849fd11 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java @@ -19,34 +19,17 @@ package org.apache.spark.ml.feature; import java.util.Arrays; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.ml.linalg.Vector; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.*; -public class JavaWord2VecSuite { - private transient SparkSession spark; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaWord2VecSuite") - .getOrCreate(); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaWord2VecSuite extends SharedSparkSession { @Test public void testJavaWord2Vec() { diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java index a5b5dd4088..1077e103a3 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java @@ -19,37 +19,14 @@ package org.apache.spark.ml.param; import java.util.Arrays; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SparkSession; - /** * Test Param and related classes in Java */ public class JavaParamsSuite { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaParamsSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } - @Test public void testParams() { JavaTestParams testParams = new JavaTestParams(); diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java index 4ea3f2255e..1da85ed9da 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java @@ -17,43 +17,21 @@ package org.apache.spark.ml.regression; -import java.io.Serializable; import java.util.HashMap; import java.util.Map; -import org.junit.After; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.classification.LogisticRegressionSuite; import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; -public class JavaDecisionTreeRegressorSuite implements Serializable { - - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaDecisionTreeRegressorSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaDecisionTreeRegressorSuite extends SharedSparkSession { @Test public void runDT() { diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java index 3b5edf1e15..7fd9b1feb7 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java @@ -17,43 +17,21 @@ package org.apache.spark.ml.regression; -import java.io.Serializable; import java.util.HashMap; import java.util.Map; -import org.junit.After; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.classification.LogisticRegressionSuite; import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; -public class JavaGBTRegressorSuite implements Serializable { - - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaGBTRegressorSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaGBTRegressorSuite extends SharedSparkSession { @Test public void runDT() { diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java index 126aa6298f..6cdcdda1a6 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -17,48 +17,32 @@ package org.apache.spark.ml.regression; -import java.io.Serializable; +import java.io.IOException; import java.util.List; -import org.junit.After; -import org.junit.Before; import org.junit.Test; import static org.junit.Assert.assertEquals; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList; import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; -public class JavaLinearRegressionSuite implements Serializable { - - private transient SparkSession spark; - private transient JavaSparkContext jsc; +public class JavaLinearRegressionSuite extends SharedSparkSession { private transient Dataset dataset; private transient JavaRDD datasetRDD; - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaLinearRegressionSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); + @Override + public void setUp() throws IOException { + super.setUp(); List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); datasetRDD = jsc.parallelize(points, 2); dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class); dataset.createOrReplaceTempView("dataset"); } - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } - @Test public void linearRegressionDefaultParams() { LinearRegression lr = new LinearRegression(); diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java index d601e7c540..4ba13e2e06 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -17,45 +17,23 @@ package org.apache.spark.ml.regression; -import java.io.Serializable; import java.util.HashMap; import java.util.Map; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.classification.LogisticRegressionSuite; import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.ml.linalg.Vector; import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; -public class JavaRandomForestRegressorSuite implements Serializable { - - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaRandomForestRegressorSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaRandomForestRegressorSuite extends SharedSparkSession { @Test public void runDT() { diff --git a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java index 022dcf94bd..fa39f4560c 100644 --- a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java @@ -23,35 +23,28 @@ import java.nio.charset.StandardCharsets; import com.google.common.io.Files; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.ml.linalg.DenseVector; import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; import org.apache.spark.util.Utils; /** * Test LibSVMRelation in Java. */ -public class JavaLibSVMRelationSuite { - private transient SparkSession spark; +public class JavaLibSVMRelationSuite extends SharedSparkSession { private File tempDir; private String path; - @Before + @Override public void setUp() throws IOException { - spark = SparkSession.builder() - .master("local") - .appName("JavaLibSVMRelationSuite") - .getOrCreate(); - + super.setUp(); tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); File file = new File(tempDir, "part-00000"); String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0"; @@ -59,10 +52,9 @@ public class JavaLibSVMRelationSuite { path = tempDir.toURI().toString(); } - @After + @Override public void tearDown() { - spark.stop(); - spark = null; + super.tearDown(); Utils.deleteRecursively(tempDir); } diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index b874ccd48b..692d5ad591 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -17,48 +17,33 @@ package org.apache.spark.ml.tuning; -import java.io.Serializable; +import java.io.IOException; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.SharedSparkSession; import org.apache.spark.ml.classification.LogisticRegression; -import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList; -import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; +import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; +import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList; -public class JavaCrossValidatorSuite implements Serializable { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - private transient Dataset dataset; +public class JavaCrossValidatorSuite extends SharedSparkSession { - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaCrossValidatorSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); + private transient Dataset dataset; + @Override + public void setUp() throws IOException { + super.setUp(); List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); dataset = spark.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class); } - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } - @Test public void crossValidationWithLogisticRegression() { LogisticRegression lr = new LogisticRegression(); diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java index 7151e27cde..da623d1d15 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -20,42 +20,25 @@ package org.apache.spark.ml.util; import java.io.File; import java.io.IOException; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.SparkSession; +import org.apache.spark.SharedSparkSession; import org.apache.spark.util.Utils; -public class JavaDefaultReadWriteSuite { - - JavaSparkContext jsc = null; - SparkSession spark = null; +public class JavaDefaultReadWriteSuite extends SharedSparkSession { File tempDir = null; - @Before - public void setUp() { - SQLContext.clearActive(); - spark = SparkSession.builder() - .master("local[2]") - .appName("JavaDefaultReadWriteSuite") - .getOrCreate(); - SQLContext.setActive(spark.wrapped()); - + @Override + public void setUp() throws IOException { + super.setUp(); tempDir = Utils.createTempDir( System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite"); } - @After + @Override public void tearDown() { - SQLContext.clearActive(); - if (spark != null) { - spark.stop(); - spark = null; - } + super.tearDown(); Utils.deleteRecursively(tempDir); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java index 2f10d14da5..c04e2e6954 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java @@ -17,37 +17,16 @@ package org.apache.spark.mllib.classification; -import java.io.Serializable; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SparkSession; -public class JavaLogisticRegressionSuite implements Serializable { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaLogisticRegressionSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaLogisticRegressionSuite extends SharedSparkSession { int validatePrediction(List validationData, LogisticRegressionModel model) { int numAccurate = 0; diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java index 5e212e2fc5..6ded42e928 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -17,42 +17,21 @@ package org.apache.spark.mllib.classification; -import java.io.Serializable; import java.util.Arrays; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SparkSession; -public class JavaNaiveBayesSuite implements Serializable { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaNaiveBayesSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaNaiveBayesSuite extends SharedSparkSession { private static final List POINTS = Arrays.asList( new LabeledPoint(0, Vectors.dense(1.0, 0.0, 0.0)), diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java index 2a090c054f..0f54e684e4 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java @@ -17,37 +17,16 @@ package org.apache.spark.mllib.classification; -import java.io.Serializable; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SparkSession; -public class JavaSVMSuite implements Serializable { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaSVMSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaSVMSuite extends SharedSparkSession { int validatePrediction(List validationData, SVMModel model) { int numAccurate = 0; diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java index 62c6d9b7e3..8c6bced52d 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.mllib.classification; -import java.io.Serializable; import java.util.Arrays; import java.util.List; @@ -37,7 +36,7 @@ import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; import static org.apache.spark.streaming.JavaTestUtils.*; -public class JavaStreamingLogisticRegressionSuite implements Serializable { +public class JavaStreamingLogisticRegressionSuite { protected transient JavaStreamingContext ssc; diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java index 7f29b05047..3d62b273d2 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java @@ -17,39 +17,17 @@ package org.apache.spark.mllib.clustering; -import java.io.Serializable; - import com.google.common.collect.Lists; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.SparkSession; - -public class JavaBisectingKMeansSuite implements Serializable { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaBisectingKMeansSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaBisectingKMeansSuite extends SharedSparkSession { @Test public void twoDimensionalData() { diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java index 20edd08a21..bf76719937 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java @@ -17,40 +17,19 @@ package org.apache.spark.mllib.clustering; -import java.io.Serializable; import java.util.Arrays; import java.util.List; import static org.junit.Assert.assertEquals; -import org.junit.After; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.SparkSession; -public class JavaGaussianMixtureSuite implements Serializable { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaGaussianMixture") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaGaussianMixtureSuite extends SharedSparkSession { @Test public void runGaussianMixture() { diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java index 4e5b87f588..270e636f82 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java @@ -17,40 +17,19 @@ package org.apache.spark.mllib.clustering; -import java.io.Serializable; import java.util.Arrays; import java.util.List; import static org.junit.Assert.assertEquals; -import org.junit.After; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.SparkSession; -public class JavaKMeansSuite implements Serializable { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaKMeans") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaKMeansSuite extends SharedSparkSession { @Test public void runKMeansUsingStaticMethods() { diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index f16585aff4..08d6713ab2 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -17,39 +17,28 @@ package org.apache.spark.mllib.clustering; -import java.io.Serializable; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import scala.Tuple2; import scala.Tuple3; -import org.junit.After; -import org.junit.Before; import org.junit.Test; import static org.junit.Assert.*; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.SparkSession; - -public class JavaLDASuite implements Serializable { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaLDASuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); +public class JavaLDASuite extends SharedSparkSession { + @Override + public void setUp() throws IOException { + super.setUp(); ArrayList> tinyCorpus = new ArrayList<>(); for (int i = 0; i < LDASuite.tinyCorpus().length; i++) { tinyCorpus.add(new Tuple2<>((Long) LDASuite.tinyCorpus()[i]._1(), @@ -59,12 +48,6 @@ public class JavaLDASuite implements Serializable { corpus = JavaPairRDD.fromJavaRDD(tmpCorpus); } - @After - public void tearDown() { - spark.stop(); - spark = null; - } - @Test public void localLDAModel() { Matrix topics = LDASuite.tinyTopics(); diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java index d1d618f7de..d41fc0e4dc 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.mllib.clustering; -import java.io.Serializable; import java.util.Arrays; import java.util.List; @@ -36,7 +35,7 @@ import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; import static org.apache.spark.streaming.JavaTestUtils.*; -public class JavaStreamingKMeansSuite implements Serializable { +public class JavaStreamingKMeansSuite { protected transient JavaStreamingContext ssc; diff --git a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java index 6a096d6386..e9d7e4fdbe 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java @@ -17,35 +17,25 @@ package org.apache.spark.mllib.evaluation; -import java.io.Serializable; +import java.io.IOException; import java.util.Arrays; import java.util.List; import scala.Tuple2; import scala.Tuple2$; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SparkSession; -public class JavaRankingMetricsSuite implements Serializable { - private transient SparkSession spark; - private transient JavaSparkContext jsc; +public class JavaRankingMetricsSuite extends SharedSparkSession { private transient JavaRDD, List>> predictionAndLabels; - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaPCASuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - + @Override + public void setUp() throws IOException { + super.setUp(); predictionAndLabels = jsc.parallelize(Arrays.asList( Tuple2$.MODULE$.apply( Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Arrays.asList(1, 2, 3, 4, 5)), @@ -55,12 +45,6 @@ public class JavaRankingMetricsSuite implements Serializable { Arrays.asList(1, 2, 3, 4, 5), Arrays.asList())), 2); } - @After - public void tearDown() { - spark.stop(); - spark = null; - } - @Test public void rankingMetrics() { @SuppressWarnings("unchecked") diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java index de50fb8c4f..05128ea343 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java @@ -17,38 +17,17 @@ package org.apache.spark.mllib.feature; -import java.io.Serializable; import java.util.Arrays; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.SparkSession; -public class JavaTfIdfSuite implements Serializable { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaPCASuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaTfIdfSuite extends SharedSparkSession { @Test public void tfIdf() { diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java index 64885cc842..3e3abddbee 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.mllib.feature; -import java.io.Serializable; import java.util.Arrays; import java.util.List; @@ -25,33 +24,13 @@ import com.google.common.base.Strings; import scala.Tuple2; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SparkSession; -public class JavaWord2VecSuite implements Serializable { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaPCASuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaWord2VecSuite extends SharedSparkSession { @Test @SuppressWarnings("unchecked") diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java index fdc19a5b3d..3451e07737 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java @@ -16,36 +16,15 @@ */ package org.apache.spark.mllib.fpm; -import java.io.Serializable; import java.util.Arrays; -import org.junit.After; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; -import org.apache.spark.sql.SparkSession; -public class JavaAssociationRulesSuite implements Serializable { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaAssociationRulesSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaAssociationRulesSuite extends SharedSparkSession { @Test public void runAssociationRules() { diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java index f235251e61..46e9dd8b59 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -18,39 +18,18 @@ package org.apache.spark.mllib.fpm; import java.io.File; -import java.io.Serializable; import java.util.Arrays; import java.util.List; import static org.junit.Assert.assertEquals; -import org.junit.After; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SparkSession; import org.apache.spark.util.Utils; -public class JavaFPGrowthSuite implements Serializable { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaFPGrowth") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaFPGrowthSuite extends SharedSparkSession { @Test public void runFPGrowth() { diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java index bf7f1fc71b..75b0ec6480 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java @@ -21,35 +21,15 @@ import java.io.File; import java.util.Arrays; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence; -import org.apache.spark.sql.SparkSession; import org.apache.spark.util.Utils; -public class JavaPrefixSpanSuite { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaPrefixSpan") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaPrefixSpanSuite extends SharedSparkSession { @Test public void runPrefixSpan() { diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java index 92fc57871c..f427846b9a 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.mllib.linalg; -import java.io.Serializable; import java.util.Random; import static org.junit.Assert.assertArrayEquals; @@ -25,7 +24,7 @@ import static org.junit.Assert.assertEquals; import org.junit.Test; -public class JavaMatricesSuite implements Serializable { +public class JavaMatricesSuite { @Test public void randMatrixConstruction() { diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java index 817b962c75..f67f555e41 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.mllib.linalg; -import java.io.Serializable; import java.util.Arrays; import static org.junit.Assert.assertArrayEquals; @@ -26,7 +25,7 @@ import scala.Tuple2; import org.junit.Test; -public class JavaVectorsSuite implements Serializable { +public class JavaVectorsSuite { @Test public void denseArrayConstruction() { diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java index b449108a9b..6d114024c3 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -20,36 +20,16 @@ package org.apache.spark.mllib.random; import java.io.Serializable; import java.util.Arrays; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.SparkSession; import static org.apache.spark.mllib.random.RandomRDDs.*; -public class JavaRandomRDDsSuite { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaRandomRDDsSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaRandomRDDsSuite extends SharedSparkSession { @Test public void testUniformRDD() { diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index aa784054d5..363ab42546 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -17,41 +17,20 @@ package org.apache.spark.mllib.recommendation; -import java.io.Serializable; import java.util.ArrayList; import java.util.List; import scala.Tuple2; import scala.Tuple3; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SparkSession; - -public class JavaALSSuite implements Serializable { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaALS") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaALSSuite extends SharedSparkSession { private void validatePrediction( MatrixFactorizationModel model, diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java index 8b05675d65..dbd4cbfd2b 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java @@ -17,26 +17,20 @@ package org.apache.spark.mllib.regression; -import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import scala.Tuple3; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SparkSession; -public class JavaIsotonicRegressionSuite implements Serializable { - private transient SparkSession spark; - private transient JavaSparkContext jsc; +public class JavaIsotonicRegressionSuite extends SharedSparkSession { private static List> generateIsotonicInput(double[] labels) { List> input = new ArrayList<>(labels.length); @@ -55,21 +49,6 @@ public class JavaIsotonicRegressionSuite implements Serializable { return new IsotonicRegression().run(trainRDD); } - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaLinearRegressionSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } - @Test public void testIsotonicRegressionJavaRDD() { IsotonicRegressionModel model = diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java index 098bac3bed..1458cc72bc 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java @@ -17,37 +17,16 @@ package org.apache.spark.mllib.regression; -import java.io.Serializable; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.util.LinearDataGenerator; -import org.apache.spark.sql.SparkSession; -public class JavaLassoSuite implements Serializable { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaLassoSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaLassoSuite extends SharedSparkSession { int validatePrediction(List validationData, LassoModel model) { int numAccurate = 0; diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java index 35087a5e46..a46b1321b3 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java @@ -17,39 +17,18 @@ package org.apache.spark.mllib.regression; -import java.io.Serializable; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.util.LinearDataGenerator; -import org.apache.spark.sql.SparkSession; -public class JavaLinearRegressionSuite implements Serializable { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaLinearRegressionSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaLinearRegressionSuite extends SharedSparkSession { int validatePrediction(List validationData, LinearRegressionModel model) { int numAccurate = 0; diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java index b2efb2e72e..cb00977412 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java @@ -17,38 +17,17 @@ package org.apache.spark.mllib.regression; -import java.io.Serializable; import java.util.List; import java.util.Random; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.util.LinearDataGenerator; -import org.apache.spark.sql.SparkSession; - -public class JavaRidgeRegressionSuite implements Serializable { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaRidgeRegressionSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaRidgeRegressionSuite extends SharedSparkSession { private static double predictionError(List validationData, RidgeRegressionModel model) { diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java index ea0ccd7448..ab554475d5 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.mllib.regression; -import java.io.Serializable; import java.util.Arrays; import java.util.List; @@ -36,7 +35,7 @@ import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; import static org.apache.spark.streaming.JavaTestUtils.*; -public class JavaStreamingLinearRegressionSuite implements Serializable { +public class JavaStreamingLinearRegressionSuite { protected transient JavaStreamingContext ssc; diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java index 373417d3ba..1abaa39ead 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.mllib.stat; -import java.io.Serializable; import java.util.Arrays; import java.util.List; @@ -42,7 +41,7 @@ import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; import static org.apache.spark.streaming.JavaTestUtils.*; -public class JavaStatisticsSuite implements Serializable { +public class JavaStatisticsSuite { private transient SparkSession spark; private transient JavaSparkContext jsc; private transient JavaStreamingContext ssc; diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java index 5b464a4722..1dcbbcaa02 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -17,17 +17,14 @@ package org.apache.spark.mllib.tree; -import java.io.Serializable; import java.util.HashMap; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; @@ -35,27 +32,8 @@ import org.apache.spark.mllib.tree.configuration.Algo; import org.apache.spark.mllib.tree.configuration.Strategy; import org.apache.spark.mllib.tree.impurity.Gini; import org.apache.spark.mllib.tree.model.DecisionTreeModel; -import org.apache.spark.sql.SparkSession; - -public class JavaDecisionTreeSuite implements Serializable { - private transient SparkSession spark; - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - spark = SparkSession.builder() - .master("local") - .appName("JavaDecisionTreeSuite") - .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } +public class JavaDecisionTreeSuite extends SharedSparkSession { int validatePrediction(List validationData, DecisionTreeModel model) { int numCorrect = 0; -- cgit v1.2.3