aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org/apache
diff options
context:
space:
mode:
authorSandeep Singh <sandeep@techaddict.me>2016-05-10 11:17:47 -0700
committerAndrew Or <andrew@databricks.com>2016-05-10 11:17:47 -0700
commited0b4070fb50054b1ecf66ff6c32458a4967dfd3 (patch)
tree68b3ad1a3ca22f2e0b5966db517c9bc42da3d254 /mllib/src/test/java/org/apache
parentbcfee153b1cacfe617e602f3b72c0877e0bdf1f7 (diff)
downloadspark-ed0b4070fb50054b1ecf66ff6c32458a4967dfd3.tar.gz
spark-ed0b4070fb50054b1ecf66ff6c32458a4967dfd3.tar.bz2
spark-ed0b4070fb50054b1ecf66ff6c32458a4967dfd3.zip
[SPARK-15037][SQL][MLLIB] Use SparkSession instead of SQLContext in Scala/Java TestSuites
## What changes were proposed in this pull request? Use SparkSession instead of SQLContext in Scala/Java TestSuites as this PR already very big working Python TestSuites in a diff PR. ## How was this patch tested? Existing tests Author: Sandeep Singh <sandeep@techaddict.me> Closes #12907 from techaddict/SPARK-15037.
Diffstat (limited to 'mllib/src/test/java/org/apache')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java27
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java23
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java18
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java49
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java36
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java18
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java90
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java26
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java26
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java20
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java21
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java18
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java19
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java21
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java19
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java17
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java22
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java26
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java21
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java31
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java18
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java18
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java22
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java14
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java38
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java18
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java18
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java25
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java28
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java18
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java18
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala1
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java21
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java35
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java25
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java32
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java27
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java20
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java23
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java37
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java3
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java21
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java22
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java19
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java23
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java29
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java26
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java278
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java136
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java64
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java22
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java32
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java42
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java22
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java32
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java24
58 files changed, 1023 insertions, 785 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
index 60a4a1d2ea..e0c4363597 100644
--- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -17,18 +17,18 @@
package org.apache.spark.ml;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.StandardScaler;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
/**
@@ -36,23 +36,26 @@ import static org.apache.spark.mllib.classification.LogisticRegressionSuite.gene
*/
public class JavaPipelineSuite {
+ private transient SparkSession spark;
private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
private transient Dataset<Row> dataset;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaPipelineSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaPipelineSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
JavaRDD<LabeledPoint> points =
jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
- dataset = jsql.createDataFrame(points, LabeledPoint.class);
+ dataset = spark.createDataFrame(points, LabeledPoint.class);
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -63,10 +66,10 @@ public class JavaPipelineSuite {
LogisticRegression lr = new LogisticRegression()
.setFeaturesCol("scaledFeatures");
Pipeline pipeline = new Pipeline()
- .setStages(new PipelineStage[] {scaler, lr});
+ .setStages(new PipelineStage[]{scaler, lr});
PipelineModel model = pipeline.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
+ Dataset<Row> predictions = spark.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java
index b74bbed231..15cde0d3c0 100644
--- a/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java
@@ -17,8 +17,8 @@
package org.apache.spark.ml.attribute;
-import org.junit.Test;
import org.junit.Assert;
+import org.junit.Test;
public class JavaAttributeSuite {
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
index 1f23682621..8b89991327 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
@@ -21,8 +21,6 @@ import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -32,21 +30,28 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
-
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
public class JavaDecisionTreeClassifierSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaDecisionTreeClassifierSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaDecisionTreeClassifierSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -55,7 +60,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
double A = 2.0;
double B = -1.5;
- JavaRDD<LabeledPoint> data = sc.parallelize(
+ JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
@@ -70,7 +75,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
.setCacheNodeIds(false)
.setCheckpointInterval(10)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
- for (String impurity: DecisionTreeClassifier.supportedImpurities()) {
+ for (String impurity : DecisionTreeClassifier.supportedImpurities()) {
dt.setImpurity(impurity);
}
DecisionTreeClassificationModel model = dt.fit(dataFrame);
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
index 74841058a2..682371eb9e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
@@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
public class JavaGBTClassifierSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaGBTClassifierSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaGBTClassifierSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -55,7 +61,7 @@ public class JavaGBTClassifierSuite implements Serializable {
double A = 2.0;
double B = -1.5;
- JavaRDD<LabeledPoint> data = sc.parallelize(
+ JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
@@ -74,7 +80,7 @@ public class JavaGBTClassifierSuite implements Serializable {
.setMaxIter(3)
.setStepSize(0.1)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
- for (String lossType: GBTClassifier.supportedLossTypes()) {
+ for (String lossType : GBTClassifier.supportedLossTypes()) {
rf.setLossType(lossType);
}
GBTClassificationModel model = rf.fit(dataFrame);
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index e160a5a47e..e3ff68364e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -27,18 +27,17 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
-
+import org.apache.spark.sql.SparkSession;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
public class JavaLogisticRegressionSuite implements Serializable {
+ private transient SparkSession spark;
private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
@@ -46,18 +45,22 @@ public class JavaLogisticRegressionSuite implements Serializable {
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLogisticRegressionSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
+
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2);
- dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
+ dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.registerTempTable("dataset");
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -66,7 +69,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(lr.getLabelCol(), "label");
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
+ Dataset<Row> predictions = spark.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
// Check defaults
Assert.assertEquals(0.5, model.getThreshold(), eps);
@@ -95,23 +98,23 @@ public class JavaLogisticRegressionSuite implements Serializable {
// Modify model params, and check that the params worked.
model.setThreshold(1.0);
model.transform(dataset).registerTempTable("predAllZero");
- Dataset<Row> predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
- for (Row r: predAllZero.collectAsList()) {
+ Dataset<Row> predAllZero = spark.sql("SELECT prediction, myProbability FROM predAllZero");
+ for (Row r : predAllZero.collectAsList()) {
Assert.assertEquals(0.0, r.getDouble(0), eps);
}
// Call transform with params, and check that the params worked.
model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
.registerTempTable("predNotAllZero");
- Dataset<Row> predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
+ Dataset<Row> predNotAllZero = spark.sql("SELECT prediction, myProb FROM predNotAllZero");
boolean foundNonZero = false;
- for (Row r: predNotAllZero.collectAsList()) {
+ for (Row r : predNotAllZero.collectAsList()) {
if (r.getDouble(0) != 0.0) foundNonZero = true;
}
Assert.assertTrue(foundNonZero);
// Call fit() with new params, and check as many params as we can.
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
- lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
+ lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
LogisticRegression parent2 = (LogisticRegression) model2.parent();
Assert.assertEquals(5, parent2.getMaxIter());
Assert.assertEquals(0.1, parent2.getRegParam(), eps);
@@ -128,10 +131,10 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(2, model.numClasses());
model.transform(dataset).registerTempTable("transformed");
- Dataset<Row> trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
- for (Row row: trans1.collectAsList()) {
- Vector raw = (Vector)row.get(0);
- Vector prob = (Vector)row.get(1);
+ Dataset<Row> trans1 = spark.sql("SELECT rawPrediction, probability FROM transformed");
+ for (Row row : trans1.collectAsList()) {
+ Vector raw = (Vector) row.get(0);
+ Vector prob = (Vector) row.get(1);
Assert.assertEquals(raw.size(), 2);
Assert.assertEquals(prob.size(), 2);
double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1)));
@@ -139,11 +142,11 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps);
}
- Dataset<Row> trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
- for (Row row: trans2.collectAsList()) {
+ Dataset<Row> trans2 = spark.sql("SELECT prediction, probability FROM transformed");
+ for (Row row : trans2.collectAsList()) {
double pred = row.getDouble(0);
- Vector prob = (Vector)row.get(1);
- double probOfPred = prob.apply((int)pred);
+ Vector prob = (Vector) row.get(1);
+ double probOfPred = prob.apply((int) pred);
for (int i = 0; i < prob.size(); ++i) {
Assert.assertTrue(probOfPred >= prob.apply(i));
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
index bc955f3cf6..b0624cea3e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
@@ -26,49 +26,49 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
- private transient JavaSparkContext jsc;
- private transient SQLContext sqlContext;
+ private transient SparkSession spark;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
- sqlContext = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLogisticRegressionSuite")
+ .getOrCreate();
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
- sqlContext = null;
+ spark.stop();
+ spark = null;
}
@Test
public void testMLPC() {
- Dataset<Row> dataFrame = sqlContext.createDataFrame(
- jsc.parallelize(Arrays.asList(
- new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
- new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
- new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
- new LabeledPoint(0.0, Vectors.dense(1.0, 1.0)))),
- LabeledPoint.class);
+ List<LabeledPoint> data = Arrays.asList(
+ new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+ new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
+ new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
+ new LabeledPoint(0.0, Vectors.dense(1.0, 1.0))
+ );
+ Dataset<Row> dataFrame = spark.createDataFrame(data, LabeledPoint.class);
+
MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier()
- .setLayers(new int[] {2, 5, 2})
+ .setLayers(new int[]{2, 5, 2})
.setBlockSize(1)
.setSeed(123L)
.setMaxIter(100);
MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame);
Dataset<Row> result = model.transform(dataFrame);
List<Row> predictionAndLabels = result.select("prediction", "label").collectAsList();
- for (Row r: predictionAndLabels) {
+ for (Row r : predictionAndLabels) {
Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1));
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
index 45101f286c..3fc3648627 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
@@ -26,13 +26,12 @@ import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
@@ -40,19 +39,20 @@ import org.apache.spark.sql.types.StructType;
public class JavaNaiveBayesSuite implements Serializable {
- private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
+ private transient SparkSession spark;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLogisticRegressionSuite")
+ .getOrCreate();
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
public void validatePrediction(Dataset<Row> predictionAndLabels) {
@@ -88,7 +88,7 @@ public class JavaNaiveBayesSuite implements Serializable {
new StructField("features", new VectorUDT(), false, Metadata.empty())
});
- Dataset<Row> dataset = jsql.createDataFrame(data, schema);
+ Dataset<Row> dataset = spark.createDataFrame(data, schema);
NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
NaiveBayesModel model = nb.fit(dataset);
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
index 00f4476841..486fbbd58c 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
@@ -20,7 +20,6 @@ package org.apache.spark.ml.classification;
import java.io.Serializable;
import java.util.List;
-import org.apache.spark.sql.Row;
import scala.collection.JavaConverters;
import org.junit.After;
@@ -30,56 +29,61 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
public class JavaOneVsRestSuite implements Serializable {
- private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
- private transient Dataset<Row> dataset;
- private transient JavaRDD<LabeledPoint> datasetRDD;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
+ private transient Dataset<Row> dataset;
+ private transient JavaRDD<LabeledPoint> datasetRDD;
+
+ @Before
+ public void setUp() {
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLOneVsRestSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
- @Before
- public void setUp() {
- jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite");
- jsql = new SQLContext(jsc);
- int nPoints = 3;
+ int nPoints = 3;
- // The following coefficients and xMean/xVariance are computed from iris dataset with
- // lambda=0.2.
- // As a result, we are drawing samples from probability distribution of an actual model.
- double[] coefficients = {
- -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
- -0.16624, -0.84355, -0.048509, -0.301789, 4.170682 };
+ // The following coefficients and xMean/xVariance are computed from iris dataset with
+ // lambda=0.2.
+ // As a result, we are drawing samples from probability distribution of an actual model.
+ double[] coefficients = {
+ -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
+ -0.16624, -0.84355, -0.048509, -0.301789, 4.170682};
- double[] xMean = {5.843, 3.057, 3.758, 1.199};
- double[] xVariance = {0.6856, 0.1899, 3.116, 0.581};
- List<LabeledPoint> points = JavaConverters.seqAsJavaListConverter(
- generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42)
- ).asJava();
- datasetRDD = jsc.parallelize(points, 2);
- dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
- }
+ double[] xMean = {5.843, 3.057, 3.758, 1.199};
+ double[] xVariance = {0.6856, 0.1899, 3.116, 0.581};
+ List<LabeledPoint> points = JavaConverters.seqAsJavaListConverter(
+ generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42)
+ ).asJava();
+ datasetRDD = jsc.parallelize(points, 2);
+ dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
+ }
- @After
- public void tearDown() {
- jsc.stop();
- jsc = null;
- }
+ @After
+ public void tearDown() {
+ spark.stop();
+ spark = null;
+ }
- @Test
- public void oneVsRestDefaultParams() {
- OneVsRest ova = new OneVsRest();
- ova.setClassifier(new LogisticRegression());
- Assert.assertEquals(ova.getLabelCol() , "label");
- Assert.assertEquals(ova.getPredictionCol() , "prediction");
- OneVsRestModel ovaModel = ova.fit(dataset);
- Dataset<Row> predictions = ovaModel.transform(dataset).select("label", "prediction");
- predictions.collectAsList();
- Assert.assertEquals(ovaModel.getLabelCol(), "label");
- Assert.assertEquals(ovaModel.getPredictionCol() , "prediction");
- }
+ @Test
+ public void oneVsRestDefaultParams() {
+ OneVsRest ova = new OneVsRest();
+ ova.setClassifier(new LogisticRegression());
+ Assert.assertEquals(ova.getLabelCol(), "label");
+ Assert.assertEquals(ova.getPredictionCol(), "prediction");
+ OneVsRestModel ovaModel = ova.fit(dataset);
+ Dataset<Row> predictions = ovaModel.transform(dataset).select("label", "prediction");
+ predictions.collectAsList();
+ Assert.assertEquals(ovaModel.getLabelCol(), "label");
+ Assert.assertEquals(ovaModel.getPredictionCol(), "prediction");
+ }
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
index 4f40fd65b9..e3855662fb 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -34,21 +34,27 @@ import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
public class JavaRandomForestClassifierSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaRandomForestClassifierSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaRandomForestClassifierSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -57,7 +63,7 @@ public class JavaRandomForestClassifierSuite implements Serializable {
double A = 2.0;
double B = -1.5;
- JavaRDD<LabeledPoint> data = sc.parallelize(
+ JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
@@ -75,22 +81,22 @@ public class JavaRandomForestClassifierSuite implements Serializable {
.setSeed(1234)
.setNumTrees(3)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
- for (String impurity: RandomForestClassifier.supportedImpurities()) {
+ for (String impurity : RandomForestClassifier.supportedImpurities()) {
rf.setImpurity(impurity);
}
- for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) {
+ for (String featureSubsetStrategy : RandomForestClassifier.supportedFeatureSubsetStrategies()) {
rf.setFeatureSubsetStrategy(featureSubsetStrategy);
}
String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
- for (String strategy: realStrategies) {
+ for (String strategy : realStrategies) {
rf.setFeatureSubsetStrategy(strategy);
}
String[] integerStrategies = {"1", "10", "100", "1000", "10000"};
- for (String strategy: integerStrategies) {
+ for (String strategy : integerStrategies) {
rf.setFeatureSubsetStrategy(strategy);
}
String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
- for (String strategy: invalidStrategies) {
+ for (String strategy : invalidStrategies) {
try {
rf.setFeatureSubsetStrategy(strategy);
Assert.fail("Expected exception to be thrown for invalid strategies");
diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
index a3fcdb54ee..3ab09ac27d 100644
--- a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
@@ -21,37 +21,37 @@ import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
public class JavaKMeansSuite implements Serializable {
private transient int k = 5;
- private transient JavaSparkContext sc;
private transient Dataset<Row> dataset;
- private transient SQLContext sql;
+ private transient SparkSession spark;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaKMeansSuite");
- sql = new SQLContext(sc);
-
- dataset = KMeansSuite.generateKMeansData(sql, 50, 3, k);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaKMeansSuite")
+ .getOrCreate();
+ dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k);
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -65,7 +65,7 @@ public class JavaKMeansSuite implements Serializable {
Dataset<Row> transformed = model.transform(dataset);
List<String> columns = Arrays.asList(transformed.columns());
List<String> expectedColumns = Arrays.asList("features", "prediction");
- for (String column: expectedColumns) {
+ for (String column : expectedColumns) {
assertTrue(columns.contains(column));
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
index 77e3a489a9..a96b43de15 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
@@ -25,40 +25,40 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
public class JavaBucketizerSuite {
- private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
+ private transient SparkSession spark;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaBucketizerSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaBucketizerSuite")
+ .getOrCreate();
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
public void bucketizerTest() {
double[] splits = {-0.5, 0.0, 0.5};
- StructType schema = new StructType(new StructField[] {
+ StructType schema = new StructType(new StructField[]{
new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
});
- Dataset<Row> dataset = jsql.createDataFrame(
+ Dataset<Row> dataset = spark.createDataFrame(
Arrays.asList(
RowFactory.create(-0.5),
RowFactory.create(-0.3),
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
index ed1ad4c3a3..06482d8f0d 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
@@ -21,43 +21,44 @@ import java.util.Arrays;
import java.util.List;
import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D;
+
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
public class JavaDCTSuite {
- private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
+ private transient SparkSession spark;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaDCTSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaDCTSuite")
+ .getOrCreate();
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
public void javaCompatibilityTest() {
- double[] input = new double[] {1D, 2D, 3D, 4D};
- Dataset<Row> dataset = jsql.createDataFrame(
+ double[] input = new double[]{1D, 2D, 3D, 4D};
+ Dataset<Row> dataset = spark.createDataFrame(
Arrays.asList(RowFactory.create(Vectors.dense(input))),
new StructType(new StructField[]{
new StructField("vec", (new VectorUDT()), false, Metadata.empty())
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
index 6e2cc7e887..0e21d4a94f 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
@@ -25,12 +25,11 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
@@ -38,19 +37,20 @@ import org.apache.spark.sql.types.StructType;
public class JavaHashingTFSuite {
- private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
+ private transient SparkSession spark;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaHashingTFSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaHashingTFSuite")
+ .getOrCreate();
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -65,7 +65,7 @@ public class JavaHashingTFSuite {
new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
});
- Dataset<Row> sentenceData = jsql.createDataFrame(data, schema);
+ Dataset<Row> sentenceData = spark.createDataFrame(data, schema);
Tokenizer tokenizer = new Tokenizer()
.setInputCol("sentence")
.setOutputCol("words");
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
index 5bbd9634b2..04b2897b18 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
@@ -23,27 +23,30 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
public class JavaNormalizerSuite {
+ private transient SparkSession spark;
private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaNormalizerSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaNormalizerSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -54,7 +57,7 @@ public class JavaNormalizerSuite {
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
));
- Dataset<Row> dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class);
+ Dataset<Row> dataFrame = spark.createDataFrame(points, VectorIndexerSuite.FeatureData.class);
Normalizer normalizer = new Normalizer()
.setInputCol("features")
.setOutputCol("normFeatures");
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
index 1389d17e7e..32f6b4375e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
@@ -28,31 +28,34 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.linalg.distributed.RowMatrix;
+import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.mllib.linalg.distributed.RowMatrix;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
public class JavaPCASuite implements Serializable {
+ private transient SparkSession spark;
private transient JavaSparkContext jsc;
- private transient SQLContext sqlContext;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaPCASuite");
- sqlContext = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaPCASuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
public static class VectorPair implements Serializable {
@@ -100,7 +103,7 @@ public class JavaPCASuite implements Serializable {
}
);
- Dataset<Row> df = sqlContext.createDataFrame(featuresExpected, VectorPair.class);
+ Dataset<Row> df = spark.createDataFrame(featuresExpected, VectorPair.class);
PCAModel pca = new PCA()
.setInputCol("features")
.setOutputCol("pca_features")
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
index 6a8bb64801..8f726077a2 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
@@ -32,19 +32,22 @@ import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
public class JavaPolynomialExpansionSuite {
+ private transient SparkSession spark;
private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaPolynomialExpansionSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaPolynomialExpansionSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
@@ -72,20 +75,20 @@ public class JavaPolynomialExpansionSuite {
)
);
- StructType schema = new StructType(new StructField[] {
+ StructType schema = new StructType(new StructField[]{
new StructField("features", new VectorUDT(), false, Metadata.empty()),
new StructField("expected", new VectorUDT(), false, Metadata.empty())
});
- Dataset<Row> dataset = jsql.createDataFrame(data, schema);
+ Dataset<Row> dataset = spark.createDataFrame(data, schema);
List<Row> pairs = polyExpansion.transform(dataset)
.select("polyFeatures", "expected")
.collectAsList();
for (Row r : pairs) {
- double[] polyFeatures = ((Vector)r.get(0)).toArray();
- double[] expected = ((Vector)r.get(1)).toArray();
+ double[] polyFeatures = ((Vector) r.get(0)).toArray();
+ double[] expected = ((Vector) r.get(1)).toArray();
Assert.assertArrayEquals(polyFeatures, expected, 1e-1);
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
index 3f6fc333e4..c7397bdd68 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
@@ -28,22 +28,25 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
public class JavaStandardScalerSuite {
+ private transient SparkSession spark;
private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaStandardScalerSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaStandardScalerSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -54,7 +57,7 @@ public class JavaStandardScalerSuite {
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
);
- Dataset<Row> dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2),
+ Dataset<Row> dataFrame = spark.createDataFrame(jsc.parallelize(points, 2),
VectorIndexerSuite.FeatureData.class);
StandardScaler scaler = new StandardScaler()
.setInputCol("features")
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
index bdcbde5e26..2b156f3bca 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
@@ -24,11 +24,10 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
@@ -37,19 +36,20 @@ import org.apache.spark.sql.types.StructType;
public class JavaStopWordsRemoverSuite {
- private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
+ private transient SparkSession spark;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaStopWordsRemoverSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaStopWordsRemoverSuite")
+ .getOrCreate();
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -62,11 +62,11 @@ public class JavaStopWordsRemoverSuite {
RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")),
RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb"))
);
- StructType schema = new StructType(new StructField[] {
+ StructType schema = new StructType(new StructField[]{
new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false,
- Metadata.empty())
+ Metadata.empty())
});
- Dataset<Row> dataset = jsql.createDataFrame(data, schema);
+ Dataset<Row> dataset = spark.createDataFrame(data, schema);
remover.transform(dataset).collect();
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
index 431779cd2e..52c0bde8f3 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
@@ -25,40 +25,42 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.SparkConf;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import static org.apache.spark.sql.types.DataTypes.*;
public class JavaStringIndexerSuite {
- private transient JavaSparkContext jsc;
- private transient SQLContext sqlContext;
+ private transient SparkSession spark;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaStringIndexerSuite");
- sqlContext = new SQLContext(jsc);
+ SparkConf sparkConf = new SparkConf();
+ sparkConf.setMaster("local");
+ sparkConf.setAppName("JavaStringIndexerSuite");
+
+ spark = SparkSession.builder().config(sparkConf).getOrCreate();
}
@After
public void tearDown() {
- jsc.stop();
- sqlContext = null;
+ spark.stop();
+ spark = null;
}
@Test
public void testStringIndexer() {
- StructType schema = createStructType(new StructField[] {
+ StructType schema = createStructType(new StructField[]{
createStructField("id", IntegerType, false),
createStructField("label", StringType, false)
});
List<Row> data = Arrays.asList(
cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c"));
- Dataset<Row> dataset = sqlContext.createDataFrame(data, schema);
+ Dataset<Row> dataset = spark.createDataFrame(data, schema);
StringIndexer indexer = new StringIndexer()
.setInputCol("label")
@@ -70,7 +72,9 @@ public class JavaStringIndexerSuite {
output.orderBy("id").select("id", "labelIndex").collectAsList());
}
- /** An alias for RowFactory.create. */
+ /**
+ * An alias for RowFactory.create.
+ */
private Row cr(Object... values) {
return RowFactory.create(values);
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
index 83d16cbd0e..0bac2839e1 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
@@ -29,22 +29,25 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
public class JavaTokenizerSuite {
+ private transient SparkSession spark;
private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaTokenizerSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaTokenizerSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -59,10 +62,10 @@ public class JavaTokenizerSuite {
JavaRDD<TokenizerTestData> rdd = jsc.parallelize(Arrays.asList(
- new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}),
- new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"})
+ new TokenizerTestData("Test of tok.", new String[]{"Test", "tok."}),
+ new TokenizerTestData("Te,st. punct", new String[]{"Te,st.", "punct"})
));
- Dataset<Row> dataset = jsql.createDataFrame(rdd, TokenizerTestData.class);
+ Dataset<Row> dataset = spark.createDataFrame(rdd, TokenizerTestData.class);
List<Row> pairs = myRegExTokenizer.transform(dataset)
.select("tokens", "wantedTokens")
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
index e45e198043..8774cd0c69 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
@@ -24,36 +24,39 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.SparkConf;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
-import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
import static org.apache.spark.sql.types.DataTypes.*;
public class JavaVectorAssemblerSuite {
- private transient JavaSparkContext jsc;
- private transient SQLContext sqlContext;
+ private transient SparkSession spark;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaVectorAssemblerSuite");
- sqlContext = new SQLContext(jsc);
+ SparkConf sparkConf = new SparkConf();
+ sparkConf.setMaster("local");
+ sparkConf.setAppName("JavaVectorAssemblerSuite");
+
+ spark = SparkSession.builder().config(sparkConf).getOrCreate();
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
public void testVectorAssembler() {
- StructType schema = createStructType(new StructField[] {
+ StructType schema = createStructType(new StructField[]{
createStructField("id", IntegerType, false),
createStructField("x", DoubleType, false),
createStructField("y", new VectorUDT(), false),
@@ -63,14 +66,14 @@ public class JavaVectorAssemblerSuite {
});
Row row = RowFactory.create(
0, 0.0, Vectors.dense(1.0, 2.0), "a",
- Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L);
- Dataset<Row> dataset = sqlContext.createDataFrame(Arrays.asList(row), schema);
+ Vectors.sparse(2, new int[]{1}, new double[]{3.0}), 10L);
+ Dataset<Row> dataset = spark.createDataFrame(Arrays.asList(row), schema);
VectorAssembler assembler = new VectorAssembler()
- .setInputCols(new String[] {"x", "y", "z", "n"})
+ .setInputCols(new String[]{"x", "y", "z", "n"})
.setOutputCol("features");
Dataset<Row> output = assembler.transform(dataset);
Assert.assertEquals(
- Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}),
+ Vectors.sparse(6, new int[]{1, 2, 4, 5}, new double[]{1.0, 2.0, 3.0, 10.0}),
output.select("features").first().<Vector>getAs(0));
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
index fec6cac8be..c386c9a45b 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
@@ -32,21 +32,26 @@ import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
public class JavaVectorIndexerSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaVectorIndexerSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaVectorIndexerSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -57,8 +62,7 @@ public class JavaVectorIndexerSuite implements Serializable {
new FeatureData(Vectors.dense(1.0, 3.0)),
new FeatureData(Vectors.dense(1.0, 4.0))
);
- SQLContext sqlContext = new SQLContext(sc);
- Dataset<Row> data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class);
+ Dataset<Row> data = spark.createDataFrame(jsc.parallelize(points, 2), FeatureData.class);
VectorIndexer indexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexed")
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
index e2da11183b..59ad3c2f61 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
@@ -25,7 +25,6 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.NumericAttribute;
@@ -34,24 +33,25 @@ import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructType;
public class JavaVectorSlicerSuite {
- private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
+ private transient SparkSession spark;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaVectorSlicerSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaVectorSlicerSuite")
+ .getOrCreate();
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -69,7 +69,7 @@ public class JavaVectorSlicerSuite {
);
Dataset<Row> dataset =
- jsql.createDataFrame(data, (new StructType()).add(group.toStructField()));
+ spark.createDataFrame(data, (new StructType()).add(group.toStructField()));
VectorSlicer vectorSlicer = new VectorSlicer()
.setInputCol("userFeatures").setOutputCol("features");
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
index 7517b70cc9..392aabc96d 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
@@ -24,28 +24,28 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.*;
public class JavaWord2VecSuite {
- private transient JavaSparkContext jsc;
- private transient SQLContext sqlContext;
+ private transient SparkSession spark;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaWord2VecSuite");
- sqlContext = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaWord2VecSuite")
+ .getOrCreate();
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -53,7 +53,7 @@ public class JavaWord2VecSuite {
StructType schema = new StructType(new StructField[]{
new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
});
- Dataset<Row> documentDF = sqlContext.createDataFrame(
+ Dataset<Row> documentDF = spark.createDataFrame(
Arrays.asList(
RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
@@ -68,8 +68,8 @@ public class JavaWord2VecSuite {
Word2VecModel model = word2Vec.fit(documentDF);
Dataset<Row> result = model.transform(documentDF);
- for (Row r: result.select("result").collectAsList()) {
- double[] polyFeatures = ((Vector)r.get(0)).toArray();
+ for (Row r : result.select("result").collectAsList()) {
+ double[] polyFeatures = ((Vector) r.get(0)).toArray();
Assert.assertEquals(polyFeatures.length, 3);
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
index fa777f3d42..a5b5dd4088 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
@@ -25,23 +25,29 @@ import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
/**
* Test Param and related classes in Java
*/
public class JavaParamsSuite {
+ private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaParamsSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaParamsSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -51,7 +57,7 @@ public class JavaParamsSuite {
testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a");
Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
Assert.assertEquals(testParams.getMyStringParam(), "a");
- Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[] {1.0, 2.0}, 0.0);
+ Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[]{1.0, 2.0}, 0.0);
}
@Test
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
index 06f7fbb86e..1ad5f7a442 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
@@ -45,9 +45,14 @@ public class JavaTestParams extends JavaParams {
}
private IntParam myIntParam_;
- public IntParam myIntParam() { return myIntParam_; }
- public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); }
+ public IntParam myIntParam() {
+ return myIntParam_;
+ }
+
+ public int getMyIntParam() {
+ return (Integer) getOrDefault(myIntParam_);
+ }
public JavaTestParams setMyIntParam(int value) {
set(myIntParam_, value);
@@ -55,9 +60,14 @@ public class JavaTestParams extends JavaParams {
}
private DoubleParam myDoubleParam_;
- public DoubleParam myDoubleParam() { return myDoubleParam_; }
- public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); }
+ public DoubleParam myDoubleParam() {
+ return myDoubleParam_;
+ }
+
+ public double getMyDoubleParam() {
+ return (Double) getOrDefault(myDoubleParam_);
+ }
public JavaTestParams setMyDoubleParam(double value) {
set(myDoubleParam_, value);
@@ -65,9 +75,14 @@ public class JavaTestParams extends JavaParams {
}
private Param<String> myStringParam_;
- public Param<String> myStringParam() { return myStringParam_; }
- public String getMyStringParam() { return getOrDefault(myStringParam_); }
+ public Param<String> myStringParam() {
+ return myStringParam_;
+ }
+
+ public String getMyStringParam() {
+ return getOrDefault(myStringParam_);
+ }
public JavaTestParams setMyStringParam(String value) {
set(myStringParam_, value);
@@ -75,9 +90,14 @@ public class JavaTestParams extends JavaParams {
}
private DoubleArrayParam myDoubleArrayParam_;
- public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; }
- public double[] getMyDoubleArrayParam() { return getOrDefault(myDoubleArrayParam_); }
+ public DoubleArrayParam myDoubleArrayParam() {
+ return myDoubleArrayParam_;
+ }
+
+ public double[] getMyDoubleArrayParam() {
+ return getOrDefault(myDoubleArrayParam_);
+ }
public JavaTestParams setMyDoubleArrayParam(double[] value) {
set(myDoubleArrayParam_, value);
@@ -96,7 +116,7 @@ public class JavaTestParams extends JavaParams {
setDefault(myIntParam(), 1);
setDefault(myDoubleParam(), 0.5);
- setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0});
+ setDefault(myDoubleArrayParam(), new double[]{1.0, 2.0});
}
@Override
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
index fa3b28ed4f..bbd59a04ec 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
@@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
public class JavaDecisionTreeRegressorSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaDecisionTreeRegressorSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaDecisionTreeRegressorSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -55,7 +61,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable {
double A = 2.0;
double B = -1.5;
- JavaRDD<LabeledPoint> data = sc.parallelize(
+ JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
@@ -70,7 +76,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable {
.setCacheNodeIds(false)
.setCheckpointInterval(10)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
- for (String impurity: DecisionTreeRegressor.supportedImpurities()) {
+ for (String impurity : DecisionTreeRegressor.supportedImpurities()) {
dt.setImpurity(impurity);
}
DecisionTreeRegressionModel model = dt.fit(dataFrame);
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
index 8413ea0e0a..5370b58e8f 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
@@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
public class JavaGBTRegressorSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaGBTRegressorSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaGBTRegressorSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -55,7 +61,7 @@ public class JavaGBTRegressorSuite implements Serializable {
double A = 2.0;
double B = -1.5;
- JavaRDD<LabeledPoint> data = sc.parallelize(
+ JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
@@ -73,7 +79,7 @@ public class JavaGBTRegressorSuite implements Serializable {
.setMaxIter(3)
.setStepSize(0.1)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
- for (String lossType: GBTRegressor.supportedLossTypes()) {
+ for (String lossType : GBTRegressor.supportedLossTypes()) {
rf.setLossType(lossType);
}
GBTRegressionModel model = rf.fit(dataFrame);
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index 9f817515eb..00c59f08b6 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -30,25 +30,26 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
-import static org.apache.spark.mllib.classification.LogisticRegressionSuite
- .generateLogisticInputAsList;
-
+import org.apache.spark.sql.SparkSession;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
public class JavaLinearRegressionSuite implements Serializable {
+ private transient SparkSession spark;
private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLinearRegressionSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2);
- dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
+ dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.registerTempTable("dataset");
}
@@ -65,7 +66,7 @@ public class JavaLinearRegressionSuite implements Serializable {
assertEquals("auto", lr.getSolver());
LinearRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- Dataset<Row> predictions = jsql.sql("SELECT label, prediction FROM prediction");
+ Dataset<Row> predictions = spark.sql("SELECT label, prediction FROM prediction");
predictions.collect();
// Check defaults
assertEquals("features", model.getFeaturesCol());
@@ -76,8 +77,8 @@ public class JavaLinearRegressionSuite implements Serializable {
public void linearRegressionWithSetters() {
// Set params, train, and check as many params as we can.
LinearRegression lr = new LinearRegression()
- .setMaxIter(10)
- .setRegParam(1.0).setSolver("l-bfgs");
+ .setMaxIter(10)
+ .setRegParam(1.0).setSolver("l-bfgs");
LinearRegressionModel model = lr.fit(dataset);
LinearRegression parent = (LinearRegression) model.parent();
assertEquals(10, parent.getMaxIter());
@@ -85,7 +86,7 @@ public class JavaLinearRegressionSuite implements Serializable {
// Call fit() with new params, and check as many params as we can.
LinearRegressionModel model2 =
- lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
+ lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
LinearRegression parent2 = (LinearRegression) model2.parent();
assertEquals(5, parent2.getMaxIter());
assertEquals(0.1, parent2.getRegParam(), 0.0);
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
index 38b895f1fd..fdb41ffc10 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -28,27 +28,33 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.ml.tree.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
public class JavaRandomForestRegressorSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaRandomForestRegressorSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaRandomForestRegressorSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -57,7 +63,7 @@ public class JavaRandomForestRegressorSuite implements Serializable {
double A = 2.0;
double B = -1.5;
- JavaRDD<LabeledPoint> data = sc.parallelize(
+ JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
@@ -75,22 +81,22 @@ public class JavaRandomForestRegressorSuite implements Serializable {
.setSeed(1234)
.setNumTrees(3)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
- for (String impurity: RandomForestRegressor.supportedImpurities()) {
+ for (String impurity : RandomForestRegressor.supportedImpurities()) {
rf.setImpurity(impurity);
}
- for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) {
+ for (String featureSubsetStrategy : RandomForestRegressor.supportedFeatureSubsetStrategies()) {
rf.setFeatureSubsetStrategy(featureSubsetStrategy);
}
String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
- for (String strategy: realStrategies) {
+ for (String strategy : realStrategies) {
rf.setFeatureSubsetStrategy(strategy);
}
String[] integerStrategies = {"1", "10", "100", "1000", "10000"};
- for (String strategy: integerStrategies) {
+ for (String strategy : integerStrategies) {
rf.setFeatureSubsetStrategy(strategy);
}
String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
- for (String strategy: invalidStrategies) {
+ for (String strategy : invalidStrategies) {
try {
rf.setFeatureSubsetStrategy(strategy);
Assert.fail("Expected exception to be thrown for invalid strategies");
diff --git a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
index 1c18b2b266..058f2ddafd 100644
--- a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
@@ -28,12 +28,11 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils;
@@ -41,16 +40,17 @@ import org.apache.spark.util.Utils;
* Test LibSVMRelation in Java.
*/
public class JavaLibSVMRelationSuite {
- private transient JavaSparkContext jsc;
- private transient SQLContext sqlContext;
+ private transient SparkSession spark;
private File tempDir;
private String path;
@Before
public void setUp() throws IOException {
- jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite");
- sqlContext = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLibSVMRelationSuite")
+ .getOrCreate();
tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
File file = new File(tempDir, "part-00000");
@@ -61,14 +61,14 @@ public class JavaLibSVMRelationSuite {
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
Utils.deleteRecursively(tempDir);
}
@Test
public void verifyLibSVMDF() {
- Dataset<Row> dataset = sqlContext.read().format("libsvm").option("vectorType", "dense")
+ Dataset<Row> dataset = spark.read().format("libsvm").option("vectorType", "dense")
.load(path);
Assert.assertEquals("label", dataset.columns()[0]);
Assert.assertEquals("features", dataset.columns()[1]);
diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
index 24b0097454..8b4d034ffe 100644
--- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -32,21 +32,25 @@ import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
public class JavaCrossValidatorSuite implements Serializable {
+ private transient SparkSession spark;
private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
private transient Dataset<Row> dataset;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaCrossValidatorSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
+
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
- dataset = jsql.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class);
+ dataset = spark.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class);
}
@After
@@ -59,8 +63,8 @@ public class JavaCrossValidatorSuite implements Serializable {
public void crossValidationWithLogisticRegression() {
LogisticRegression lr = new LogisticRegression();
ParamMap[] lrParamMaps = new ParamGridBuilder()
- .addGrid(lr.regParam(), new double[] {0.001, 1000.0})
- .addGrid(lr.maxIter(), new int[] {0, 10})
+ .addGrid(lr.regParam(), new double[]{0.001, 1000.0})
+ .addGrid(lr.maxIter(), new int[]{0, 10})
.build();
BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator();
CrossValidator cv = new CrossValidator()
diff --git a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
index 928301523f..878bc66ee3 100644
--- a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
+++ b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
@@ -37,4 +37,5 @@ object IdentifiableSuite {
class Test(override val uid: String) extends Identifiable {
def this() = this(Identifiable.randomUID("test"))
}
+
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
index 01ff1ea658..7151e27cde 100644
--- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
@@ -27,31 +27,34 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils;
public class JavaDefaultReadWriteSuite {
JavaSparkContext jsc = null;
- SQLContext sqlContext = null;
+ SparkSession spark = null;
File tempDir = null;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite");
SQLContext.clearActive();
- sqlContext = new SQLContext(jsc);
- SQLContext.setActive(sqlContext);
+ spark = SparkSession.builder()
+ .master("local[2]")
+ .appName("JavaDefaultReadWriteSuite")
+ .getOrCreate();
+ SQLContext.setActive(spark.wrapped());
+
tempDir = Utils.createTempDir(
System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite");
}
@After
public void tearDown() {
- sqlContext = null;
SQLContext.clearActive();
- if (jsc != null) {
- jsc.stop();
- jsc = null;
+ if (spark != null) {
+ spark.stop();
+ spark = null;
}
Utils.deleteRecursively(tempDir);
}
@@ -70,7 +73,7 @@ public class JavaDefaultReadWriteSuite {
} catch (IOException e) {
// expected
}
- instance.write().context(sqlContext).overwrite().save(outputPath);
+ instance.write().context(spark.wrapped()).overwrite().save(outputPath);
MyParams newInstance = MyParams.load(outputPath);
Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid());
Assert.assertEquals("Params should be preserved.",
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
index 862221d487..2f10d14da5 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
@@ -27,26 +27,31 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-
import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.SparkSession;
public class JavaLogisticRegressionSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLogisticRegressionSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
int validatePrediction(List<LabeledPoint> validationData, LogisticRegressionModel model) {
int numAccurate = 0;
- for (LabeledPoint point: validationData) {
+ for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features());
if (prediction == point.label()) {
numAccurate++;
@@ -61,16 +66,16 @@ public class JavaLogisticRegressionSuite implements Serializable {
double A = 2.0;
double B = -1.5;
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(
- LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
List<LabeledPoint> validationData =
- LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD();
lrImpl.setIntercept(true);
lrImpl.optimizer().setStepSize(1.0)
- .setRegParam(1.0)
- .setNumIterations(100);
+ .setRegParam(1.0)
+ .setNumIterations(100);
LogisticRegressionModel model = lrImpl.run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model);
@@ -83,13 +88,13 @@ public class JavaLogisticRegressionSuite implements Serializable {
double A = 0.0;
double B = -2.5;
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(
- LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
List<LabeledPoint> validationData =
- LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
LogisticRegressionModel model = LogisticRegressionWithSGD.train(
- testRDD.rdd(), 100, 1.0, 1.0);
+ testRDD.rdd(), 100, 1.0, 1.0);
int numAccurate = validatePrediction(validationData, model);
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
index 3771c0ea7a..5e212e2fc5 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
@@ -32,20 +32,26 @@ import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.SparkSession;
public class JavaNaiveBayesSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaNaiveBayesSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaNaiveBayesSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
private static final List<LabeledPoint> POINTS = Arrays.asList(
@@ -59,7 +65,7 @@ public class JavaNaiveBayesSuite implements Serializable {
private int validatePrediction(List<LabeledPoint> points, NaiveBayesModel model) {
int correct = 0;
- for (LabeledPoint p: points) {
+ for (LabeledPoint p : points) {
if (model.predict(p.features()) == p.label()) {
correct += 1;
}
@@ -69,7 +75,7 @@ public class JavaNaiveBayesSuite implements Serializable {
@Test
public void runUsingConstructor() {
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(POINTS, 2).cache();
NaiveBayes nb = new NaiveBayes().setLambda(1.0);
NaiveBayesModel model = nb.run(testRDD.rdd());
@@ -80,7 +86,7 @@ public class JavaNaiveBayesSuite implements Serializable {
@Test
public void runUsingStaticMethods() {
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(POINTS, 2).cache();
NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd());
int numAccurate1 = validatePrediction(POINTS, model1);
@@ -93,13 +99,14 @@ public class JavaNaiveBayesSuite implements Serializable {
@Test
public void testPredictJavaRDD() {
- JavaRDD<LabeledPoint> examples = sc.parallelize(POINTS, 2).cache();
+ JavaRDD<LabeledPoint> examples = jsc.parallelize(POINTS, 2).cache();
NaiveBayesModel model = NaiveBayes.train(examples.rdd());
JavaRDD<Vector> vectors = examples.map(new Function<LabeledPoint, Vector>() {
@Override
public Vector call(LabeledPoint v) throws Exception {
return v.features();
- }});
+ }
+ });
JavaRDD<Double> predictions = model.predict(vectors);
// Should be able to get the first prediction.
predictions.first();
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
index 31b9f3e8d4..2a090c054f 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
@@ -28,24 +28,30 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.SparkSession;
public class JavaSVMSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaSVMSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaSVMSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
int validatePrediction(List<LabeledPoint> validationData, SVMModel model) {
int numAccurate = 0;
- for (LabeledPoint point: validationData) {
+ for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features());
if (prediction == point.label()) {
numAccurate++;
@@ -60,16 +66,16 @@ public class JavaSVMSuite implements Serializable {
double A = 2.0;
double[] weights = {-1.5, 1.0};
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A,
- weights, nPoints, 42), 2).cache();
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(SVMSuite.generateSVMInputAsList(A,
+ weights, nPoints, 42), 2).cache();
List<LabeledPoint> validationData =
- SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
+ SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
SVMWithSGD svmSGDImpl = new SVMWithSGD();
svmSGDImpl.setIntercept(true);
svmSGDImpl.optimizer().setStepSize(1.0)
- .setRegParam(1.0)
- .setNumIterations(100);
+ .setRegParam(1.0)
+ .setNumIterations(100);
SVMModel model = svmSGDImpl.run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model);
@@ -82,10 +88,10 @@ public class JavaSVMSuite implements Serializable {
double A = 0.0;
double[] weights = {-1.5, 1.0};
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A,
- weights, nPoints, 42), 2).cache();
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(SVMSuite.generateSVMInputAsList(A,
+ weights, nPoints, 42), 2).cache();
List<LabeledPoint> validationData =
- SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
+ SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
SVMModel model = SVMWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0, 1.0);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
index a714620ff7..7f29b05047 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering;
import java.io.Serializable;
import com.google.common.collect.Lists;
+
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
@@ -29,27 +30,33 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.SparkSession;
public class JavaBisectingKMeansSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", this.getClass().getSimpleName());
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaBisectingKMeansSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
public void twoDimensionalData() {
- JavaRDD<Vector> points = sc.parallelize(Lists.newArrayList(
+ JavaRDD<Vector> points = jsc.parallelize(Lists.newArrayList(
Vectors.dense(4, -1),
Vectors.dense(4, 1),
- Vectors.sparse(2, new int[] {0}, new double[] {1.0})
+ Vectors.sparse(2, new int[]{0}, new double[]{1.0})
), 2);
BisectingKMeans bkm = new BisectingKMeans()
@@ -58,15 +65,15 @@ public class JavaBisectingKMeansSuite implements Serializable {
.setSeed(1L);
BisectingKMeansModel model = bkm.run(points);
Assert.assertEquals(3, model.k());
- Assert.assertArrayEquals(new double[] {3.0, 0.0}, model.root().center().toArray(), 1e-12);
- for (ClusteringTreeNode child: model.root().children()) {
+ Assert.assertArrayEquals(new double[]{3.0, 0.0}, model.root().center().toArray(), 1e-12);
+ for (ClusteringTreeNode child : model.root().children()) {
double[] center = child.center().toArray();
if (center[0] > 2) {
Assert.assertEquals(2, child.size());
- Assert.assertArrayEquals(new double[] {4.0, 0.0}, center, 1e-12);
+ Assert.assertArrayEquals(new double[]{4.0, 0.0}, center, 1e-12);
} else {
Assert.assertEquals(1, child.size());
- Assert.assertArrayEquals(new double[] {1.0, 0.0}, center, 1e-12);
+ Assert.assertArrayEquals(new double[]{1.0, 0.0}, center, 1e-12);
}
}
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
index 123f78da54..20edd08a21 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
@@ -21,29 +21,35 @@ import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
+import static org.junit.Assert.assertEquals;
+
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
-import static org.junit.Assert.assertEquals;
-
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.SparkSession;
public class JavaGaussianMixtureSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaGaussianMixture");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaGaussianMixture")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -54,7 +60,7 @@ public class JavaGaussianMixtureSuite implements Serializable {
Vectors.dense(1.0, 4.0, 6.0)
);
- JavaRDD<Vector> data = sc.parallelize(points, 2);
+ JavaRDD<Vector> data = jsc.parallelize(points, 2);
GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234)
.run(data);
assertEquals(model.gaussians().length, 2);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
index ad06676c72..4e5b87f588 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
@@ -21,28 +21,35 @@ import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
+import static org.junit.Assert.assertEquals;
+
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
-import static org.junit.Assert.*;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.SparkSession;
public class JavaKMeansSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaKMeans");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaKMeans")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -55,7 +62,7 @@ public class JavaKMeansSuite implements Serializable {
Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
- JavaRDD<Vector> data = sc.parallelize(points, 2);
+ JavaRDD<Vector> data = jsc.parallelize(points, 2);
KMeansModel model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.K_MEANS_PARALLEL());
assertEquals(1, model.clusterCenters().length);
assertEquals(expectedCenter, model.clusterCenters()[0]);
@@ -74,7 +81,7 @@ public class JavaKMeansSuite implements Serializable {
Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
- JavaRDD<Vector> data = sc.parallelize(points, 2);
+ JavaRDD<Vector> data = jsc.parallelize(points, 2);
KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
assertEquals(1, model.clusterCenters().length);
assertEquals(expectedCenter, model.clusterCenters()[0]);
@@ -94,7 +101,7 @@ public class JavaKMeansSuite implements Serializable {
Vectors.dense(1.0, 3.0, 0.0),
Vectors.dense(1.0, 4.0, 6.0)
);
- JavaRDD<Vector> data = sc.parallelize(points, 2);
+ JavaRDD<Vector> data = jsc.parallelize(points, 2);
KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
JavaRDD<Integer> predictions = model.predict(data);
// Should be able to get the first prediction.
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
index db19b309f6..f16585aff4 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
@@ -27,37 +27,42 @@ import scala.Tuple3;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
-import static org.junit.Assert.assertArrayEquals;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.*;
-import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.SparkSession;
public class JavaLDASuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaLDA");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLDASuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
+
ArrayList<Tuple2<Long, Vector>> tinyCorpus = new ArrayList<>();
for (int i = 0; i < LDASuite.tinyCorpus().length; i++) {
- tinyCorpus.add(new Tuple2<>((Long)LDASuite.tinyCorpus()[i]._1(),
- LDASuite.tinyCorpus()[i]._2()));
+ tinyCorpus.add(new Tuple2<>((Long) LDASuite.tinyCorpus()[i]._1(),
+ LDASuite.tinyCorpus()[i]._2()));
}
- JavaRDD<Tuple2<Long, Vector>> tmpCorpus = sc.parallelize(tinyCorpus, 2);
+ JavaRDD<Tuple2<Long, Vector>> tmpCorpus = jsc.parallelize(tinyCorpus, 2);
corpus = JavaPairRDD.fromJavaRDD(tmpCorpus);
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -95,7 +100,7 @@ public class JavaLDASuite implements Serializable {
.setMaxIterations(5)
.setSeed(12345);
- DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus);
+ DistributedLDAModel model = (DistributedLDAModel) lda.run(corpus);
// Check: basic parameters
LocalLDAModel localModel = model.toLocal();
@@ -124,7 +129,7 @@ public class JavaLDASuite implements Serializable {
public Boolean call(Tuple2<Long, Vector> tuple2) {
return Vectors.norm(tuple2._2(), 1.0) != 0.0;
}
- });
+ });
assertEquals(topicDistributions.count(), nonEmptyCorpus.count());
// Check: javaTopTopicsPerDocuments
@@ -179,7 +184,7 @@ public class JavaLDASuite implements Serializable {
@Test
public void localLdaMethods() {
- JavaRDD<Tuple2<Long, Vector>> docs = sc.parallelize(toyData, 2);
+ JavaRDD<Tuple2<Long, Vector>> docs = jsc.parallelize(toyData, 2);
JavaPairRDD<Long, Vector> pairedDocs = JavaPairRDD.fromJavaRDD(docs);
// check: topicDistributions
@@ -191,7 +196,7 @@ public class JavaLDASuite implements Serializable {
// check: logLikelihood.
ArrayList<Tuple2<Long, Vector>> docsSingleWord = new ArrayList<>();
docsSingleWord.add(new Tuple2<>(0L, Vectors.dense(1.0, 0.0, 0.0)));
- JavaPairRDD<Long, Vector> single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord));
+ JavaPairRDD<Long, Vector> single = JavaPairRDD.fromJavaRDD(jsc.parallelize(docsSingleWord));
double logLikelihood = toyModel.logLikelihood(single);
}
@@ -199,7 +204,7 @@ public class JavaLDASuite implements Serializable {
private static int tinyVocabSize = LDASuite.tinyVocabSize();
private static Matrix tinyTopics = LDASuite.tinyTopics();
private static Tuple2<int[], double[]>[] tinyTopicDescription =
- LDASuite.tinyTopicDescription();
+ LDASuite.tinyTopicDescription();
private JavaPairRDD<Long, Vector> corpus;
private LocalLDAModel toyModel = LDASuite.toyModel();
private ArrayList<Tuple2<Long, Vector>> toyData = LDASuite.javaToyData();
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
index 62edbd3a29..d1d618f7de 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
@@ -27,8 +27,6 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Test;
-import static org.apache.spark.streaming.JavaTestUtils.*;
-
import org.apache.spark.SparkConf;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
@@ -36,6 +34,7 @@ import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import static org.apache.spark.streaming.JavaTestUtils.*;
public class JavaStreamingKMeansSuite implements Serializable {
diff --git a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
index fa4d334801..6a096d6386 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
@@ -31,27 +31,34 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
public class JavaRankingMetricsSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
private transient JavaRDD<Tuple2<List<Integer>, List<Integer>>> predictionAndLabels;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaRankingMetricsSuite");
- predictionAndLabels = sc.parallelize(Arrays.asList(
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaPCASuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
+
+ predictionAndLabels = jsc.parallelize(Arrays.asList(
Tuple2$.MODULE$.apply(
Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Arrays.asList(1, 2, 3, 4, 5)),
Tuple2$.MODULE$.apply(
- Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)),
+ Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)),
Tuple2$.MODULE$.apply(
- Arrays.asList(1, 2, 3, 4, 5), Arrays.<Integer>asList())), 2);
+ Arrays.asList(1, 2, 3, 4, 5), Arrays.<Integer>asList())), 2);
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
index 8a320afa4b..de50fb8c4f 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
@@ -29,19 +29,25 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.sql.SparkSession;
public class JavaTfIdfSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaTfIdfSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaPCASuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -49,7 +55,7 @@ public class JavaTfIdfSuite implements Serializable {
// The tests are to check Java compatibility.
HashingTF tf = new HashingTF();
@SuppressWarnings("unchecked")
- JavaRDD<List<String>> documents = sc.parallelize(Arrays.asList(
+ JavaRDD<List<String>> documents = jsc.parallelize(Arrays.asList(
Arrays.asList("this is a sentence".split(" ")),
Arrays.asList("this is another sentence".split(" ")),
Arrays.asList("this is still a sentence".split(" "))), 2);
@@ -59,7 +65,7 @@ public class JavaTfIdfSuite implements Serializable {
JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
List<Vector> localTfIdfs = tfIdfs.collect();
int indexOfThis = tf.indexOf("this");
- for (Vector v: localTfIdfs) {
+ for (Vector v : localTfIdfs) {
Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
}
}
@@ -69,7 +75,7 @@ public class JavaTfIdfSuite implements Serializable {
// The tests are to check Java compatibility.
HashingTF tf = new HashingTF();
@SuppressWarnings("unchecked")
- JavaRDD<List<String>> documents = sc.parallelize(Arrays.asList(
+ JavaRDD<List<String>> documents = jsc.parallelize(Arrays.asList(
Arrays.asList("this is a sentence".split(" ")),
Arrays.asList("this is another sentence".split(" ")),
Arrays.asList("this is still a sentence".split(" "))), 2);
@@ -79,7 +85,7 @@ public class JavaTfIdfSuite implements Serializable {
JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
List<Vector> localTfIdfs = tfIdfs.collect();
int indexOfThis = tf.indexOf("this");
- for (Vector v: localTfIdfs) {
+ for (Vector v : localTfIdfs) {
Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
}
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
index e13ed07e28..64885cc842 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
@@ -21,9 +21,10 @@ import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
+import com.google.common.base.Strings;
+
import scala.Tuple2;
-import com.google.common.base.Strings;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
@@ -31,19 +32,25 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
public class JavaWord2VecSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaWord2VecSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaPCASuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -53,7 +60,7 @@ public class JavaWord2VecSuite implements Serializable {
String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10);
List<String> words = Arrays.asList(sentence.split(" "));
List<List<String>> localDoc = Arrays.asList(words, words);
- JavaRDD<List<String>> doc = sc.parallelize(localDoc);
+ JavaRDD<List<String>> doc = jsc.parallelize(localDoc);
Word2Vec word2vec = new Word2Vec()
.setVectorSize(10)
.setSeed(42L);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
index 2bef7a8609..fdc19a5b3d 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
@@ -26,32 +26,37 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
+import org.apache.spark.sql.SparkSession;
public class JavaAssociationRulesSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaFPGrowth");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaAssociationRulesSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
public void runAssociationRules() {
@SuppressWarnings("unchecked")
- JavaRDD<FPGrowth.FreqItemset<String>> freqItemsets = sc.parallelize(Arrays.asList(
- new FreqItemset<String>(new String[] {"a"}, 15L),
- new FreqItemset<String>(new String[] {"b"}, 35L),
- new FreqItemset<String>(new String[] {"a", "b"}, 12L)
+ JavaRDD<FPGrowth.FreqItemset<String>> freqItemsets = jsc.parallelize(Arrays.asList(
+ new FreqItemset<String>(new String[]{"a"}, 15L),
+ new FreqItemset<String>(new String[]{"b"}, 35L),
+ new FreqItemset<String>(new String[]{"a", "b"}, 12L)
));
JavaRDD<AssociationRules.Rule<String>> results = (new AssociationRules()).run(freqItemsets);
}
}
-
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
index 916fff14a7..f235251e61 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
@@ -22,34 +22,41 @@ import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
+import static org.junit.Assert.assertEquals;
+
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
-import static org.junit.Assert.*;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils;
public class JavaFPGrowthSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaFPGrowth");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaFPGrowth")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
public void runFPGrowth() {
@SuppressWarnings("unchecked")
- JavaRDD<List<String>> rdd = sc.parallelize(Arrays.asList(
+ JavaRDD<List<String>> rdd = jsc.parallelize(Arrays.asList(
Arrays.asList("r z h k p".split(" ")),
Arrays.asList("z y x w v u t s".split(" ")),
Arrays.asList("s x o n r".split(" ")),
@@ -65,7 +72,7 @@ public class JavaFPGrowthSuite implements Serializable {
List<FPGrowth.FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect();
assertEquals(18, freqItemsets.size());
- for (FPGrowth.FreqItemset<String> itemset: freqItemsets) {
+ for (FPGrowth.FreqItemset<String> itemset : freqItemsets) {
// Test return types.
List<String> items = itemset.javaItems();
long freq = itemset.freq();
@@ -76,7 +83,7 @@ public class JavaFPGrowthSuite implements Serializable {
public void runFPGrowthSaveLoad() {
@SuppressWarnings("unchecked")
- JavaRDD<List<String>> rdd = sc.parallelize(Arrays.asList(
+ JavaRDD<List<String>> rdd = jsc.parallelize(Arrays.asList(
Arrays.asList("r z h k p".split(" ")),
Arrays.asList("z y x w v u t s".split(" ")),
Arrays.asList("s x o n r".split(" ")),
@@ -94,15 +101,15 @@ public class JavaFPGrowthSuite implements Serializable {
String outputPath = tempDir.getPath();
try {
- model.save(sc.sc(), outputPath);
+ model.save(spark.sparkContext(), outputPath);
@SuppressWarnings("unchecked")
FPGrowthModel<String> newModel =
- (FPGrowthModel<String>) FPGrowthModel.load(sc.sc(), outputPath);
+ (FPGrowthModel<String>) FPGrowthModel.load(spark.sparkContext(), outputPath);
List<FPGrowth.FreqItemset<String>> freqItemsets = newModel.freqItemsets().toJavaRDD()
.collect();
assertEquals(18, freqItemsets.size());
- for (FPGrowth.FreqItemset<String> itemset: freqItemsets) {
+ for (FPGrowth.FreqItemset<String> itemset : freqItemsets) {
// Test return types.
List<String> items = itemset.javaItems();
long freq = itemset.freq();
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
index 8a67793abc..bf7f1fc71b 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
@@ -29,25 +29,31 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils;
public class JavaPrefixSpanSuite {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaPrefixSpan");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaPrefixSpan")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
public void runPrefixSpan() {
- JavaRDD<List<List<Integer>>> sequences = sc.parallelize(Arrays.asList(
+ JavaRDD<List<List<Integer>>> sequences = jsc.parallelize(Arrays.asList(
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)),
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
@@ -61,7 +67,7 @@ public class JavaPrefixSpanSuite {
List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect();
Assert.assertEquals(5, localFreqSeqs.size());
// Check that each frequent sequence could be materialized.
- for (PrefixSpan.FreqSequence<Integer> freqSeq: localFreqSeqs) {
+ for (PrefixSpan.FreqSequence<Integer> freqSeq : localFreqSeqs) {
List<List<Integer>> seq = freqSeq.javaSequence();
long freq = freqSeq.freq();
}
@@ -69,7 +75,7 @@ public class JavaPrefixSpanSuite {
@Test
public void runPrefixSpanSaveLoad() {
- JavaRDD<List<List<Integer>>> sequences = sc.parallelize(Arrays.asList(
+ JavaRDD<List<List<Integer>>> sequences = jsc.parallelize(Arrays.asList(
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)),
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
@@ -85,13 +91,13 @@ public class JavaPrefixSpanSuite {
String outputPath = tempDir.getPath();
try {
- model.save(sc.sc(), outputPath);
- PrefixSpanModel newModel = PrefixSpanModel.load(sc.sc(), outputPath);
+ model.save(spark.sparkContext(), outputPath);
+ PrefixSpanModel newModel = PrefixSpanModel.load(spark.sparkContext(), outputPath);
JavaRDD<FreqSequence<Integer>> freqSeqs = newModel.freqSequences().toJavaRDD();
List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect();
Assert.assertEquals(5, localFreqSeqs.size());
// Check that each frequent sequence could be materialized.
- for (PrefixSpan.FreqSequence<Integer> freqSeq: localFreqSeqs) {
+ for (PrefixSpan.FreqSequence<Integer> freqSeq : localFreqSeqs) {
List<List<Integer>> seq = freqSeq.javaSequence();
long freq = freqSeq.freq();
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
index 8beea102ef..92fc57871c 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
@@ -17,147 +17,149 @@
package org.apache.spark.mllib.linalg;
-import static org.junit.Assert.*;
-import org.junit.Test;
-
import java.io.Serializable;
import java.util.Random;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+import org.junit.Test;
+
public class JavaMatricesSuite implements Serializable {
- @Test
- public void randMatrixConstruction() {
- Random rng = new Random(24);
- Matrix r = Matrices.rand(3, 4, rng);
- rng.setSeed(24);
- DenseMatrix dr = DenseMatrix.rand(3, 4, rng);
- assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
-
- rng.setSeed(24);
- Matrix rn = Matrices.randn(3, 4, rng);
- rng.setSeed(24);
- DenseMatrix drn = DenseMatrix.randn(3, 4, rng);
- assertArrayEquals(rn.toArray(), drn.toArray(), 0.0);
-
- rng.setSeed(24);
- Matrix s = Matrices.sprand(3, 4, 0.5, rng);
- rng.setSeed(24);
- SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng);
- assertArrayEquals(s.toArray(), sr.toArray(), 0.0);
-
- rng.setSeed(24);
- Matrix sn = Matrices.sprandn(3, 4, 0.5, rng);
- rng.setSeed(24);
- SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng);
- assertArrayEquals(sn.toArray(), srn.toArray(), 0.0);
- }
-
- @Test
- public void identityMatrixConstruction() {
- Matrix r = Matrices.eye(2);
- DenseMatrix dr = DenseMatrix.eye(2);
- SparseMatrix sr = SparseMatrix.speye(2);
- assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
- assertArrayEquals(sr.toArray(), dr.toArray(), 0.0);
- assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0);
- }
-
- @Test
- public void diagonalMatrixConstruction() {
- Vector v = Vectors.dense(1.0, 0.0, 2.0);
- Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0});
-
- Matrix m = Matrices.diag(v);
- Matrix sm = Matrices.diag(sv);
- DenseMatrix d = DenseMatrix.diag(v);
- DenseMatrix sd = DenseMatrix.diag(sv);
- SparseMatrix s = SparseMatrix.spdiag(v);
- SparseMatrix ss = SparseMatrix.spdiag(sv);
-
- assertArrayEquals(m.toArray(), sm.toArray(), 0.0);
- assertArrayEquals(d.toArray(), sm.toArray(), 0.0);
- assertArrayEquals(d.toArray(), sd.toArray(), 0.0);
- assertArrayEquals(sd.toArray(), s.toArray(), 0.0);
- assertArrayEquals(s.toArray(), ss.toArray(), 0.0);
- assertArrayEquals(s.values(), ss.values(), 0.0);
- assertEquals(2, s.values().length);
- assertEquals(2, ss.values().length);
- assertEquals(4, s.colPtrs().length);
- assertEquals(4, ss.colPtrs().length);
- }
-
- @Test
- public void zerosMatrixConstruction() {
- Matrix z = Matrices.zeros(2, 2);
- Matrix one = Matrices.ones(2, 2);
- DenseMatrix dz = DenseMatrix.zeros(2, 2);
- DenseMatrix done = DenseMatrix.ones(2, 2);
-
- assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
- assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
- assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
- assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
- }
-
- @Test
- public void sparseDenseConversion() {
- int m = 3;
- int n = 2;
- double[] values = new double[]{1.0, 2.0, 4.0, 5.0};
- double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0};
- int[] colPtrs = new int[]{0, 2, 4};
- int[] rowIndices = new int[]{0, 1, 1, 2};
-
- SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values);
- DenseMatrix deMat1 = new DenseMatrix(m, n, allValues);
-
- SparseMatrix spMat2 = deMat1.toSparse();
- DenseMatrix deMat2 = spMat1.toDense();
-
- assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0);
- assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0);
- }
-
- @Test
- public void concatenateMatrices() {
- int m = 3;
- int n = 2;
-
- Random rng = new Random(42);
- SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng);
- rng.setSeed(42);
- DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng);
- Matrix deMat2 = Matrices.eye(3);
- Matrix spMat2 = Matrices.speye(3);
- Matrix deMat3 = Matrices.eye(2);
- Matrix spMat3 = Matrices.speye(2);
-
- Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2});
- Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2});
- Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2});
- Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2});
-
- assertEquals(3, deHorz1.numRows());
- assertEquals(3, deHorz2.numRows());
- assertEquals(3, deHorz3.numRows());
- assertEquals(3, spHorz.numRows());
- assertEquals(5, deHorz1.numCols());
- assertEquals(5, deHorz2.numCols());
- assertEquals(5, deHorz3.numCols());
- assertEquals(5, spHorz.numCols());
-
- Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3});
- Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3});
- Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3});
- Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3});
-
- assertEquals(5, deVert1.numRows());
- assertEquals(5, deVert2.numRows());
- assertEquals(5, deVert3.numRows());
- assertEquals(5, spVert.numRows());
- assertEquals(2, deVert1.numCols());
- assertEquals(2, deVert2.numCols());
- assertEquals(2, deVert3.numCols());
- assertEquals(2, spVert.numCols());
- }
+ @Test
+ public void randMatrixConstruction() {
+ Random rng = new Random(24);
+ Matrix r = Matrices.rand(3, 4, rng);
+ rng.setSeed(24);
+ DenseMatrix dr = DenseMatrix.rand(3, 4, rng);
+ assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
+
+ rng.setSeed(24);
+ Matrix rn = Matrices.randn(3, 4, rng);
+ rng.setSeed(24);
+ DenseMatrix drn = DenseMatrix.randn(3, 4, rng);
+ assertArrayEquals(rn.toArray(), drn.toArray(), 0.0);
+
+ rng.setSeed(24);
+ Matrix s = Matrices.sprand(3, 4, 0.5, rng);
+ rng.setSeed(24);
+ SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng);
+ assertArrayEquals(s.toArray(), sr.toArray(), 0.0);
+
+ rng.setSeed(24);
+ Matrix sn = Matrices.sprandn(3, 4, 0.5, rng);
+ rng.setSeed(24);
+ SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng);
+ assertArrayEquals(sn.toArray(), srn.toArray(), 0.0);
+ }
+
+ @Test
+ public void identityMatrixConstruction() {
+ Matrix r = Matrices.eye(2);
+ DenseMatrix dr = DenseMatrix.eye(2);
+ SparseMatrix sr = SparseMatrix.speye(2);
+ assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
+ assertArrayEquals(sr.toArray(), dr.toArray(), 0.0);
+ assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0);
+ }
+
+ @Test
+ public void diagonalMatrixConstruction() {
+ Vector v = Vectors.dense(1.0, 0.0, 2.0);
+ Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0});
+
+ Matrix m = Matrices.diag(v);
+ Matrix sm = Matrices.diag(sv);
+ DenseMatrix d = DenseMatrix.diag(v);
+ DenseMatrix sd = DenseMatrix.diag(sv);
+ SparseMatrix s = SparseMatrix.spdiag(v);
+ SparseMatrix ss = SparseMatrix.spdiag(sv);
+
+ assertArrayEquals(m.toArray(), sm.toArray(), 0.0);
+ assertArrayEquals(d.toArray(), sm.toArray(), 0.0);
+ assertArrayEquals(d.toArray(), sd.toArray(), 0.0);
+ assertArrayEquals(sd.toArray(), s.toArray(), 0.0);
+ assertArrayEquals(s.toArray(), ss.toArray(), 0.0);
+ assertArrayEquals(s.values(), ss.values(), 0.0);
+ assertEquals(2, s.values().length);
+ assertEquals(2, ss.values().length);
+ assertEquals(4, s.colPtrs().length);
+ assertEquals(4, ss.colPtrs().length);
+ }
+
+ @Test
+ public void zerosMatrixConstruction() {
+ Matrix z = Matrices.zeros(2, 2);
+ Matrix one = Matrices.ones(2, 2);
+ DenseMatrix dz = DenseMatrix.zeros(2, 2);
+ DenseMatrix done = DenseMatrix.ones(2, 2);
+
+ assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
+ assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
+ assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
+ assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
+ }
+
+ @Test
+ public void sparseDenseConversion() {
+ int m = 3;
+ int n = 2;
+ double[] values = new double[]{1.0, 2.0, 4.0, 5.0};
+ double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0};
+ int[] colPtrs = new int[]{0, 2, 4};
+ int[] rowIndices = new int[]{0, 1, 1, 2};
+
+ SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values);
+ DenseMatrix deMat1 = new DenseMatrix(m, n, allValues);
+
+ SparseMatrix spMat2 = deMat1.toSparse();
+ DenseMatrix deMat2 = spMat1.toDense();
+
+ assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0);
+ assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0);
+ }
+
+ @Test
+ public void concatenateMatrices() {
+ int m = 3;
+ int n = 2;
+
+ Random rng = new Random(42);
+ SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng);
+ rng.setSeed(42);
+ DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng);
+ Matrix deMat2 = Matrices.eye(3);
+ Matrix spMat2 = Matrices.speye(3);
+ Matrix deMat3 = Matrices.eye(2);
+ Matrix spMat3 = Matrices.speye(2);
+
+ Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2});
+ Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2});
+ Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2});
+ Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2});
+
+ assertEquals(3, deHorz1.numRows());
+ assertEquals(3, deHorz2.numRows());
+ assertEquals(3, deHorz3.numRows());
+ assertEquals(3, spHorz.numRows());
+ assertEquals(5, deHorz1.numCols());
+ assertEquals(5, deHorz2.numCols());
+ assertEquals(5, deHorz3.numCols());
+ assertEquals(5, spHorz.numCols());
+
+ Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3});
+ Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3});
+ Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3});
+ Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3});
+
+ assertEquals(5, deVert1.numRows());
+ assertEquals(5, deVert2.numRows());
+ assertEquals(5, deVert3.numRows());
+ assertEquals(5, spVert.numRows());
+ assertEquals(2, deVert1.numCols());
+ assertEquals(2, deVert2.numCols());
+ assertEquals(2, deVert3.numCols());
+ assertEquals(2, spVert.numCols());
+ }
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
index 4ba8e543a9..817b962c75 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
@@ -20,10 +20,11 @@ package org.apache.spark.mllib.linalg;
import java.io.Serializable;
import java.util.Arrays;
+import static org.junit.Assert.assertArrayEquals;
+
import scala.Tuple2;
import org.junit.Test;
-import static org.junit.Assert.*;
public class JavaVectorsSuite implements Serializable {
@@ -37,8 +38,8 @@ public class JavaVectorsSuite implements Serializable {
public void sparseArrayConstruction() {
@SuppressWarnings("unchecked")
Vector v = Vectors.sparse(3, Arrays.asList(
- new Tuple2<>(0, 2.0),
- new Tuple2<>(2, 3.0)));
+ new Tuple2<>(0, 2.0),
+ new Tuple2<>(2, 3.0)));
assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0);
}
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
index be58691f4d..b449108a9b 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
@@ -20,29 +20,35 @@ package org.apache.spark.mllib.random;
import java.io.Serializable;
import java.util.Arrays;
-import org.apache.spark.api.java.JavaRDD;
-import org.junit.Assert;
import org.junit.After;
+import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaDoubleRDD;
+import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.sql.SparkSession;
import static org.apache.spark.mllib.random.RandomRDDs.*;
public class JavaRandomRDDsSuite {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaRandomRDDsSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaRandomRDDsSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -50,10 +56,10 @@ public class JavaRandomRDDsSuite {
long m = 1000L;
int p = 2;
long seed = 1L;
- JavaDoubleRDD rdd1 = uniformJavaRDD(sc, m);
- JavaDoubleRDD rdd2 = uniformJavaRDD(sc, m, p);
- JavaDoubleRDD rdd3 = uniformJavaRDD(sc, m, p, seed);
- for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaDoubleRDD rdd1 = uniformJavaRDD(jsc, m);
+ JavaDoubleRDD rdd2 = uniformJavaRDD(jsc, m, p);
+ JavaDoubleRDD rdd3 = uniformJavaRDD(jsc, m, p, seed);
+ for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
}
}
@@ -63,10 +69,10 @@ public class JavaRandomRDDsSuite {
long m = 1000L;
int p = 2;
long seed = 1L;
- JavaDoubleRDD rdd1 = normalJavaRDD(sc, m);
- JavaDoubleRDD rdd2 = normalJavaRDD(sc, m, p);
- JavaDoubleRDD rdd3 = normalJavaRDD(sc, m, p, seed);
- for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaDoubleRDD rdd1 = normalJavaRDD(jsc, m);
+ JavaDoubleRDD rdd2 = normalJavaRDD(jsc, m, p);
+ JavaDoubleRDD rdd3 = normalJavaRDD(jsc, m, p, seed);
+ for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
}
}
@@ -78,10 +84,10 @@ public class JavaRandomRDDsSuite {
long m = 1000L;
int p = 2;
long seed = 1L;
- JavaDoubleRDD rdd1 = logNormalJavaRDD(sc, mean, std, m);
- JavaDoubleRDD rdd2 = logNormalJavaRDD(sc, mean, std, m, p);
- JavaDoubleRDD rdd3 = logNormalJavaRDD(sc, mean, std, m, p, seed);
- for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaDoubleRDD rdd1 = logNormalJavaRDD(jsc, mean, std, m);
+ JavaDoubleRDD rdd2 = logNormalJavaRDD(jsc, mean, std, m, p);
+ JavaDoubleRDD rdd3 = logNormalJavaRDD(jsc, mean, std, m, p, seed);
+ for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
}
}
@@ -92,10 +98,10 @@ public class JavaRandomRDDsSuite {
long m = 1000L;
int p = 2;
long seed = 1L;
- JavaDoubleRDD rdd1 = poissonJavaRDD(sc, mean, m);
- JavaDoubleRDD rdd2 = poissonJavaRDD(sc, mean, m, p);
- JavaDoubleRDD rdd3 = poissonJavaRDD(sc, mean, m, p, seed);
- for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaDoubleRDD rdd1 = poissonJavaRDD(jsc, mean, m);
+ JavaDoubleRDD rdd2 = poissonJavaRDD(jsc, mean, m, p);
+ JavaDoubleRDD rdd3 = poissonJavaRDD(jsc, mean, m, p, seed);
+ for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
}
}
@@ -106,10 +112,10 @@ public class JavaRandomRDDsSuite {
long m = 1000L;
int p = 2;
long seed = 1L;
- JavaDoubleRDD rdd1 = exponentialJavaRDD(sc, mean, m);
- JavaDoubleRDD rdd2 = exponentialJavaRDD(sc, mean, m, p);
- JavaDoubleRDD rdd3 = exponentialJavaRDD(sc, mean, m, p, seed);
- for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaDoubleRDD rdd1 = exponentialJavaRDD(jsc, mean, m);
+ JavaDoubleRDD rdd2 = exponentialJavaRDD(jsc, mean, m, p);
+ JavaDoubleRDD rdd3 = exponentialJavaRDD(jsc, mean, m, p, seed);
+ for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
}
}
@@ -117,14 +123,14 @@ public class JavaRandomRDDsSuite {
@Test
public void testGammaRDD() {
double shape = 1.0;
- double scale = 2.0;
+ double jscale = 2.0;
long m = 1000L;
int p = 2;
long seed = 1L;
- JavaDoubleRDD rdd1 = gammaJavaRDD(sc, shape, scale, m);
- JavaDoubleRDD rdd2 = gammaJavaRDD(sc, shape, scale, m, p);
- JavaDoubleRDD rdd3 = gammaJavaRDD(sc, shape, scale, m, p, seed);
- for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaDoubleRDD rdd1 = gammaJavaRDD(jsc, shape, jscale, m);
+ JavaDoubleRDD rdd2 = gammaJavaRDD(jsc, shape, jscale, m, p);
+ JavaDoubleRDD rdd3 = gammaJavaRDD(jsc, shape, jscale, m, p, seed);
+ for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
}
}
@@ -137,10 +143,10 @@ public class JavaRandomRDDsSuite {
int n = 10;
int p = 2;
long seed = 1L;
- JavaRDD<Vector> rdd1 = uniformJavaVectorRDD(sc, m, n);
- JavaRDD<Vector> rdd2 = uniformJavaVectorRDD(sc, m, n, p);
- JavaRDD<Vector> rdd3 = uniformJavaVectorRDD(sc, m, n, p, seed);
- for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaRDD<Vector> rdd1 = uniformJavaVectorRDD(jsc, m, n);
+ JavaRDD<Vector> rdd2 = uniformJavaVectorRDD(jsc, m, n, p);
+ JavaRDD<Vector> rdd3 = uniformJavaVectorRDD(jsc, m, n, p, seed);
+ for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@@ -153,10 +159,10 @@ public class JavaRandomRDDsSuite {
int n = 10;
int p = 2;
long seed = 1L;
- JavaRDD<Vector> rdd1 = normalJavaVectorRDD(sc, m, n);
- JavaRDD<Vector> rdd2 = normalJavaVectorRDD(sc, m, n, p);
- JavaRDD<Vector> rdd3 = normalJavaVectorRDD(sc, m, n, p, seed);
- for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaRDD<Vector> rdd1 = normalJavaVectorRDD(jsc, m, n);
+ JavaRDD<Vector> rdd2 = normalJavaVectorRDD(jsc, m, n, p);
+ JavaRDD<Vector> rdd3 = normalJavaVectorRDD(jsc, m, n, p, seed);
+ for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@@ -171,10 +177,10 @@ public class JavaRandomRDDsSuite {
int n = 10;
int p = 2;
long seed = 1L;
- JavaRDD<Vector> rdd1 = logNormalJavaVectorRDD(sc, mean, std, m, n);
- JavaRDD<Vector> rdd2 = logNormalJavaVectorRDD(sc, mean, std, m, n, p);
- JavaRDD<Vector> rdd3 = logNormalJavaVectorRDD(sc, mean, std, m, n, p, seed);
- for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaRDD<Vector> rdd1 = logNormalJavaVectorRDD(jsc, mean, std, m, n);
+ JavaRDD<Vector> rdd2 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p);
+ JavaRDD<Vector> rdd3 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p, seed);
+ for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@@ -188,10 +194,10 @@ public class JavaRandomRDDsSuite {
int n = 10;
int p = 2;
long seed = 1L;
- JavaRDD<Vector> rdd1 = poissonJavaVectorRDD(sc, mean, m, n);
- JavaRDD<Vector> rdd2 = poissonJavaVectorRDD(sc, mean, m, n, p);
- JavaRDD<Vector> rdd3 = poissonJavaVectorRDD(sc, mean, m, n, p, seed);
- for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaRDD<Vector> rdd1 = poissonJavaVectorRDD(jsc, mean, m, n);
+ JavaRDD<Vector> rdd2 = poissonJavaVectorRDD(jsc, mean, m, n, p);
+ JavaRDD<Vector> rdd3 = poissonJavaVectorRDD(jsc, mean, m, n, p, seed);
+ for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@@ -205,10 +211,10 @@ public class JavaRandomRDDsSuite {
int n = 10;
int p = 2;
long seed = 1L;
- JavaRDD<Vector> rdd1 = exponentialJavaVectorRDD(sc, mean, m, n);
- JavaRDD<Vector> rdd2 = exponentialJavaVectorRDD(sc, mean, m, n, p);
- JavaRDD<Vector> rdd3 = exponentialJavaVectorRDD(sc, mean, m, n, p, seed);
- for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaRDD<Vector> rdd1 = exponentialJavaVectorRDD(jsc, mean, m, n);
+ JavaRDD<Vector> rdd2 = exponentialJavaVectorRDD(jsc, mean, m, n, p);
+ JavaRDD<Vector> rdd3 = exponentialJavaVectorRDD(jsc, mean, m, n, p, seed);
+ for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@@ -218,15 +224,15 @@ public class JavaRandomRDDsSuite {
@SuppressWarnings("unchecked")
public void testGammaVectorRDD() {
double shape = 1.0;
- double scale = 2.0;
+ double jscale = 2.0;
long m = 100L;
int n = 10;
int p = 2;
long seed = 1L;
- JavaRDD<Vector> rdd1 = gammaJavaVectorRDD(sc, shape, scale, m, n);
- JavaRDD<Vector> rdd2 = gammaJavaVectorRDD(sc, shape, scale, m, n, p);
- JavaRDD<Vector> rdd3 = gammaJavaVectorRDD(sc, shape, scale, m, n, p, seed);
- for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaRDD<Vector> rdd1 = gammaJavaVectorRDD(jsc, shape, jscale, m, n);
+ JavaRDD<Vector> rdd2 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p);
+ JavaRDD<Vector> rdd3 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p, seed);
+ for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@@ -238,10 +244,10 @@ public class JavaRandomRDDsSuite {
long seed = 1L;
int numPartitions = 0;
StringGenerator gen = new StringGenerator();
- JavaRDD<String> rdd1 = randomJavaRDD(sc, gen, size);
- JavaRDD<String> rdd2 = randomJavaRDD(sc, gen, size, numPartitions);
- JavaRDD<String> rdd3 = randomJavaRDD(sc, gen, size, numPartitions, seed);
- for (JavaRDD<String> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaRDD<String> rdd1 = randomJavaRDD(jsc, gen, size);
+ JavaRDD<String> rdd2 = randomJavaRDD(jsc, gen, size, numPartitions);
+ JavaRDD<String> rdd3 = randomJavaRDD(jsc, gen, size, numPartitions, seed);
+ for (JavaRDD<String> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(size, rdd.count());
Assert.assertEquals(2, rdd.first().length());
}
@@ -255,10 +261,10 @@ public class JavaRandomRDDsSuite {
int n = 10;
int p = 2;
long seed = 1L;
- JavaRDD<Vector> rdd1 = randomJavaVectorRDD(sc, generator, m, n);
- JavaRDD<Vector> rdd2 = randomJavaVectorRDD(sc, generator, m, n, p);
- JavaRDD<Vector> rdd3 = randomJavaVectorRDD(sc, generator, m, n, p, seed);
- for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaRDD<Vector> rdd1 = randomJavaVectorRDD(jsc, generator, m, n);
+ JavaRDD<Vector> rdd2 = randomJavaVectorRDD(jsc, generator, m, n, p);
+ JavaRDD<Vector> rdd3 = randomJavaVectorRDD(jsc, generator, m, n, p, seed);
+ for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@@ -271,10 +277,12 @@ class StringGenerator implements RandomDataGenerator<String>, Serializable {
public String nextValue() {
return "42";
}
+
@Override
public StringGenerator copy() {
return new StringGenerator();
}
+
@Override
public void setSeed(long seed) {
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
index d0bf7f556d..aa784054d5 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
@@ -32,40 +32,46 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
public class JavaALSSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaALS");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaALS")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
private void validatePrediction(
- MatrixFactorizationModel model,
- int users,
- int products,
- double[] trueRatings,
- double matchThreshold,
- boolean implicitPrefs,
- double[] truePrefs) {
+ MatrixFactorizationModel model,
+ int users,
+ int products,
+ double[] trueRatings,
+ double matchThreshold,
+ boolean implicitPrefs,
+ double[] truePrefs) {
List<Tuple2<Integer, Integer>> localUsersProducts = new ArrayList<>(users * products);
- for (int u=0; u < users; ++u) {
- for (int p=0; p < products; ++p) {
+ for (int u = 0; u < users; ++u) {
+ for (int p = 0; p < products; ++p) {
localUsersProducts.add(new Tuple2<>(u, p));
}
}
- JavaPairRDD<Integer, Integer> usersProducts = sc.parallelizePairs(localUsersProducts);
+ JavaPairRDD<Integer, Integer> usersProducts = jsc.parallelizePairs(localUsersProducts);
List<Rating> predictedRatings = model.predict(usersProducts).collect();
Assert.assertEquals(users * products, predictedRatings.size());
if (!implicitPrefs) {
- for (Rating r: predictedRatings) {
+ for (Rating r : predictedRatings) {
double prediction = r.rating();
double correct = trueRatings[r.product() * users + r.user()];
Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
@@ -76,7 +82,7 @@ public class JavaALSSuite implements Serializable {
// (ref Mahout's implicit ALS tests)
double sqErr = 0.0;
double denom = 0.0;
- for (Rating r: predictedRatings) {
+ for (Rating r : predictedRatings) {
double prediction = r.rating();
double truePref = truePrefs[r.product() * users + r.user()];
double confidence = 1.0 +
@@ -98,9 +104,9 @@ public class JavaALSSuite implements Serializable {
int users = 50;
int products = 100;
Tuple3<List<Rating>, double[], double[]> testData =
- ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
+ ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
- JavaRDD<Rating> data = sc.parallelize(testData._1());
+ JavaRDD<Rating> data = jsc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3());
}
@@ -112,9 +118,9 @@ public class JavaALSSuite implements Serializable {
int users = 100;
int products = 200;
Tuple3<List<Rating>, double[], double[]> testData =
- ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
+ ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
- JavaRDD<Rating> data = sc.parallelize(testData._1());
+ JavaRDD<Rating> data = jsc.parallelize(testData._1());
MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations)
@@ -129,9 +135,9 @@ public class JavaALSSuite implements Serializable {
int users = 80;
int products = 160;
Tuple3<List<Rating>, double[], double[]> testData =
- ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
+ ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
- JavaRDD<Rating> data = sc.parallelize(testData._1());
+ JavaRDD<Rating> data = jsc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
}
@@ -143,9 +149,9 @@ public class JavaALSSuite implements Serializable {
int users = 100;
int products = 200;
Tuple3<List<Rating>, double[], double[]> testData =
- ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
+ ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
- JavaRDD<Rating> data = sc.parallelize(testData._1());
+ JavaRDD<Rating> data = jsc.parallelize(testData._1());
MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations)
@@ -161,9 +167,9 @@ public class JavaALSSuite implements Serializable {
int users = 80;
int products = 160;
Tuple3<List<Rating>, double[], double[]> testData =
- ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, true);
+ ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, true);
- JavaRDD<Rating> data = sc.parallelize(testData._1());
+ JavaRDD<Rating> data = jsc.parallelize(testData._1());
MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations)
.setImplicitPrefs(true)
@@ -179,8 +185,8 @@ public class JavaALSSuite implements Serializable {
int users = 200;
int products = 50;
List<Rating> testData = ALSSuite.generateRatingsAsJava(
- users, products, features, 0.7, true, false)._1();
- JavaRDD<Rating> data = sc.parallelize(testData);
+ users, products, features, 0.7, true, false)._1();
+ JavaRDD<Rating> data = jsc.parallelize(testData);
MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations)
.setImplicitPrefs(true)
@@ -193,7 +199,7 @@ public class JavaALSSuite implements Serializable {
private static void validateRecommendations(Rating[] recommendations, int howMany) {
Assert.assertEquals(howMany, recommendations.length);
for (int i = 1; i < recommendations.length; i++) {
- Assert.assertTrue(recommendations[i-1].rating() >= recommendations[i].rating());
+ Assert.assertTrue(recommendations[i - 1].rating() >= recommendations[i].rating());
}
Assert.assertTrue(recommendations[0].rating() > 0.7);
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java
index 3db9b39e74..8b05675d65 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java
@@ -32,15 +32,17 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
public class JavaIsotonicRegressionSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
private static List<Tuple3<Double, Double, Double>> generateIsotonicInput(double[] labels) {
List<Tuple3<Double, Double, Double>> input = new ArrayList<>(labels.length);
for (int i = 1; i <= labels.length; i++) {
- input.add(new Tuple3<>(labels[i-1], (double) i, 1.0));
+ input.add(new Tuple3<>(labels[i - 1], (double) i, 1.0));
}
return input;
@@ -48,20 +50,24 @@ public class JavaIsotonicRegressionSuite implements Serializable {
private IsotonicRegressionModel runIsotonicRegression(double[] labels) {
JavaRDD<Tuple3<Double, Double, Double>> trainRDD =
- sc.parallelize(generateIsotonicInput(labels), 2).cache();
+ jsc.parallelize(generateIsotonicInput(labels), 2).cache();
return new IsotonicRegression().run(trainRDD);
}
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLinearRegressionSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -70,7 +76,7 @@ public class JavaIsotonicRegressionSuite implements Serializable {
runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
Assert.assertArrayEquals(
- new double[] {1, 2, 7.0/3, 7.0/3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1.0e-14);
+ new double[]{1, 2, 7.0 / 3, 7.0 / 3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1.0e-14);
}
@Test
@@ -78,7 +84,7 @@ public class JavaIsotonicRegressionSuite implements Serializable {
IsotonicRegressionModel model =
runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
- JavaDoubleRDD testRDD = sc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0));
+ JavaDoubleRDD testRDD = jsc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0));
List<Double> predictions = model.predict(testRDD).collect();
Assert.assertEquals(1.0, predictions.get(0).doubleValue(), 1.0e-14);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
index 8950b48888..098bac3bed 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
@@ -28,24 +28,30 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.util.LinearDataGenerator;
+import org.apache.spark.sql.SparkSession;
public class JavaLassoSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaLassoSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLassoSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
int validatePrediction(List<LabeledPoint> validationData, LassoModel model) {
int numAccurate = 0;
- for (LabeledPoint point: validationData) {
+ for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features());
// A prediction is off if the prediction is more than 0.5 away from expected value.
if (Math.abs(prediction - point.label()) <= 0.5) {
@@ -61,15 +67,15 @@ public class JavaLassoSuite implements Serializable {
double A = 0.0;
double[] weights = {-1.5, 1.0e-2};
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
- weights, nPoints, 42, 0.1), 2).cache();
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
+ weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData =
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LassoWithSGD lassoSGDImpl = new LassoWithSGD();
lassoSGDImpl.optimizer().setStepSize(1.0)
- .setRegParam(0.01)
- .setNumIterations(20);
+ .setRegParam(0.01)
+ .setNumIterations(20);
LassoModel model = lassoSGDImpl.run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model);
@@ -82,10 +88,10 @@ public class JavaLassoSuite implements Serializable {
double A = 0.0;
double[] weights = {-1.5, 1.0e-2};
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
- weights, nPoints, 42, 0.1), 2).cache();
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
+ weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData =
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
index 24c4c20d9a..35087a5e46 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
@@ -25,34 +25,40 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.util.LinearDataGenerator;
+import org.apache.spark.sql.SparkSession;
public class JavaLinearRegressionSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLinearRegressionSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
int validatePrediction(List<LabeledPoint> validationData, LinearRegressionModel model) {
int numAccurate = 0;
- for (LabeledPoint point: validationData) {
- Double prediction = model.predict(point.features());
- // A prediction is off if the prediction is more than 0.5 away from expected value.
- if (Math.abs(prediction - point.label()) <= 0.5) {
- numAccurate++;
- }
+ for (LabeledPoint point : validationData) {
+ Double prediction = model.predict(point.features());
+ // A prediction is off if the prediction is more than 0.5 away from expected value.
+ if (Math.abs(prediction - point.label()) <= 0.5) {
+ numAccurate++;
+ }
}
return numAccurate;
}
@@ -63,10 +69,10 @@ public class JavaLinearRegressionSuite implements Serializable {
double A = 3.0;
double[] weights = {10, 10};
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData =
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
linSGDImpl.setIntercept(true);
@@ -82,10 +88,10 @@ public class JavaLinearRegressionSuite implements Serializable {
double A = 0.0;
double[] weights = {10, 10};
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData =
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100);
@@ -98,7 +104,7 @@ public class JavaLinearRegressionSuite implements Serializable {
int nPoints = 100;
double A = 0.0;
double[] weights = {10, 10};
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
index c56db703ea..b2efb2e72e 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
@@ -29,25 +29,31 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.util.LinearDataGenerator;
+import org.apache.spark.sql.SparkSession;
public class JavaRidgeRegressionSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaRidgeRegressionSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
private static double predictionError(List<LabeledPoint> validationData,
RidgeRegressionModel model) {
double errorSum = 0;
- for (LabeledPoint point: validationData) {
+ for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features());
errorSum += (prediction - point.label()) * (prediction - point.label());
}
@@ -68,9 +74,9 @@ public class JavaRidgeRegressionSuite implements Serializable {
public void runRidgeRegressionUsingConstructor() {
int numExamples = 50;
int numFeatures = 20;
- List<LabeledPoint> data = generateRidgeData(2*numExamples, numFeatures, 10.0);
+ List<LabeledPoint> data = generateRidgeData(2 * numExamples, numFeatures, 10.0);
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, numExamples));
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(data.subList(0, numExamples));
List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD();
@@ -94,7 +100,7 @@ public class JavaRidgeRegressionSuite implements Serializable {
int numFeatures = 20;
List<LabeledPoint> data = generateRidgeData(2 * numExamples, numFeatures, 10.0);
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, numExamples));
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(data.subList(0, numExamples));
List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
index 5f1d5987e8..373417d3ba 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
@@ -24,13 +24,11 @@ import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
-
-import static org.apache.spark.streaming.JavaTestUtils.*;
import static org.junit.Assert.assertEquals;
import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaDoubleRDD;
+import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
@@ -38,36 +36,42 @@ import org.apache.spark.mllib.stat.test.BinarySample;
import org.apache.spark.mllib.stat.test.ChiSqTestResult;
import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult;
import org.apache.spark.mllib.stat.test.StreamingTest;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import static org.apache.spark.streaming.JavaTestUtils.*;
public class JavaStatisticsSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
private transient JavaStreamingContext ssc;
@Before
public void setUp() {
SparkConf conf = new SparkConf()
- .setMaster("local[2]")
- .setAppName("JavaStatistics")
.set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
- sc = new JavaSparkContext(conf);
- ssc = new JavaStreamingContext(sc, new Duration(1000));
+ spark = SparkSession.builder()
+ .master("local[2]")
+ .appName("JavaStatistics")
+ .config(conf)
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
+ ssc = new JavaStreamingContext(jsc, new Duration(1000));
ssc.checkpoint("checkpoint");
}
@After
public void tearDown() {
+ spark.stop();
ssc.stop();
- ssc = null;
- sc = null;
+ spark = null;
}
@Test
public void testCorr() {
- JavaRDD<Double> x = sc.parallelize(Arrays.asList(1.0, 2.0, 3.0, 4.0));
- JavaRDD<Double> y = sc.parallelize(Arrays.asList(1.1, 2.2, 3.1, 4.3));
+ JavaRDD<Double> x = jsc.parallelize(Arrays.asList(1.0, 2.0, 3.0, 4.0));
+ JavaRDD<Double> y = jsc.parallelize(Arrays.asList(1.1, 2.2, 3.1, 4.3));
Double corr1 = Statistics.corr(x, y);
Double corr2 = Statistics.corr(x, y, "pearson");
@@ -77,7 +81,7 @@ public class JavaStatisticsSuite implements Serializable {
@Test
public void kolmogorovSmirnovTest() {
- JavaDoubleRDD data = sc.parallelizeDoubles(Arrays.asList(0.2, 1.0, -1.0, 2.0));
+ JavaDoubleRDD data = jsc.parallelizeDoubles(Arrays.asList(0.2, 1.0, -1.0, 2.0));
KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm");
KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest(
data, "norm", 0.0, 1.0);
@@ -85,7 +89,7 @@ public class JavaStatisticsSuite implements Serializable {
@Test
public void chiSqTest() {
- JavaRDD<LabeledPoint> data = sc.parallelize(Arrays.asList(
+ JavaRDD<LabeledPoint> data = jsc.parallelize(Arrays.asList(
new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)),
new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)),
new LabeledPoint(0.0, Vectors.dense(2.4, 8.1))));
diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
index 60585d2727..5b464a4722 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
@@ -35,25 +35,31 @@ import org.apache.spark.mllib.tree.configuration.Algo;
import org.apache.spark.mllib.tree.configuration.Strategy;
import org.apache.spark.mllib.tree.impurity.Gini;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
+import org.apache.spark.sql.SparkSession;
public class JavaDecisionTreeSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaDecisionTreeSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaDecisionTreeSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
int validatePrediction(List<LabeledPoint> validationData, DecisionTreeModel model) {
int numCorrect = 0;
- for (LabeledPoint point: validationData) {
+ for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features());
if (prediction == point.label()) {
numCorrect++;
@@ -65,7 +71,7 @@ public class JavaDecisionTreeSuite implements Serializable {
@Test
public void runDTUsingConstructor() {
List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
- JavaRDD<LabeledPoint> rdd = sc.parallelize(arr);
+ JavaRDD<LabeledPoint> rdd = jsc.parallelize(arr);
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories
@@ -73,7 +79,7 @@ public class JavaDecisionTreeSuite implements Serializable {
int numClasses = 2;
int maxBins = 100;
Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses,
- maxBins, categoricalFeaturesInfo);
+ maxBins, categoricalFeaturesInfo);
DecisionTree learner = new DecisionTree(strategy);
DecisionTreeModel model = learner.run(rdd.rdd());
@@ -85,7 +91,7 @@ public class JavaDecisionTreeSuite implements Serializable {
@Test
public void runDTUsingStaticMethods() {
List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
- JavaRDD<LabeledPoint> rdd = sc.parallelize(arr);
+ JavaRDD<LabeledPoint> rdd = jsc.parallelize(arr);
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories
@@ -93,7 +99,7 @@ public class JavaDecisionTreeSuite implements Serializable {
int numClasses = 2;
int maxBins = 100;
Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses,
- maxBins, categoricalFeaturesInfo);
+ maxBins, categoricalFeaturesInfo);
DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy);