aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorSandeep Singh <sandeep@techaddict.me>2016-05-10 11:17:47 -0700
committerAndrew Or <andrew@databricks.com>2016-05-10 11:17:47 -0700
commited0b4070fb50054b1ecf66ff6c32458a4967dfd3 (patch)
tree68b3ad1a3ca22f2e0b5966db517c9bc42da3d254 /mllib/src/test
parentbcfee153b1cacfe617e602f3b72c0877e0bdf1f7 (diff)
downloadspark-ed0b4070fb50054b1ecf66ff6c32458a4967dfd3.tar.gz
spark-ed0b4070fb50054b1ecf66ff6c32458a4967dfd3.tar.bz2
spark-ed0b4070fb50054b1ecf66ff6c32458a4967dfd3.zip
[SPARK-15037][SQL][MLLIB] Use SparkSession instead of SQLContext in Scala/Java TestSuites
## What changes were proposed in this pull request? Use SparkSession instead of SQLContext in Scala/Java TestSuites as this PR already very big working Python TestSuites in a diff PR. ## How was this patch tested? Existing tests Author: Sandeep Singh <sandeep@techaddict.me> Closes #12907 from techaddict/SPARK-15037.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java27
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java23
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java18
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java49
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java36
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java18
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java90
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java26
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java26
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java20
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java21
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java18
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java19
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java21
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java19
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java17
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java22
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java26
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java21
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java31
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java18
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java18
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java22
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java14
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java38
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java18
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java18
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java25
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java28
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java18
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java18
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala1
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java21
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java35
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java25
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java32
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java27
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java20
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java23
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java37
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java3
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java21
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java22
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java19
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java23
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java29
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java26
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java278
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java136
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java64
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java22
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java32
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java42
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java22
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java32
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java24
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala16
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala44
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala8
-rwxr-xr-xmllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala16
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala18
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala16
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala21
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala32
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala18
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala28
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala24
114 files changed, 1283 insertions, 1037 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
index 60a4a1d2ea..e0c4363597 100644
--- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -17,18 +17,18 @@
package org.apache.spark.ml;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.StandardScaler;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
/**
@@ -36,23 +36,26 @@ import static org.apache.spark.mllib.classification.LogisticRegressionSuite.gene
*/
public class JavaPipelineSuite {
+ private transient SparkSession spark;
private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
private transient Dataset<Row> dataset;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaPipelineSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaPipelineSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
JavaRDD<LabeledPoint> points =
jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
- dataset = jsql.createDataFrame(points, LabeledPoint.class);
+ dataset = spark.createDataFrame(points, LabeledPoint.class);
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -63,10 +66,10 @@ public class JavaPipelineSuite {
LogisticRegression lr = new LogisticRegression()
.setFeaturesCol("scaledFeatures");
Pipeline pipeline = new Pipeline()
- .setStages(new PipelineStage[] {scaler, lr});
+ .setStages(new PipelineStage[]{scaler, lr});
PipelineModel model = pipeline.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
+ Dataset<Row> predictions = spark.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java
index b74bbed231..15cde0d3c0 100644
--- a/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java
@@ -17,8 +17,8 @@
package org.apache.spark.ml.attribute;
-import org.junit.Test;
import org.junit.Assert;
+import org.junit.Test;
public class JavaAttributeSuite {
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
index 1f23682621..8b89991327 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
@@ -21,8 +21,6 @@ import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -32,21 +30,28 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
-
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
public class JavaDecisionTreeClassifierSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaDecisionTreeClassifierSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaDecisionTreeClassifierSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -55,7 +60,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
double A = 2.0;
double B = -1.5;
- JavaRDD<LabeledPoint> data = sc.parallelize(
+ JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
@@ -70,7 +75,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
.setCacheNodeIds(false)
.setCheckpointInterval(10)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
- for (String impurity: DecisionTreeClassifier.supportedImpurities()) {
+ for (String impurity : DecisionTreeClassifier.supportedImpurities()) {
dt.setImpurity(impurity);
}
DecisionTreeClassificationModel model = dt.fit(dataFrame);
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
index 74841058a2..682371eb9e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
@@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
public class JavaGBTClassifierSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaGBTClassifierSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaGBTClassifierSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -55,7 +61,7 @@ public class JavaGBTClassifierSuite implements Serializable {
double A = 2.0;
double B = -1.5;
- JavaRDD<LabeledPoint> data = sc.parallelize(
+ JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
@@ -74,7 +80,7 @@ public class JavaGBTClassifierSuite implements Serializable {
.setMaxIter(3)
.setStepSize(0.1)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
- for (String lossType: GBTClassifier.supportedLossTypes()) {
+ for (String lossType : GBTClassifier.supportedLossTypes()) {
rf.setLossType(lossType);
}
GBTClassificationModel model = rf.fit(dataFrame);
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index e160a5a47e..e3ff68364e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -27,18 +27,17 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
-
+import org.apache.spark.sql.SparkSession;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
public class JavaLogisticRegressionSuite implements Serializable {
+ private transient SparkSession spark;
private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
@@ -46,18 +45,22 @@ public class JavaLogisticRegressionSuite implements Serializable {
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLogisticRegressionSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
+
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2);
- dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
+ dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.registerTempTable("dataset");
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -66,7 +69,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(lr.getLabelCol(), "label");
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
+ Dataset<Row> predictions = spark.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
// Check defaults
Assert.assertEquals(0.5, model.getThreshold(), eps);
@@ -95,23 +98,23 @@ public class JavaLogisticRegressionSuite implements Serializable {
// Modify model params, and check that the params worked.
model.setThreshold(1.0);
model.transform(dataset).registerTempTable("predAllZero");
- Dataset<Row> predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
- for (Row r: predAllZero.collectAsList()) {
+ Dataset<Row> predAllZero = spark.sql("SELECT prediction, myProbability FROM predAllZero");
+ for (Row r : predAllZero.collectAsList()) {
Assert.assertEquals(0.0, r.getDouble(0), eps);
}
// Call transform with params, and check that the params worked.
model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
.registerTempTable("predNotAllZero");
- Dataset<Row> predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
+ Dataset<Row> predNotAllZero = spark.sql("SELECT prediction, myProb FROM predNotAllZero");
boolean foundNonZero = false;
- for (Row r: predNotAllZero.collectAsList()) {
+ for (Row r : predNotAllZero.collectAsList()) {
if (r.getDouble(0) != 0.0) foundNonZero = true;
}
Assert.assertTrue(foundNonZero);
// Call fit() with new params, and check as many params as we can.
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
- lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
+ lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
LogisticRegression parent2 = (LogisticRegression) model2.parent();
Assert.assertEquals(5, parent2.getMaxIter());
Assert.assertEquals(0.1, parent2.getRegParam(), eps);
@@ -128,10 +131,10 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(2, model.numClasses());
model.transform(dataset).registerTempTable("transformed");
- Dataset<Row> trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
- for (Row row: trans1.collectAsList()) {
- Vector raw = (Vector)row.get(0);
- Vector prob = (Vector)row.get(1);
+ Dataset<Row> trans1 = spark.sql("SELECT rawPrediction, probability FROM transformed");
+ for (Row row : trans1.collectAsList()) {
+ Vector raw = (Vector) row.get(0);
+ Vector prob = (Vector) row.get(1);
Assert.assertEquals(raw.size(), 2);
Assert.assertEquals(prob.size(), 2);
double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1)));
@@ -139,11 +142,11 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps);
}
- Dataset<Row> trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
- for (Row row: trans2.collectAsList()) {
+ Dataset<Row> trans2 = spark.sql("SELECT prediction, probability FROM transformed");
+ for (Row row : trans2.collectAsList()) {
double pred = row.getDouble(0);
- Vector prob = (Vector)row.get(1);
- double probOfPred = prob.apply((int)pred);
+ Vector prob = (Vector) row.get(1);
+ double probOfPred = prob.apply((int) pred);
for (int i = 0; i < prob.size(); ++i) {
Assert.assertTrue(probOfPred >= prob.apply(i));
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
index bc955f3cf6..b0624cea3e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
@@ -26,49 +26,49 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
- private transient JavaSparkContext jsc;
- private transient SQLContext sqlContext;
+ private transient SparkSession spark;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
- sqlContext = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLogisticRegressionSuite")
+ .getOrCreate();
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
- sqlContext = null;
+ spark.stop();
+ spark = null;
}
@Test
public void testMLPC() {
- Dataset<Row> dataFrame = sqlContext.createDataFrame(
- jsc.parallelize(Arrays.asList(
- new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
- new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
- new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
- new LabeledPoint(0.0, Vectors.dense(1.0, 1.0)))),
- LabeledPoint.class);
+ List<LabeledPoint> data = Arrays.asList(
+ new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+ new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
+ new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
+ new LabeledPoint(0.0, Vectors.dense(1.0, 1.0))
+ );
+ Dataset<Row> dataFrame = spark.createDataFrame(data, LabeledPoint.class);
+
MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier()
- .setLayers(new int[] {2, 5, 2})
+ .setLayers(new int[]{2, 5, 2})
.setBlockSize(1)
.setSeed(123L)
.setMaxIter(100);
MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame);
Dataset<Row> result = model.transform(dataFrame);
List<Row> predictionAndLabels = result.select("prediction", "label").collectAsList();
- for (Row r: predictionAndLabels) {
+ for (Row r : predictionAndLabels) {
Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1));
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
index 45101f286c..3fc3648627 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
@@ -26,13 +26,12 @@ import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
@@ -40,19 +39,20 @@ import org.apache.spark.sql.types.StructType;
public class JavaNaiveBayesSuite implements Serializable {
- private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
+ private transient SparkSession spark;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLogisticRegressionSuite")
+ .getOrCreate();
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
public void validatePrediction(Dataset<Row> predictionAndLabels) {
@@ -88,7 +88,7 @@ public class JavaNaiveBayesSuite implements Serializable {
new StructField("features", new VectorUDT(), false, Metadata.empty())
});
- Dataset<Row> dataset = jsql.createDataFrame(data, schema);
+ Dataset<Row> dataset = spark.createDataFrame(data, schema);
NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
NaiveBayesModel model = nb.fit(dataset);
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
index 00f4476841..486fbbd58c 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
@@ -20,7 +20,6 @@ package org.apache.spark.ml.classification;
import java.io.Serializable;
import java.util.List;
-import org.apache.spark.sql.Row;
import scala.collection.JavaConverters;
import org.junit.After;
@@ -30,56 +29,61 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
public class JavaOneVsRestSuite implements Serializable {
- private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
- private transient Dataset<Row> dataset;
- private transient JavaRDD<LabeledPoint> datasetRDD;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
+ private transient Dataset<Row> dataset;
+ private transient JavaRDD<LabeledPoint> datasetRDD;
+
+ @Before
+ public void setUp() {
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLOneVsRestSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
- @Before
- public void setUp() {
- jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite");
- jsql = new SQLContext(jsc);
- int nPoints = 3;
+ int nPoints = 3;
- // The following coefficients and xMean/xVariance are computed from iris dataset with
- // lambda=0.2.
- // As a result, we are drawing samples from probability distribution of an actual model.
- double[] coefficients = {
- -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
- -0.16624, -0.84355, -0.048509, -0.301789, 4.170682 };
+ // The following coefficients and xMean/xVariance are computed from iris dataset with
+ // lambda=0.2.
+ // As a result, we are drawing samples from probability distribution of an actual model.
+ double[] coefficients = {
+ -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
+ -0.16624, -0.84355, -0.048509, -0.301789, 4.170682};
- double[] xMean = {5.843, 3.057, 3.758, 1.199};
- double[] xVariance = {0.6856, 0.1899, 3.116, 0.581};
- List<LabeledPoint> points = JavaConverters.seqAsJavaListConverter(
- generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42)
- ).asJava();
- datasetRDD = jsc.parallelize(points, 2);
- dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
- }
+ double[] xMean = {5.843, 3.057, 3.758, 1.199};
+ double[] xVariance = {0.6856, 0.1899, 3.116, 0.581};
+ List<LabeledPoint> points = JavaConverters.seqAsJavaListConverter(
+ generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42)
+ ).asJava();
+ datasetRDD = jsc.parallelize(points, 2);
+ dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
+ }
- @After
- public void tearDown() {
- jsc.stop();
- jsc = null;
- }
+ @After
+ public void tearDown() {
+ spark.stop();
+ spark = null;
+ }
- @Test
- public void oneVsRestDefaultParams() {
- OneVsRest ova = new OneVsRest();
- ova.setClassifier(new LogisticRegression());
- Assert.assertEquals(ova.getLabelCol() , "label");
- Assert.assertEquals(ova.getPredictionCol() , "prediction");
- OneVsRestModel ovaModel = ova.fit(dataset);
- Dataset<Row> predictions = ovaModel.transform(dataset).select("label", "prediction");
- predictions.collectAsList();
- Assert.assertEquals(ovaModel.getLabelCol(), "label");
- Assert.assertEquals(ovaModel.getPredictionCol() , "prediction");
- }
+ @Test
+ public void oneVsRestDefaultParams() {
+ OneVsRest ova = new OneVsRest();
+ ova.setClassifier(new LogisticRegression());
+ Assert.assertEquals(ova.getLabelCol(), "label");
+ Assert.assertEquals(ova.getPredictionCol(), "prediction");
+ OneVsRestModel ovaModel = ova.fit(dataset);
+ Dataset<Row> predictions = ovaModel.transform(dataset).select("label", "prediction");
+ predictions.collectAsList();
+ Assert.assertEquals(ovaModel.getLabelCol(), "label");
+ Assert.assertEquals(ovaModel.getPredictionCol(), "prediction");
+ }
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
index 4f40fd65b9..e3855662fb 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -34,21 +34,27 @@ import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
public class JavaRandomForestClassifierSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaRandomForestClassifierSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaRandomForestClassifierSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -57,7 +63,7 @@ public class JavaRandomForestClassifierSuite implements Serializable {
double A = 2.0;
double B = -1.5;
- JavaRDD<LabeledPoint> data = sc.parallelize(
+ JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
@@ -75,22 +81,22 @@ public class JavaRandomForestClassifierSuite implements Serializable {
.setSeed(1234)
.setNumTrees(3)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
- for (String impurity: RandomForestClassifier.supportedImpurities()) {
+ for (String impurity : RandomForestClassifier.supportedImpurities()) {
rf.setImpurity(impurity);
}
- for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) {
+ for (String featureSubsetStrategy : RandomForestClassifier.supportedFeatureSubsetStrategies()) {
rf.setFeatureSubsetStrategy(featureSubsetStrategy);
}
String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
- for (String strategy: realStrategies) {
+ for (String strategy : realStrategies) {
rf.setFeatureSubsetStrategy(strategy);
}
String[] integerStrategies = {"1", "10", "100", "1000", "10000"};
- for (String strategy: integerStrategies) {
+ for (String strategy : integerStrategies) {
rf.setFeatureSubsetStrategy(strategy);
}
String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
- for (String strategy: invalidStrategies) {
+ for (String strategy : invalidStrategies) {
try {
rf.setFeatureSubsetStrategy(strategy);
Assert.fail("Expected exception to be thrown for invalid strategies");
diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
index a3fcdb54ee..3ab09ac27d 100644
--- a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
@@ -21,37 +21,37 @@ import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
public class JavaKMeansSuite implements Serializable {
private transient int k = 5;
- private transient JavaSparkContext sc;
private transient Dataset<Row> dataset;
- private transient SQLContext sql;
+ private transient SparkSession spark;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaKMeansSuite");
- sql = new SQLContext(sc);
-
- dataset = KMeansSuite.generateKMeansData(sql, 50, 3, k);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaKMeansSuite")
+ .getOrCreate();
+ dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k);
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -65,7 +65,7 @@ public class JavaKMeansSuite implements Serializable {
Dataset<Row> transformed = model.transform(dataset);
List<String> columns = Arrays.asList(transformed.columns());
List<String> expectedColumns = Arrays.asList("features", "prediction");
- for (String column: expectedColumns) {
+ for (String column : expectedColumns) {
assertTrue(columns.contains(column));
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
index 77e3a489a9..a96b43de15 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
@@ -25,40 +25,40 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
public class JavaBucketizerSuite {
- private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
+ private transient SparkSession spark;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaBucketizerSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaBucketizerSuite")
+ .getOrCreate();
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
public void bucketizerTest() {
double[] splits = {-0.5, 0.0, 0.5};
- StructType schema = new StructType(new StructField[] {
+ StructType schema = new StructType(new StructField[]{
new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
});
- Dataset<Row> dataset = jsql.createDataFrame(
+ Dataset<Row> dataset = spark.createDataFrame(
Arrays.asList(
RowFactory.create(-0.5),
RowFactory.create(-0.3),
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
index ed1ad4c3a3..06482d8f0d 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
@@ -21,43 +21,44 @@ import java.util.Arrays;
import java.util.List;
import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D;
+
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
public class JavaDCTSuite {
- private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
+ private transient SparkSession spark;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaDCTSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaDCTSuite")
+ .getOrCreate();
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
public void javaCompatibilityTest() {
- double[] input = new double[] {1D, 2D, 3D, 4D};
- Dataset<Row> dataset = jsql.createDataFrame(
+ double[] input = new double[]{1D, 2D, 3D, 4D};
+ Dataset<Row> dataset = spark.createDataFrame(
Arrays.asList(RowFactory.create(Vectors.dense(input))),
new StructType(new StructField[]{
new StructField("vec", (new VectorUDT()), false, Metadata.empty())
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
index 6e2cc7e887..0e21d4a94f 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
@@ -25,12 +25,11 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
@@ -38,19 +37,20 @@ import org.apache.spark.sql.types.StructType;
public class JavaHashingTFSuite {
- private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
+ private transient SparkSession spark;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaHashingTFSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaHashingTFSuite")
+ .getOrCreate();
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -65,7 +65,7 @@ public class JavaHashingTFSuite {
new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
});
- Dataset<Row> sentenceData = jsql.createDataFrame(data, schema);
+ Dataset<Row> sentenceData = spark.createDataFrame(data, schema);
Tokenizer tokenizer = new Tokenizer()
.setInputCol("sentence")
.setOutputCol("words");
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
index 5bbd9634b2..04b2897b18 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
@@ -23,27 +23,30 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
public class JavaNormalizerSuite {
+ private transient SparkSession spark;
private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaNormalizerSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaNormalizerSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -54,7 +57,7 @@ public class JavaNormalizerSuite {
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
));
- Dataset<Row> dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class);
+ Dataset<Row> dataFrame = spark.createDataFrame(points, VectorIndexerSuite.FeatureData.class);
Normalizer normalizer = new Normalizer()
.setInputCol("features")
.setOutputCol("normFeatures");
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
index 1389d17e7e..32f6b4375e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
@@ -28,31 +28,34 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.linalg.distributed.RowMatrix;
+import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.mllib.linalg.distributed.RowMatrix;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
public class JavaPCASuite implements Serializable {
+ private transient SparkSession spark;
private transient JavaSparkContext jsc;
- private transient SQLContext sqlContext;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaPCASuite");
- sqlContext = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaPCASuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
public static class VectorPair implements Serializable {
@@ -100,7 +103,7 @@ public class JavaPCASuite implements Serializable {
}
);
- Dataset<Row> df = sqlContext.createDataFrame(featuresExpected, VectorPair.class);
+ Dataset<Row> df = spark.createDataFrame(featuresExpected, VectorPair.class);
PCAModel pca = new PCA()
.setInputCol("features")
.setOutputCol("pca_features")
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
index 6a8bb64801..8f726077a2 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
@@ -32,19 +32,22 @@ import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
public class JavaPolynomialExpansionSuite {
+ private transient SparkSession spark;
private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaPolynomialExpansionSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaPolynomialExpansionSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
@@ -72,20 +75,20 @@ public class JavaPolynomialExpansionSuite {
)
);
- StructType schema = new StructType(new StructField[] {
+ StructType schema = new StructType(new StructField[]{
new StructField("features", new VectorUDT(), false, Metadata.empty()),
new StructField("expected", new VectorUDT(), false, Metadata.empty())
});
- Dataset<Row> dataset = jsql.createDataFrame(data, schema);
+ Dataset<Row> dataset = spark.createDataFrame(data, schema);
List<Row> pairs = polyExpansion.transform(dataset)
.select("polyFeatures", "expected")
.collectAsList();
for (Row r : pairs) {
- double[] polyFeatures = ((Vector)r.get(0)).toArray();
- double[] expected = ((Vector)r.get(1)).toArray();
+ double[] polyFeatures = ((Vector) r.get(0)).toArray();
+ double[] expected = ((Vector) r.get(1)).toArray();
Assert.assertArrayEquals(polyFeatures, expected, 1e-1);
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
index 3f6fc333e4..c7397bdd68 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
@@ -28,22 +28,25 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
public class JavaStandardScalerSuite {
+ private transient SparkSession spark;
private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaStandardScalerSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaStandardScalerSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -54,7 +57,7 @@ public class JavaStandardScalerSuite {
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
);
- Dataset<Row> dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2),
+ Dataset<Row> dataFrame = spark.createDataFrame(jsc.parallelize(points, 2),
VectorIndexerSuite.FeatureData.class);
StandardScaler scaler = new StandardScaler()
.setInputCol("features")
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
index bdcbde5e26..2b156f3bca 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
@@ -24,11 +24,10 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
@@ -37,19 +36,20 @@ import org.apache.spark.sql.types.StructType;
public class JavaStopWordsRemoverSuite {
- private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
+ private transient SparkSession spark;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaStopWordsRemoverSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaStopWordsRemoverSuite")
+ .getOrCreate();
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -62,11 +62,11 @@ public class JavaStopWordsRemoverSuite {
RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")),
RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb"))
);
- StructType schema = new StructType(new StructField[] {
+ StructType schema = new StructType(new StructField[]{
new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false,
- Metadata.empty())
+ Metadata.empty())
});
- Dataset<Row> dataset = jsql.createDataFrame(data, schema);
+ Dataset<Row> dataset = spark.createDataFrame(data, schema);
remover.transform(dataset).collect();
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
index 431779cd2e..52c0bde8f3 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
@@ -25,40 +25,42 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.SparkConf;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import static org.apache.spark.sql.types.DataTypes.*;
public class JavaStringIndexerSuite {
- private transient JavaSparkContext jsc;
- private transient SQLContext sqlContext;
+ private transient SparkSession spark;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaStringIndexerSuite");
- sqlContext = new SQLContext(jsc);
+ SparkConf sparkConf = new SparkConf();
+ sparkConf.setMaster("local");
+ sparkConf.setAppName("JavaStringIndexerSuite");
+
+ spark = SparkSession.builder().config(sparkConf).getOrCreate();
}
@After
public void tearDown() {
- jsc.stop();
- sqlContext = null;
+ spark.stop();
+ spark = null;
}
@Test
public void testStringIndexer() {
- StructType schema = createStructType(new StructField[] {
+ StructType schema = createStructType(new StructField[]{
createStructField("id", IntegerType, false),
createStructField("label", StringType, false)
});
List<Row> data = Arrays.asList(
cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c"));
- Dataset<Row> dataset = sqlContext.createDataFrame(data, schema);
+ Dataset<Row> dataset = spark.createDataFrame(data, schema);
StringIndexer indexer = new StringIndexer()
.setInputCol("label")
@@ -70,7 +72,9 @@ public class JavaStringIndexerSuite {
output.orderBy("id").select("id", "labelIndex").collectAsList());
}
- /** An alias for RowFactory.create. */
+ /**
+ * An alias for RowFactory.create.
+ */
private Row cr(Object... values) {
return RowFactory.create(values);
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
index 83d16cbd0e..0bac2839e1 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
@@ -29,22 +29,25 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
public class JavaTokenizerSuite {
+ private transient SparkSession spark;
private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaTokenizerSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaTokenizerSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -59,10 +62,10 @@ public class JavaTokenizerSuite {
JavaRDD<TokenizerTestData> rdd = jsc.parallelize(Arrays.asList(
- new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}),
- new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"})
+ new TokenizerTestData("Test of tok.", new String[]{"Test", "tok."}),
+ new TokenizerTestData("Te,st. punct", new String[]{"Te,st.", "punct"})
));
- Dataset<Row> dataset = jsql.createDataFrame(rdd, TokenizerTestData.class);
+ Dataset<Row> dataset = spark.createDataFrame(rdd, TokenizerTestData.class);
List<Row> pairs = myRegExTokenizer.transform(dataset)
.select("tokens", "wantedTokens")
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
index e45e198043..8774cd0c69 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
@@ -24,36 +24,39 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.SparkConf;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
-import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
import static org.apache.spark.sql.types.DataTypes.*;
public class JavaVectorAssemblerSuite {
- private transient JavaSparkContext jsc;
- private transient SQLContext sqlContext;
+ private transient SparkSession spark;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaVectorAssemblerSuite");
- sqlContext = new SQLContext(jsc);
+ SparkConf sparkConf = new SparkConf();
+ sparkConf.setMaster("local");
+ sparkConf.setAppName("JavaVectorAssemblerSuite");
+
+ spark = SparkSession.builder().config(sparkConf).getOrCreate();
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
public void testVectorAssembler() {
- StructType schema = createStructType(new StructField[] {
+ StructType schema = createStructType(new StructField[]{
createStructField("id", IntegerType, false),
createStructField("x", DoubleType, false),
createStructField("y", new VectorUDT(), false),
@@ -63,14 +66,14 @@ public class JavaVectorAssemblerSuite {
});
Row row = RowFactory.create(
0, 0.0, Vectors.dense(1.0, 2.0), "a",
- Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L);
- Dataset<Row> dataset = sqlContext.createDataFrame(Arrays.asList(row), schema);
+ Vectors.sparse(2, new int[]{1}, new double[]{3.0}), 10L);
+ Dataset<Row> dataset = spark.createDataFrame(Arrays.asList(row), schema);
VectorAssembler assembler = new VectorAssembler()
- .setInputCols(new String[] {"x", "y", "z", "n"})
+ .setInputCols(new String[]{"x", "y", "z", "n"})
.setOutputCol("features");
Dataset<Row> output = assembler.transform(dataset);
Assert.assertEquals(
- Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}),
+ Vectors.sparse(6, new int[]{1, 2, 4, 5}, new double[]{1.0, 2.0, 3.0, 10.0}),
output.select("features").first().<Vector>getAs(0));
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
index fec6cac8be..c386c9a45b 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
@@ -32,21 +32,26 @@ import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
public class JavaVectorIndexerSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaVectorIndexerSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaVectorIndexerSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -57,8 +62,7 @@ public class JavaVectorIndexerSuite implements Serializable {
new FeatureData(Vectors.dense(1.0, 3.0)),
new FeatureData(Vectors.dense(1.0, 4.0))
);
- SQLContext sqlContext = new SQLContext(sc);
- Dataset<Row> data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class);
+ Dataset<Row> data = spark.createDataFrame(jsc.parallelize(points, 2), FeatureData.class);
VectorIndexer indexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexed")
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
index e2da11183b..59ad3c2f61 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
@@ -25,7 +25,6 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.NumericAttribute;
@@ -34,24 +33,25 @@ import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructType;
public class JavaVectorSlicerSuite {
- private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
+ private transient SparkSession spark;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaVectorSlicerSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaVectorSlicerSuite")
+ .getOrCreate();
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -69,7 +69,7 @@ public class JavaVectorSlicerSuite {
);
Dataset<Row> dataset =
- jsql.createDataFrame(data, (new StructType()).add(group.toStructField()));
+ spark.createDataFrame(data, (new StructType()).add(group.toStructField()));
VectorSlicer vectorSlicer = new VectorSlicer()
.setInputCol("userFeatures").setOutputCol("features");
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
index 7517b70cc9..392aabc96d 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
@@ -24,28 +24,28 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.*;
public class JavaWord2VecSuite {
- private transient JavaSparkContext jsc;
- private transient SQLContext sqlContext;
+ private transient SparkSession spark;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaWord2VecSuite");
- sqlContext = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaWord2VecSuite")
+ .getOrCreate();
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -53,7 +53,7 @@ public class JavaWord2VecSuite {
StructType schema = new StructType(new StructField[]{
new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
});
- Dataset<Row> documentDF = sqlContext.createDataFrame(
+ Dataset<Row> documentDF = spark.createDataFrame(
Arrays.asList(
RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
@@ -68,8 +68,8 @@ public class JavaWord2VecSuite {
Word2VecModel model = word2Vec.fit(documentDF);
Dataset<Row> result = model.transform(documentDF);
- for (Row r: result.select("result").collectAsList()) {
- double[] polyFeatures = ((Vector)r.get(0)).toArray();
+ for (Row r : result.select("result").collectAsList()) {
+ double[] polyFeatures = ((Vector) r.get(0)).toArray();
Assert.assertEquals(polyFeatures.length, 3);
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
index fa777f3d42..a5b5dd4088 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
@@ -25,23 +25,29 @@ import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
/**
* Test Param and related classes in Java
*/
public class JavaParamsSuite {
+ private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaParamsSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaParamsSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -51,7 +57,7 @@ public class JavaParamsSuite {
testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a");
Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
Assert.assertEquals(testParams.getMyStringParam(), "a");
- Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[] {1.0, 2.0}, 0.0);
+ Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[]{1.0, 2.0}, 0.0);
}
@Test
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
index 06f7fbb86e..1ad5f7a442 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
@@ -45,9 +45,14 @@ public class JavaTestParams extends JavaParams {
}
private IntParam myIntParam_;
- public IntParam myIntParam() { return myIntParam_; }
- public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); }
+ public IntParam myIntParam() {
+ return myIntParam_;
+ }
+
+ public int getMyIntParam() {
+ return (Integer) getOrDefault(myIntParam_);
+ }
public JavaTestParams setMyIntParam(int value) {
set(myIntParam_, value);
@@ -55,9 +60,14 @@ public class JavaTestParams extends JavaParams {
}
private DoubleParam myDoubleParam_;
- public DoubleParam myDoubleParam() { return myDoubleParam_; }
- public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); }
+ public DoubleParam myDoubleParam() {
+ return myDoubleParam_;
+ }
+
+ public double getMyDoubleParam() {
+ return (Double) getOrDefault(myDoubleParam_);
+ }
public JavaTestParams setMyDoubleParam(double value) {
set(myDoubleParam_, value);
@@ -65,9 +75,14 @@ public class JavaTestParams extends JavaParams {
}
private Param<String> myStringParam_;
- public Param<String> myStringParam() { return myStringParam_; }
- public String getMyStringParam() { return getOrDefault(myStringParam_); }
+ public Param<String> myStringParam() {
+ return myStringParam_;
+ }
+
+ public String getMyStringParam() {
+ return getOrDefault(myStringParam_);
+ }
public JavaTestParams setMyStringParam(String value) {
set(myStringParam_, value);
@@ -75,9 +90,14 @@ public class JavaTestParams extends JavaParams {
}
private DoubleArrayParam myDoubleArrayParam_;
- public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; }
- public double[] getMyDoubleArrayParam() { return getOrDefault(myDoubleArrayParam_); }
+ public DoubleArrayParam myDoubleArrayParam() {
+ return myDoubleArrayParam_;
+ }
+
+ public double[] getMyDoubleArrayParam() {
+ return getOrDefault(myDoubleArrayParam_);
+ }
public JavaTestParams setMyDoubleArrayParam(double[] value) {
set(myDoubleArrayParam_, value);
@@ -96,7 +116,7 @@ public class JavaTestParams extends JavaParams {
setDefault(myIntParam(), 1);
setDefault(myDoubleParam(), 0.5);
- setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0});
+ setDefault(myDoubleArrayParam(), new double[]{1.0, 2.0});
}
@Override
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
index fa3b28ed4f..bbd59a04ec 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
@@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
public class JavaDecisionTreeRegressorSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaDecisionTreeRegressorSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaDecisionTreeRegressorSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -55,7 +61,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable {
double A = 2.0;
double B = -1.5;
- JavaRDD<LabeledPoint> data = sc.parallelize(
+ JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
@@ -70,7 +76,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable {
.setCacheNodeIds(false)
.setCheckpointInterval(10)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
- for (String impurity: DecisionTreeRegressor.supportedImpurities()) {
+ for (String impurity : DecisionTreeRegressor.supportedImpurities()) {
dt.setImpurity(impurity);
}
DecisionTreeRegressionModel model = dt.fit(dataFrame);
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
index 8413ea0e0a..5370b58e8f 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
@@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
public class JavaGBTRegressorSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaGBTRegressorSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaGBTRegressorSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -55,7 +61,7 @@ public class JavaGBTRegressorSuite implements Serializable {
double A = 2.0;
double B = -1.5;
- JavaRDD<LabeledPoint> data = sc.parallelize(
+ JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
@@ -73,7 +79,7 @@ public class JavaGBTRegressorSuite implements Serializable {
.setMaxIter(3)
.setStepSize(0.1)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
- for (String lossType: GBTRegressor.supportedLossTypes()) {
+ for (String lossType : GBTRegressor.supportedLossTypes()) {
rf.setLossType(lossType);
}
GBTRegressionModel model = rf.fit(dataFrame);
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index 9f817515eb..00c59f08b6 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -30,25 +30,26 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
-import static org.apache.spark.mllib.classification.LogisticRegressionSuite
- .generateLogisticInputAsList;
-
+import org.apache.spark.sql.SparkSession;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
public class JavaLinearRegressionSuite implements Serializable {
+ private transient SparkSession spark;
private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLinearRegressionSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2);
- dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
+ dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.registerTempTable("dataset");
}
@@ -65,7 +66,7 @@ public class JavaLinearRegressionSuite implements Serializable {
assertEquals("auto", lr.getSolver());
LinearRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- Dataset<Row> predictions = jsql.sql("SELECT label, prediction FROM prediction");
+ Dataset<Row> predictions = spark.sql("SELECT label, prediction FROM prediction");
predictions.collect();
// Check defaults
assertEquals("features", model.getFeaturesCol());
@@ -76,8 +77,8 @@ public class JavaLinearRegressionSuite implements Serializable {
public void linearRegressionWithSetters() {
// Set params, train, and check as many params as we can.
LinearRegression lr = new LinearRegression()
- .setMaxIter(10)
- .setRegParam(1.0).setSolver("l-bfgs");
+ .setMaxIter(10)
+ .setRegParam(1.0).setSolver("l-bfgs");
LinearRegressionModel model = lr.fit(dataset);
LinearRegression parent = (LinearRegression) model.parent();
assertEquals(10, parent.getMaxIter());
@@ -85,7 +86,7 @@ public class JavaLinearRegressionSuite implements Serializable {
// Call fit() with new params, and check as many params as we can.
LinearRegressionModel model2 =
- lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
+ lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
LinearRegression parent2 = (LinearRegression) model2.parent();
assertEquals(5, parent2.getMaxIter());
assertEquals(0.1, parent2.getRegParam(), 0.0);
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
index 38b895f1fd..fdb41ffc10 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -28,27 +28,33 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.ml.tree.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
public class JavaRandomForestRegressorSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaRandomForestRegressorSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaRandomForestRegressorSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -57,7 +63,7 @@ public class JavaRandomForestRegressorSuite implements Serializable {
double A = 2.0;
double B = -1.5;
- JavaRDD<LabeledPoint> data = sc.parallelize(
+ JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
@@ -75,22 +81,22 @@ public class JavaRandomForestRegressorSuite implements Serializable {
.setSeed(1234)
.setNumTrees(3)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
- for (String impurity: RandomForestRegressor.supportedImpurities()) {
+ for (String impurity : RandomForestRegressor.supportedImpurities()) {
rf.setImpurity(impurity);
}
- for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) {
+ for (String featureSubsetStrategy : RandomForestRegressor.supportedFeatureSubsetStrategies()) {
rf.setFeatureSubsetStrategy(featureSubsetStrategy);
}
String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
- for (String strategy: realStrategies) {
+ for (String strategy : realStrategies) {
rf.setFeatureSubsetStrategy(strategy);
}
String[] integerStrategies = {"1", "10", "100", "1000", "10000"};
- for (String strategy: integerStrategies) {
+ for (String strategy : integerStrategies) {
rf.setFeatureSubsetStrategy(strategy);
}
String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
- for (String strategy: invalidStrategies) {
+ for (String strategy : invalidStrategies) {
try {
rf.setFeatureSubsetStrategy(strategy);
Assert.fail("Expected exception to be thrown for invalid strategies");
diff --git a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
index 1c18b2b266..058f2ddafd 100644
--- a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
@@ -28,12 +28,11 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils;
@@ -41,16 +40,17 @@ import org.apache.spark.util.Utils;
* Test LibSVMRelation in Java.
*/
public class JavaLibSVMRelationSuite {
- private transient JavaSparkContext jsc;
- private transient SQLContext sqlContext;
+ private transient SparkSession spark;
private File tempDir;
private String path;
@Before
public void setUp() throws IOException {
- jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite");
- sqlContext = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLibSVMRelationSuite")
+ .getOrCreate();
tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
File file = new File(tempDir, "part-00000");
@@ -61,14 +61,14 @@ public class JavaLibSVMRelationSuite {
@After
public void tearDown() {
- jsc.stop();
- jsc = null;
+ spark.stop();
+ spark = null;
Utils.deleteRecursively(tempDir);
}
@Test
public void verifyLibSVMDF() {
- Dataset<Row> dataset = sqlContext.read().format("libsvm").option("vectorType", "dense")
+ Dataset<Row> dataset = spark.read().format("libsvm").option("vectorType", "dense")
.load(path);
Assert.assertEquals("label", dataset.columns()[0]);
Assert.assertEquals("features", dataset.columns()[1]);
diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
index 24b0097454..8b4d034ffe 100644
--- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -32,21 +32,25 @@ import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
public class JavaCrossValidatorSuite implements Serializable {
+ private transient SparkSession spark;
private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
private transient Dataset<Row> dataset;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite");
- jsql = new SQLContext(jsc);
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaCrossValidatorSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
+
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
- dataset = jsql.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class);
+ dataset = spark.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class);
}
@After
@@ -59,8 +63,8 @@ public class JavaCrossValidatorSuite implements Serializable {
public void crossValidationWithLogisticRegression() {
LogisticRegression lr = new LogisticRegression();
ParamMap[] lrParamMaps = new ParamGridBuilder()
- .addGrid(lr.regParam(), new double[] {0.001, 1000.0})
- .addGrid(lr.maxIter(), new int[] {0, 10})
+ .addGrid(lr.regParam(), new double[]{0.001, 1000.0})
+ .addGrid(lr.maxIter(), new int[]{0, 10})
.build();
BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator();
CrossValidator cv = new CrossValidator()
diff --git a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
index 928301523f..878bc66ee3 100644
--- a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
+++ b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
@@ -37,4 +37,5 @@ object IdentifiableSuite {
class Test(override val uid: String) extends Identifiable {
def this() = this(Identifiable.randomUID("test"))
}
+
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
index 01ff1ea658..7151e27cde 100644
--- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
@@ -27,31 +27,34 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils;
public class JavaDefaultReadWriteSuite {
JavaSparkContext jsc = null;
- SQLContext sqlContext = null;
+ SparkSession spark = null;
File tempDir = null;
@Before
public void setUp() {
- jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite");
SQLContext.clearActive();
- sqlContext = new SQLContext(jsc);
- SQLContext.setActive(sqlContext);
+ spark = SparkSession.builder()
+ .master("local[2]")
+ .appName("JavaDefaultReadWriteSuite")
+ .getOrCreate();
+ SQLContext.setActive(spark.wrapped());
+
tempDir = Utils.createTempDir(
System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite");
}
@After
public void tearDown() {
- sqlContext = null;
SQLContext.clearActive();
- if (jsc != null) {
- jsc.stop();
- jsc = null;
+ if (spark != null) {
+ spark.stop();
+ spark = null;
}
Utils.deleteRecursively(tempDir);
}
@@ -70,7 +73,7 @@ public class JavaDefaultReadWriteSuite {
} catch (IOException e) {
// expected
}
- instance.write().context(sqlContext).overwrite().save(outputPath);
+ instance.write().context(spark.wrapped()).overwrite().save(outputPath);
MyParams newInstance = MyParams.load(outputPath);
Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid());
Assert.assertEquals("Params should be preserved.",
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
index 862221d487..2f10d14da5 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
@@ -27,26 +27,31 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-
import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.SparkSession;
public class JavaLogisticRegressionSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLogisticRegressionSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
int validatePrediction(List<LabeledPoint> validationData, LogisticRegressionModel model) {
int numAccurate = 0;
- for (LabeledPoint point: validationData) {
+ for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features());
if (prediction == point.label()) {
numAccurate++;
@@ -61,16 +66,16 @@ public class JavaLogisticRegressionSuite implements Serializable {
double A = 2.0;
double B = -1.5;
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(
- LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
List<LabeledPoint> validationData =
- LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD();
lrImpl.setIntercept(true);
lrImpl.optimizer().setStepSize(1.0)
- .setRegParam(1.0)
- .setNumIterations(100);
+ .setRegParam(1.0)
+ .setNumIterations(100);
LogisticRegressionModel model = lrImpl.run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model);
@@ -83,13 +88,13 @@ public class JavaLogisticRegressionSuite implements Serializable {
double A = 0.0;
double B = -2.5;
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(
- LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
List<LabeledPoint> validationData =
- LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
LogisticRegressionModel model = LogisticRegressionWithSGD.train(
- testRDD.rdd(), 100, 1.0, 1.0);
+ testRDD.rdd(), 100, 1.0, 1.0);
int numAccurate = validatePrediction(validationData, model);
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
index 3771c0ea7a..5e212e2fc5 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
@@ -32,20 +32,26 @@ import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.SparkSession;
public class JavaNaiveBayesSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaNaiveBayesSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaNaiveBayesSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
private static final List<LabeledPoint> POINTS = Arrays.asList(
@@ -59,7 +65,7 @@ public class JavaNaiveBayesSuite implements Serializable {
private int validatePrediction(List<LabeledPoint> points, NaiveBayesModel model) {
int correct = 0;
- for (LabeledPoint p: points) {
+ for (LabeledPoint p : points) {
if (model.predict(p.features()) == p.label()) {
correct += 1;
}
@@ -69,7 +75,7 @@ public class JavaNaiveBayesSuite implements Serializable {
@Test
public void runUsingConstructor() {
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(POINTS, 2).cache();
NaiveBayes nb = new NaiveBayes().setLambda(1.0);
NaiveBayesModel model = nb.run(testRDD.rdd());
@@ -80,7 +86,7 @@ public class JavaNaiveBayesSuite implements Serializable {
@Test
public void runUsingStaticMethods() {
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(POINTS, 2).cache();
NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd());
int numAccurate1 = validatePrediction(POINTS, model1);
@@ -93,13 +99,14 @@ public class JavaNaiveBayesSuite implements Serializable {
@Test
public void testPredictJavaRDD() {
- JavaRDD<LabeledPoint> examples = sc.parallelize(POINTS, 2).cache();
+ JavaRDD<LabeledPoint> examples = jsc.parallelize(POINTS, 2).cache();
NaiveBayesModel model = NaiveBayes.train(examples.rdd());
JavaRDD<Vector> vectors = examples.map(new Function<LabeledPoint, Vector>() {
@Override
public Vector call(LabeledPoint v) throws Exception {
return v.features();
- }});
+ }
+ });
JavaRDD<Double> predictions = model.predict(vectors);
// Should be able to get the first prediction.
predictions.first();
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
index 31b9f3e8d4..2a090c054f 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
@@ -28,24 +28,30 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.SparkSession;
public class JavaSVMSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaSVMSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaSVMSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
int validatePrediction(List<LabeledPoint> validationData, SVMModel model) {
int numAccurate = 0;
- for (LabeledPoint point: validationData) {
+ for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features());
if (prediction == point.label()) {
numAccurate++;
@@ -60,16 +66,16 @@ public class JavaSVMSuite implements Serializable {
double A = 2.0;
double[] weights = {-1.5, 1.0};
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A,
- weights, nPoints, 42), 2).cache();
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(SVMSuite.generateSVMInputAsList(A,
+ weights, nPoints, 42), 2).cache();
List<LabeledPoint> validationData =
- SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
+ SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
SVMWithSGD svmSGDImpl = new SVMWithSGD();
svmSGDImpl.setIntercept(true);
svmSGDImpl.optimizer().setStepSize(1.0)
- .setRegParam(1.0)
- .setNumIterations(100);
+ .setRegParam(1.0)
+ .setNumIterations(100);
SVMModel model = svmSGDImpl.run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model);
@@ -82,10 +88,10 @@ public class JavaSVMSuite implements Serializable {
double A = 0.0;
double[] weights = {-1.5, 1.0};
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A,
- weights, nPoints, 42), 2).cache();
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(SVMSuite.generateSVMInputAsList(A,
+ weights, nPoints, 42), 2).cache();
List<LabeledPoint> validationData =
- SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
+ SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
SVMModel model = SVMWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0, 1.0);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
index a714620ff7..7f29b05047 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering;
import java.io.Serializable;
import com.google.common.collect.Lists;
+
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
@@ -29,27 +30,33 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.SparkSession;
public class JavaBisectingKMeansSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", this.getClass().getSimpleName());
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaBisectingKMeansSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
public void twoDimensionalData() {
- JavaRDD<Vector> points = sc.parallelize(Lists.newArrayList(
+ JavaRDD<Vector> points = jsc.parallelize(Lists.newArrayList(
Vectors.dense(4, -1),
Vectors.dense(4, 1),
- Vectors.sparse(2, new int[] {0}, new double[] {1.0})
+ Vectors.sparse(2, new int[]{0}, new double[]{1.0})
), 2);
BisectingKMeans bkm = new BisectingKMeans()
@@ -58,15 +65,15 @@ public class JavaBisectingKMeansSuite implements Serializable {
.setSeed(1L);
BisectingKMeansModel model = bkm.run(points);
Assert.assertEquals(3, model.k());
- Assert.assertArrayEquals(new double[] {3.0, 0.0}, model.root().center().toArray(), 1e-12);
- for (ClusteringTreeNode child: model.root().children()) {
+ Assert.assertArrayEquals(new double[]{3.0, 0.0}, model.root().center().toArray(), 1e-12);
+ for (ClusteringTreeNode child : model.root().children()) {
double[] center = child.center().toArray();
if (center[0] > 2) {
Assert.assertEquals(2, child.size());
- Assert.assertArrayEquals(new double[] {4.0, 0.0}, center, 1e-12);
+ Assert.assertArrayEquals(new double[]{4.0, 0.0}, center, 1e-12);
} else {
Assert.assertEquals(1, child.size());
- Assert.assertArrayEquals(new double[] {1.0, 0.0}, center, 1e-12);
+ Assert.assertArrayEquals(new double[]{1.0, 0.0}, center, 1e-12);
}
}
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
index 123f78da54..20edd08a21 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
@@ -21,29 +21,35 @@ import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
+import static org.junit.Assert.assertEquals;
+
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
-import static org.junit.Assert.assertEquals;
-
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.SparkSession;
public class JavaGaussianMixtureSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaGaussianMixture");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaGaussianMixture")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -54,7 +60,7 @@ public class JavaGaussianMixtureSuite implements Serializable {
Vectors.dense(1.0, 4.0, 6.0)
);
- JavaRDD<Vector> data = sc.parallelize(points, 2);
+ JavaRDD<Vector> data = jsc.parallelize(points, 2);
GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234)
.run(data);
assertEquals(model.gaussians().length, 2);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
index ad06676c72..4e5b87f588 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
@@ -21,28 +21,35 @@ import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
+import static org.junit.Assert.assertEquals;
+
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
-import static org.junit.Assert.*;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.SparkSession;
public class JavaKMeansSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaKMeans");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaKMeans")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -55,7 +62,7 @@ public class JavaKMeansSuite implements Serializable {
Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
- JavaRDD<Vector> data = sc.parallelize(points, 2);
+ JavaRDD<Vector> data = jsc.parallelize(points, 2);
KMeansModel model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.K_MEANS_PARALLEL());
assertEquals(1, model.clusterCenters().length);
assertEquals(expectedCenter, model.clusterCenters()[0]);
@@ -74,7 +81,7 @@ public class JavaKMeansSuite implements Serializable {
Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
- JavaRDD<Vector> data = sc.parallelize(points, 2);
+ JavaRDD<Vector> data = jsc.parallelize(points, 2);
KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
assertEquals(1, model.clusterCenters().length);
assertEquals(expectedCenter, model.clusterCenters()[0]);
@@ -94,7 +101,7 @@ public class JavaKMeansSuite implements Serializable {
Vectors.dense(1.0, 3.0, 0.0),
Vectors.dense(1.0, 4.0, 6.0)
);
- JavaRDD<Vector> data = sc.parallelize(points, 2);
+ JavaRDD<Vector> data = jsc.parallelize(points, 2);
KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
JavaRDD<Integer> predictions = model.predict(data);
// Should be able to get the first prediction.
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
index db19b309f6..f16585aff4 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
@@ -27,37 +27,42 @@ import scala.Tuple3;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
-import static org.junit.Assert.assertArrayEquals;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.*;
-import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.SparkSession;
public class JavaLDASuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaLDA");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLDASuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
+
ArrayList<Tuple2<Long, Vector>> tinyCorpus = new ArrayList<>();
for (int i = 0; i < LDASuite.tinyCorpus().length; i++) {
- tinyCorpus.add(new Tuple2<>((Long)LDASuite.tinyCorpus()[i]._1(),
- LDASuite.tinyCorpus()[i]._2()));
+ tinyCorpus.add(new Tuple2<>((Long) LDASuite.tinyCorpus()[i]._1(),
+ LDASuite.tinyCorpus()[i]._2()));
}
- JavaRDD<Tuple2<Long, Vector>> tmpCorpus = sc.parallelize(tinyCorpus, 2);
+ JavaRDD<Tuple2<Long, Vector>> tmpCorpus = jsc.parallelize(tinyCorpus, 2);
corpus = JavaPairRDD.fromJavaRDD(tmpCorpus);
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -95,7 +100,7 @@ public class JavaLDASuite implements Serializable {
.setMaxIterations(5)
.setSeed(12345);
- DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus);
+ DistributedLDAModel model = (DistributedLDAModel) lda.run(corpus);
// Check: basic parameters
LocalLDAModel localModel = model.toLocal();
@@ -124,7 +129,7 @@ public class JavaLDASuite implements Serializable {
public Boolean call(Tuple2<Long, Vector> tuple2) {
return Vectors.norm(tuple2._2(), 1.0) != 0.0;
}
- });
+ });
assertEquals(topicDistributions.count(), nonEmptyCorpus.count());
// Check: javaTopTopicsPerDocuments
@@ -179,7 +184,7 @@ public class JavaLDASuite implements Serializable {
@Test
public void localLdaMethods() {
- JavaRDD<Tuple2<Long, Vector>> docs = sc.parallelize(toyData, 2);
+ JavaRDD<Tuple2<Long, Vector>> docs = jsc.parallelize(toyData, 2);
JavaPairRDD<Long, Vector> pairedDocs = JavaPairRDD.fromJavaRDD(docs);
// check: topicDistributions
@@ -191,7 +196,7 @@ public class JavaLDASuite implements Serializable {
// check: logLikelihood.
ArrayList<Tuple2<Long, Vector>> docsSingleWord = new ArrayList<>();
docsSingleWord.add(new Tuple2<>(0L, Vectors.dense(1.0, 0.0, 0.0)));
- JavaPairRDD<Long, Vector> single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord));
+ JavaPairRDD<Long, Vector> single = JavaPairRDD.fromJavaRDD(jsc.parallelize(docsSingleWord));
double logLikelihood = toyModel.logLikelihood(single);
}
@@ -199,7 +204,7 @@ public class JavaLDASuite implements Serializable {
private static int tinyVocabSize = LDASuite.tinyVocabSize();
private static Matrix tinyTopics = LDASuite.tinyTopics();
private static Tuple2<int[], double[]>[] tinyTopicDescription =
- LDASuite.tinyTopicDescription();
+ LDASuite.tinyTopicDescription();
private JavaPairRDD<Long, Vector> corpus;
private LocalLDAModel toyModel = LDASuite.toyModel();
private ArrayList<Tuple2<Long, Vector>> toyData = LDASuite.javaToyData();
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
index 62edbd3a29..d1d618f7de 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
@@ -27,8 +27,6 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Test;
-import static org.apache.spark.streaming.JavaTestUtils.*;
-
import org.apache.spark.SparkConf;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
@@ -36,6 +34,7 @@ import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import static org.apache.spark.streaming.JavaTestUtils.*;
public class JavaStreamingKMeansSuite implements Serializable {
diff --git a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
index fa4d334801..6a096d6386 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
@@ -31,27 +31,34 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
public class JavaRankingMetricsSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
private transient JavaRDD<Tuple2<List<Integer>, List<Integer>>> predictionAndLabels;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaRankingMetricsSuite");
- predictionAndLabels = sc.parallelize(Arrays.asList(
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaPCASuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
+
+ predictionAndLabels = jsc.parallelize(Arrays.asList(
Tuple2$.MODULE$.apply(
Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Arrays.asList(1, 2, 3, 4, 5)),
Tuple2$.MODULE$.apply(
- Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)),
+ Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)),
Tuple2$.MODULE$.apply(
- Arrays.asList(1, 2, 3, 4, 5), Arrays.<Integer>asList())), 2);
+ Arrays.asList(1, 2, 3, 4, 5), Arrays.<Integer>asList())), 2);
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
index 8a320afa4b..de50fb8c4f 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
@@ -29,19 +29,25 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.sql.SparkSession;
public class JavaTfIdfSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaTfIdfSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaPCASuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -49,7 +55,7 @@ public class JavaTfIdfSuite implements Serializable {
// The tests are to check Java compatibility.
HashingTF tf = new HashingTF();
@SuppressWarnings("unchecked")
- JavaRDD<List<String>> documents = sc.parallelize(Arrays.asList(
+ JavaRDD<List<String>> documents = jsc.parallelize(Arrays.asList(
Arrays.asList("this is a sentence".split(" ")),
Arrays.asList("this is another sentence".split(" ")),
Arrays.asList("this is still a sentence".split(" "))), 2);
@@ -59,7 +65,7 @@ public class JavaTfIdfSuite implements Serializable {
JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
List<Vector> localTfIdfs = tfIdfs.collect();
int indexOfThis = tf.indexOf("this");
- for (Vector v: localTfIdfs) {
+ for (Vector v : localTfIdfs) {
Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
}
}
@@ -69,7 +75,7 @@ public class JavaTfIdfSuite implements Serializable {
// The tests are to check Java compatibility.
HashingTF tf = new HashingTF();
@SuppressWarnings("unchecked")
- JavaRDD<List<String>> documents = sc.parallelize(Arrays.asList(
+ JavaRDD<List<String>> documents = jsc.parallelize(Arrays.asList(
Arrays.asList("this is a sentence".split(" ")),
Arrays.asList("this is another sentence".split(" ")),
Arrays.asList("this is still a sentence".split(" "))), 2);
@@ -79,7 +85,7 @@ public class JavaTfIdfSuite implements Serializable {
JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
List<Vector> localTfIdfs = tfIdfs.collect();
int indexOfThis = tf.indexOf("this");
- for (Vector v: localTfIdfs) {
+ for (Vector v : localTfIdfs) {
Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
}
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
index e13ed07e28..64885cc842 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
@@ -21,9 +21,10 @@ import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
+import com.google.common.base.Strings;
+
import scala.Tuple2;
-import com.google.common.base.Strings;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
@@ -31,19 +32,25 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
public class JavaWord2VecSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaWord2VecSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaPCASuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -53,7 +60,7 @@ public class JavaWord2VecSuite implements Serializable {
String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10);
List<String> words = Arrays.asList(sentence.split(" "));
List<List<String>> localDoc = Arrays.asList(words, words);
- JavaRDD<List<String>> doc = sc.parallelize(localDoc);
+ JavaRDD<List<String>> doc = jsc.parallelize(localDoc);
Word2Vec word2vec = new Word2Vec()
.setVectorSize(10)
.setSeed(42L);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
index 2bef7a8609..fdc19a5b3d 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
@@ -26,32 +26,37 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
+import org.apache.spark.sql.SparkSession;
public class JavaAssociationRulesSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaFPGrowth");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaAssociationRulesSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
public void runAssociationRules() {
@SuppressWarnings("unchecked")
- JavaRDD<FPGrowth.FreqItemset<String>> freqItemsets = sc.parallelize(Arrays.asList(
- new FreqItemset<String>(new String[] {"a"}, 15L),
- new FreqItemset<String>(new String[] {"b"}, 35L),
- new FreqItemset<String>(new String[] {"a", "b"}, 12L)
+ JavaRDD<FPGrowth.FreqItemset<String>> freqItemsets = jsc.parallelize(Arrays.asList(
+ new FreqItemset<String>(new String[]{"a"}, 15L),
+ new FreqItemset<String>(new String[]{"b"}, 35L),
+ new FreqItemset<String>(new String[]{"a", "b"}, 12L)
));
JavaRDD<AssociationRules.Rule<String>> results = (new AssociationRules()).run(freqItemsets);
}
}
-
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
index 916fff14a7..f235251e61 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
@@ -22,34 +22,41 @@ import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
+import static org.junit.Assert.assertEquals;
+
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
-import static org.junit.Assert.*;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils;
public class JavaFPGrowthSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaFPGrowth");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaFPGrowth")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
public void runFPGrowth() {
@SuppressWarnings("unchecked")
- JavaRDD<List<String>> rdd = sc.parallelize(Arrays.asList(
+ JavaRDD<List<String>> rdd = jsc.parallelize(Arrays.asList(
Arrays.asList("r z h k p".split(" ")),
Arrays.asList("z y x w v u t s".split(" ")),
Arrays.asList("s x o n r".split(" ")),
@@ -65,7 +72,7 @@ public class JavaFPGrowthSuite implements Serializable {
List<FPGrowth.FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect();
assertEquals(18, freqItemsets.size());
- for (FPGrowth.FreqItemset<String> itemset: freqItemsets) {
+ for (FPGrowth.FreqItemset<String> itemset : freqItemsets) {
// Test return types.
List<String> items = itemset.javaItems();
long freq = itemset.freq();
@@ -76,7 +83,7 @@ public class JavaFPGrowthSuite implements Serializable {
public void runFPGrowthSaveLoad() {
@SuppressWarnings("unchecked")
- JavaRDD<List<String>> rdd = sc.parallelize(Arrays.asList(
+ JavaRDD<List<String>> rdd = jsc.parallelize(Arrays.asList(
Arrays.asList("r z h k p".split(" ")),
Arrays.asList("z y x w v u t s".split(" ")),
Arrays.asList("s x o n r".split(" ")),
@@ -94,15 +101,15 @@ public class JavaFPGrowthSuite implements Serializable {
String outputPath = tempDir.getPath();
try {
- model.save(sc.sc(), outputPath);
+ model.save(spark.sparkContext(), outputPath);
@SuppressWarnings("unchecked")
FPGrowthModel<String> newModel =
- (FPGrowthModel<String>) FPGrowthModel.load(sc.sc(), outputPath);
+ (FPGrowthModel<String>) FPGrowthModel.load(spark.sparkContext(), outputPath);
List<FPGrowth.FreqItemset<String>> freqItemsets = newModel.freqItemsets().toJavaRDD()
.collect();
assertEquals(18, freqItemsets.size());
- for (FPGrowth.FreqItemset<String> itemset: freqItemsets) {
+ for (FPGrowth.FreqItemset<String> itemset : freqItemsets) {
// Test return types.
List<String> items = itemset.javaItems();
long freq = itemset.freq();
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
index 8a67793abc..bf7f1fc71b 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
@@ -29,25 +29,31 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils;
public class JavaPrefixSpanSuite {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaPrefixSpan");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaPrefixSpan")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
public void runPrefixSpan() {
- JavaRDD<List<List<Integer>>> sequences = sc.parallelize(Arrays.asList(
+ JavaRDD<List<List<Integer>>> sequences = jsc.parallelize(Arrays.asList(
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)),
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
@@ -61,7 +67,7 @@ public class JavaPrefixSpanSuite {
List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect();
Assert.assertEquals(5, localFreqSeqs.size());
// Check that each frequent sequence could be materialized.
- for (PrefixSpan.FreqSequence<Integer> freqSeq: localFreqSeqs) {
+ for (PrefixSpan.FreqSequence<Integer> freqSeq : localFreqSeqs) {
List<List<Integer>> seq = freqSeq.javaSequence();
long freq = freqSeq.freq();
}
@@ -69,7 +75,7 @@ public class JavaPrefixSpanSuite {
@Test
public void runPrefixSpanSaveLoad() {
- JavaRDD<List<List<Integer>>> sequences = sc.parallelize(Arrays.asList(
+ JavaRDD<List<List<Integer>>> sequences = jsc.parallelize(Arrays.asList(
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)),
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
@@ -85,13 +91,13 @@ public class JavaPrefixSpanSuite {
String outputPath = tempDir.getPath();
try {
- model.save(sc.sc(), outputPath);
- PrefixSpanModel newModel = PrefixSpanModel.load(sc.sc(), outputPath);
+ model.save(spark.sparkContext(), outputPath);
+ PrefixSpanModel newModel = PrefixSpanModel.load(spark.sparkContext(), outputPath);
JavaRDD<FreqSequence<Integer>> freqSeqs = newModel.freqSequences().toJavaRDD();
List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect();
Assert.assertEquals(5, localFreqSeqs.size());
// Check that each frequent sequence could be materialized.
- for (PrefixSpan.FreqSequence<Integer> freqSeq: localFreqSeqs) {
+ for (PrefixSpan.FreqSequence<Integer> freqSeq : localFreqSeqs) {
List<List<Integer>> seq = freqSeq.javaSequence();
long freq = freqSeq.freq();
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
index 8beea102ef..92fc57871c 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
@@ -17,147 +17,149 @@
package org.apache.spark.mllib.linalg;
-import static org.junit.Assert.*;
-import org.junit.Test;
-
import java.io.Serializable;
import java.util.Random;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+import org.junit.Test;
+
public class JavaMatricesSuite implements Serializable {
- @Test
- public void randMatrixConstruction() {
- Random rng = new Random(24);
- Matrix r = Matrices.rand(3, 4, rng);
- rng.setSeed(24);
- DenseMatrix dr = DenseMatrix.rand(3, 4, rng);
- assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
-
- rng.setSeed(24);
- Matrix rn = Matrices.randn(3, 4, rng);
- rng.setSeed(24);
- DenseMatrix drn = DenseMatrix.randn(3, 4, rng);
- assertArrayEquals(rn.toArray(), drn.toArray(), 0.0);
-
- rng.setSeed(24);
- Matrix s = Matrices.sprand(3, 4, 0.5, rng);
- rng.setSeed(24);
- SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng);
- assertArrayEquals(s.toArray(), sr.toArray(), 0.0);
-
- rng.setSeed(24);
- Matrix sn = Matrices.sprandn(3, 4, 0.5, rng);
- rng.setSeed(24);
- SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng);
- assertArrayEquals(sn.toArray(), srn.toArray(), 0.0);
- }
-
- @Test
- public void identityMatrixConstruction() {
- Matrix r = Matrices.eye(2);
- DenseMatrix dr = DenseMatrix.eye(2);
- SparseMatrix sr = SparseMatrix.speye(2);
- assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
- assertArrayEquals(sr.toArray(), dr.toArray(), 0.0);
- assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0);
- }
-
- @Test
- public void diagonalMatrixConstruction() {
- Vector v = Vectors.dense(1.0, 0.0, 2.0);
- Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0});
-
- Matrix m = Matrices.diag(v);
- Matrix sm = Matrices.diag(sv);
- DenseMatrix d = DenseMatrix.diag(v);
- DenseMatrix sd = DenseMatrix.diag(sv);
- SparseMatrix s = SparseMatrix.spdiag(v);
- SparseMatrix ss = SparseMatrix.spdiag(sv);
-
- assertArrayEquals(m.toArray(), sm.toArray(), 0.0);
- assertArrayEquals(d.toArray(), sm.toArray(), 0.0);
- assertArrayEquals(d.toArray(), sd.toArray(), 0.0);
- assertArrayEquals(sd.toArray(), s.toArray(), 0.0);
- assertArrayEquals(s.toArray(), ss.toArray(), 0.0);
- assertArrayEquals(s.values(), ss.values(), 0.0);
- assertEquals(2, s.values().length);
- assertEquals(2, ss.values().length);
- assertEquals(4, s.colPtrs().length);
- assertEquals(4, ss.colPtrs().length);
- }
-
- @Test
- public void zerosMatrixConstruction() {
- Matrix z = Matrices.zeros(2, 2);
- Matrix one = Matrices.ones(2, 2);
- DenseMatrix dz = DenseMatrix.zeros(2, 2);
- DenseMatrix done = DenseMatrix.ones(2, 2);
-
- assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
- assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
- assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
- assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
- }
-
- @Test
- public void sparseDenseConversion() {
- int m = 3;
- int n = 2;
- double[] values = new double[]{1.0, 2.0, 4.0, 5.0};
- double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0};
- int[] colPtrs = new int[]{0, 2, 4};
- int[] rowIndices = new int[]{0, 1, 1, 2};
-
- SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values);
- DenseMatrix deMat1 = new DenseMatrix(m, n, allValues);
-
- SparseMatrix spMat2 = deMat1.toSparse();
- DenseMatrix deMat2 = spMat1.toDense();
-
- assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0);
- assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0);
- }
-
- @Test
- public void concatenateMatrices() {
- int m = 3;
- int n = 2;
-
- Random rng = new Random(42);
- SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng);
- rng.setSeed(42);
- DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng);
- Matrix deMat2 = Matrices.eye(3);
- Matrix spMat2 = Matrices.speye(3);
- Matrix deMat3 = Matrices.eye(2);
- Matrix spMat3 = Matrices.speye(2);
-
- Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2});
- Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2});
- Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2});
- Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2});
-
- assertEquals(3, deHorz1.numRows());
- assertEquals(3, deHorz2.numRows());
- assertEquals(3, deHorz3.numRows());
- assertEquals(3, spHorz.numRows());
- assertEquals(5, deHorz1.numCols());
- assertEquals(5, deHorz2.numCols());
- assertEquals(5, deHorz3.numCols());
- assertEquals(5, spHorz.numCols());
-
- Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3});
- Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3});
- Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3});
- Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3});
-
- assertEquals(5, deVert1.numRows());
- assertEquals(5, deVert2.numRows());
- assertEquals(5, deVert3.numRows());
- assertEquals(5, spVert.numRows());
- assertEquals(2, deVert1.numCols());
- assertEquals(2, deVert2.numCols());
- assertEquals(2, deVert3.numCols());
- assertEquals(2, spVert.numCols());
- }
+ @Test
+ public void randMatrixConstruction() {
+ Random rng = new Random(24);
+ Matrix r = Matrices.rand(3, 4, rng);
+ rng.setSeed(24);
+ DenseMatrix dr = DenseMatrix.rand(3, 4, rng);
+ assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
+
+ rng.setSeed(24);
+ Matrix rn = Matrices.randn(3, 4, rng);
+ rng.setSeed(24);
+ DenseMatrix drn = DenseMatrix.randn(3, 4, rng);
+ assertArrayEquals(rn.toArray(), drn.toArray(), 0.0);
+
+ rng.setSeed(24);
+ Matrix s = Matrices.sprand(3, 4, 0.5, rng);
+ rng.setSeed(24);
+ SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng);
+ assertArrayEquals(s.toArray(), sr.toArray(), 0.0);
+
+ rng.setSeed(24);
+ Matrix sn = Matrices.sprandn(3, 4, 0.5, rng);
+ rng.setSeed(24);
+ SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng);
+ assertArrayEquals(sn.toArray(), srn.toArray(), 0.0);
+ }
+
+ @Test
+ public void identityMatrixConstruction() {
+ Matrix r = Matrices.eye(2);
+ DenseMatrix dr = DenseMatrix.eye(2);
+ SparseMatrix sr = SparseMatrix.speye(2);
+ assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
+ assertArrayEquals(sr.toArray(), dr.toArray(), 0.0);
+ assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0);
+ }
+
+ @Test
+ public void diagonalMatrixConstruction() {
+ Vector v = Vectors.dense(1.0, 0.0, 2.0);
+ Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0});
+
+ Matrix m = Matrices.diag(v);
+ Matrix sm = Matrices.diag(sv);
+ DenseMatrix d = DenseMatrix.diag(v);
+ DenseMatrix sd = DenseMatrix.diag(sv);
+ SparseMatrix s = SparseMatrix.spdiag(v);
+ SparseMatrix ss = SparseMatrix.spdiag(sv);
+
+ assertArrayEquals(m.toArray(), sm.toArray(), 0.0);
+ assertArrayEquals(d.toArray(), sm.toArray(), 0.0);
+ assertArrayEquals(d.toArray(), sd.toArray(), 0.0);
+ assertArrayEquals(sd.toArray(), s.toArray(), 0.0);
+ assertArrayEquals(s.toArray(), ss.toArray(), 0.0);
+ assertArrayEquals(s.values(), ss.values(), 0.0);
+ assertEquals(2, s.values().length);
+ assertEquals(2, ss.values().length);
+ assertEquals(4, s.colPtrs().length);
+ assertEquals(4, ss.colPtrs().length);
+ }
+
+ @Test
+ public void zerosMatrixConstruction() {
+ Matrix z = Matrices.zeros(2, 2);
+ Matrix one = Matrices.ones(2, 2);
+ DenseMatrix dz = DenseMatrix.zeros(2, 2);
+ DenseMatrix done = DenseMatrix.ones(2, 2);
+
+ assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
+ assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
+ assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
+ assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
+ }
+
+ @Test
+ public void sparseDenseConversion() {
+ int m = 3;
+ int n = 2;
+ double[] values = new double[]{1.0, 2.0, 4.0, 5.0};
+ double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0};
+ int[] colPtrs = new int[]{0, 2, 4};
+ int[] rowIndices = new int[]{0, 1, 1, 2};
+
+ SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values);
+ DenseMatrix deMat1 = new DenseMatrix(m, n, allValues);
+
+ SparseMatrix spMat2 = deMat1.toSparse();
+ DenseMatrix deMat2 = spMat1.toDense();
+
+ assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0);
+ assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0);
+ }
+
+ @Test
+ public void concatenateMatrices() {
+ int m = 3;
+ int n = 2;
+
+ Random rng = new Random(42);
+ SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng);
+ rng.setSeed(42);
+ DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng);
+ Matrix deMat2 = Matrices.eye(3);
+ Matrix spMat2 = Matrices.speye(3);
+ Matrix deMat3 = Matrices.eye(2);
+ Matrix spMat3 = Matrices.speye(2);
+
+ Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2});
+ Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2});
+ Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2});
+ Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2});
+
+ assertEquals(3, deHorz1.numRows());
+ assertEquals(3, deHorz2.numRows());
+ assertEquals(3, deHorz3.numRows());
+ assertEquals(3, spHorz.numRows());
+ assertEquals(5, deHorz1.numCols());
+ assertEquals(5, deHorz2.numCols());
+ assertEquals(5, deHorz3.numCols());
+ assertEquals(5, spHorz.numCols());
+
+ Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3});
+ Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3});
+ Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3});
+ Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3});
+
+ assertEquals(5, deVert1.numRows());
+ assertEquals(5, deVert2.numRows());
+ assertEquals(5, deVert3.numRows());
+ assertEquals(5, spVert.numRows());
+ assertEquals(2, deVert1.numCols());
+ assertEquals(2, deVert2.numCols());
+ assertEquals(2, deVert3.numCols());
+ assertEquals(2, spVert.numCols());
+ }
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
index 4ba8e543a9..817b962c75 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
@@ -20,10 +20,11 @@ package org.apache.spark.mllib.linalg;
import java.io.Serializable;
import java.util.Arrays;
+import static org.junit.Assert.assertArrayEquals;
+
import scala.Tuple2;
import org.junit.Test;
-import static org.junit.Assert.*;
public class JavaVectorsSuite implements Serializable {
@@ -37,8 +38,8 @@ public class JavaVectorsSuite implements Serializable {
public void sparseArrayConstruction() {
@SuppressWarnings("unchecked")
Vector v = Vectors.sparse(3, Arrays.asList(
- new Tuple2<>(0, 2.0),
- new Tuple2<>(2, 3.0)));
+ new Tuple2<>(0, 2.0),
+ new Tuple2<>(2, 3.0)));
assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0);
}
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
index be58691f4d..b449108a9b 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
@@ -20,29 +20,35 @@ package org.apache.spark.mllib.random;
import java.io.Serializable;
import java.util.Arrays;
-import org.apache.spark.api.java.JavaRDD;
-import org.junit.Assert;
import org.junit.After;
+import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaDoubleRDD;
+import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.sql.SparkSession;
import static org.apache.spark.mllib.random.RandomRDDs.*;
public class JavaRandomRDDsSuite {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaRandomRDDsSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaRandomRDDsSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -50,10 +56,10 @@ public class JavaRandomRDDsSuite {
long m = 1000L;
int p = 2;
long seed = 1L;
- JavaDoubleRDD rdd1 = uniformJavaRDD(sc, m);
- JavaDoubleRDD rdd2 = uniformJavaRDD(sc, m, p);
- JavaDoubleRDD rdd3 = uniformJavaRDD(sc, m, p, seed);
- for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaDoubleRDD rdd1 = uniformJavaRDD(jsc, m);
+ JavaDoubleRDD rdd2 = uniformJavaRDD(jsc, m, p);
+ JavaDoubleRDD rdd3 = uniformJavaRDD(jsc, m, p, seed);
+ for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
}
}
@@ -63,10 +69,10 @@ public class JavaRandomRDDsSuite {
long m = 1000L;
int p = 2;
long seed = 1L;
- JavaDoubleRDD rdd1 = normalJavaRDD(sc, m);
- JavaDoubleRDD rdd2 = normalJavaRDD(sc, m, p);
- JavaDoubleRDD rdd3 = normalJavaRDD(sc, m, p, seed);
- for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaDoubleRDD rdd1 = normalJavaRDD(jsc, m);
+ JavaDoubleRDD rdd2 = normalJavaRDD(jsc, m, p);
+ JavaDoubleRDD rdd3 = normalJavaRDD(jsc, m, p, seed);
+ for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
}
}
@@ -78,10 +84,10 @@ public class JavaRandomRDDsSuite {
long m = 1000L;
int p = 2;
long seed = 1L;
- JavaDoubleRDD rdd1 = logNormalJavaRDD(sc, mean, std, m);
- JavaDoubleRDD rdd2 = logNormalJavaRDD(sc, mean, std, m, p);
- JavaDoubleRDD rdd3 = logNormalJavaRDD(sc, mean, std, m, p, seed);
- for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaDoubleRDD rdd1 = logNormalJavaRDD(jsc, mean, std, m);
+ JavaDoubleRDD rdd2 = logNormalJavaRDD(jsc, mean, std, m, p);
+ JavaDoubleRDD rdd3 = logNormalJavaRDD(jsc, mean, std, m, p, seed);
+ for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
}
}
@@ -92,10 +98,10 @@ public class JavaRandomRDDsSuite {
long m = 1000L;
int p = 2;
long seed = 1L;
- JavaDoubleRDD rdd1 = poissonJavaRDD(sc, mean, m);
- JavaDoubleRDD rdd2 = poissonJavaRDD(sc, mean, m, p);
- JavaDoubleRDD rdd3 = poissonJavaRDD(sc, mean, m, p, seed);
- for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaDoubleRDD rdd1 = poissonJavaRDD(jsc, mean, m);
+ JavaDoubleRDD rdd2 = poissonJavaRDD(jsc, mean, m, p);
+ JavaDoubleRDD rdd3 = poissonJavaRDD(jsc, mean, m, p, seed);
+ for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
}
}
@@ -106,10 +112,10 @@ public class JavaRandomRDDsSuite {
long m = 1000L;
int p = 2;
long seed = 1L;
- JavaDoubleRDD rdd1 = exponentialJavaRDD(sc, mean, m);
- JavaDoubleRDD rdd2 = exponentialJavaRDD(sc, mean, m, p);
- JavaDoubleRDD rdd3 = exponentialJavaRDD(sc, mean, m, p, seed);
- for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaDoubleRDD rdd1 = exponentialJavaRDD(jsc, mean, m);
+ JavaDoubleRDD rdd2 = exponentialJavaRDD(jsc, mean, m, p);
+ JavaDoubleRDD rdd3 = exponentialJavaRDD(jsc, mean, m, p, seed);
+ for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
}
}
@@ -117,14 +123,14 @@ public class JavaRandomRDDsSuite {
@Test
public void testGammaRDD() {
double shape = 1.0;
- double scale = 2.0;
+ double jscale = 2.0;
long m = 1000L;
int p = 2;
long seed = 1L;
- JavaDoubleRDD rdd1 = gammaJavaRDD(sc, shape, scale, m);
- JavaDoubleRDD rdd2 = gammaJavaRDD(sc, shape, scale, m, p);
- JavaDoubleRDD rdd3 = gammaJavaRDD(sc, shape, scale, m, p, seed);
- for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaDoubleRDD rdd1 = gammaJavaRDD(jsc, shape, jscale, m);
+ JavaDoubleRDD rdd2 = gammaJavaRDD(jsc, shape, jscale, m, p);
+ JavaDoubleRDD rdd3 = gammaJavaRDD(jsc, shape, jscale, m, p, seed);
+ for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
}
}
@@ -137,10 +143,10 @@ public class JavaRandomRDDsSuite {
int n = 10;
int p = 2;
long seed = 1L;
- JavaRDD<Vector> rdd1 = uniformJavaVectorRDD(sc, m, n);
- JavaRDD<Vector> rdd2 = uniformJavaVectorRDD(sc, m, n, p);
- JavaRDD<Vector> rdd3 = uniformJavaVectorRDD(sc, m, n, p, seed);
- for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaRDD<Vector> rdd1 = uniformJavaVectorRDD(jsc, m, n);
+ JavaRDD<Vector> rdd2 = uniformJavaVectorRDD(jsc, m, n, p);
+ JavaRDD<Vector> rdd3 = uniformJavaVectorRDD(jsc, m, n, p, seed);
+ for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@@ -153,10 +159,10 @@ public class JavaRandomRDDsSuite {
int n = 10;
int p = 2;
long seed = 1L;
- JavaRDD<Vector> rdd1 = normalJavaVectorRDD(sc, m, n);
- JavaRDD<Vector> rdd2 = normalJavaVectorRDD(sc, m, n, p);
- JavaRDD<Vector> rdd3 = normalJavaVectorRDD(sc, m, n, p, seed);
- for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaRDD<Vector> rdd1 = normalJavaVectorRDD(jsc, m, n);
+ JavaRDD<Vector> rdd2 = normalJavaVectorRDD(jsc, m, n, p);
+ JavaRDD<Vector> rdd3 = normalJavaVectorRDD(jsc, m, n, p, seed);
+ for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@@ -171,10 +177,10 @@ public class JavaRandomRDDsSuite {
int n = 10;
int p = 2;
long seed = 1L;
- JavaRDD<Vector> rdd1 = logNormalJavaVectorRDD(sc, mean, std, m, n);
- JavaRDD<Vector> rdd2 = logNormalJavaVectorRDD(sc, mean, std, m, n, p);
- JavaRDD<Vector> rdd3 = logNormalJavaVectorRDD(sc, mean, std, m, n, p, seed);
- for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaRDD<Vector> rdd1 = logNormalJavaVectorRDD(jsc, mean, std, m, n);
+ JavaRDD<Vector> rdd2 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p);
+ JavaRDD<Vector> rdd3 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p, seed);
+ for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@@ -188,10 +194,10 @@ public class JavaRandomRDDsSuite {
int n = 10;
int p = 2;
long seed = 1L;
- JavaRDD<Vector> rdd1 = poissonJavaVectorRDD(sc, mean, m, n);
- JavaRDD<Vector> rdd2 = poissonJavaVectorRDD(sc, mean, m, n, p);
- JavaRDD<Vector> rdd3 = poissonJavaVectorRDD(sc, mean, m, n, p, seed);
- for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaRDD<Vector> rdd1 = poissonJavaVectorRDD(jsc, mean, m, n);
+ JavaRDD<Vector> rdd2 = poissonJavaVectorRDD(jsc, mean, m, n, p);
+ JavaRDD<Vector> rdd3 = poissonJavaVectorRDD(jsc, mean, m, n, p, seed);
+ for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@@ -205,10 +211,10 @@ public class JavaRandomRDDsSuite {
int n = 10;
int p = 2;
long seed = 1L;
- JavaRDD<Vector> rdd1 = exponentialJavaVectorRDD(sc, mean, m, n);
- JavaRDD<Vector> rdd2 = exponentialJavaVectorRDD(sc, mean, m, n, p);
- JavaRDD<Vector> rdd3 = exponentialJavaVectorRDD(sc, mean, m, n, p, seed);
- for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaRDD<Vector> rdd1 = exponentialJavaVectorRDD(jsc, mean, m, n);
+ JavaRDD<Vector> rdd2 = exponentialJavaVectorRDD(jsc, mean, m, n, p);
+ JavaRDD<Vector> rdd3 = exponentialJavaVectorRDD(jsc, mean, m, n, p, seed);
+ for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@@ -218,15 +224,15 @@ public class JavaRandomRDDsSuite {
@SuppressWarnings("unchecked")
public void testGammaVectorRDD() {
double shape = 1.0;
- double scale = 2.0;
+ double jscale = 2.0;
long m = 100L;
int n = 10;
int p = 2;
long seed = 1L;
- JavaRDD<Vector> rdd1 = gammaJavaVectorRDD(sc, shape, scale, m, n);
- JavaRDD<Vector> rdd2 = gammaJavaVectorRDD(sc, shape, scale, m, n, p);
- JavaRDD<Vector> rdd3 = gammaJavaVectorRDD(sc, shape, scale, m, n, p, seed);
- for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaRDD<Vector> rdd1 = gammaJavaVectorRDD(jsc, shape, jscale, m, n);
+ JavaRDD<Vector> rdd2 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p);
+ JavaRDD<Vector> rdd3 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p, seed);
+ for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@@ -238,10 +244,10 @@ public class JavaRandomRDDsSuite {
long seed = 1L;
int numPartitions = 0;
StringGenerator gen = new StringGenerator();
- JavaRDD<String> rdd1 = randomJavaRDD(sc, gen, size);
- JavaRDD<String> rdd2 = randomJavaRDD(sc, gen, size, numPartitions);
- JavaRDD<String> rdd3 = randomJavaRDD(sc, gen, size, numPartitions, seed);
- for (JavaRDD<String> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaRDD<String> rdd1 = randomJavaRDD(jsc, gen, size);
+ JavaRDD<String> rdd2 = randomJavaRDD(jsc, gen, size, numPartitions);
+ JavaRDD<String> rdd3 = randomJavaRDD(jsc, gen, size, numPartitions, seed);
+ for (JavaRDD<String> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(size, rdd.count());
Assert.assertEquals(2, rdd.first().length());
}
@@ -255,10 +261,10 @@ public class JavaRandomRDDsSuite {
int n = 10;
int p = 2;
long seed = 1L;
- JavaRDD<Vector> rdd1 = randomJavaVectorRDD(sc, generator, m, n);
- JavaRDD<Vector> rdd2 = randomJavaVectorRDD(sc, generator, m, n, p);
- JavaRDD<Vector> rdd3 = randomJavaVectorRDD(sc, generator, m, n, p, seed);
- for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ JavaRDD<Vector> rdd1 = randomJavaVectorRDD(jsc, generator, m, n);
+ JavaRDD<Vector> rdd2 = randomJavaVectorRDD(jsc, generator, m, n, p);
+ JavaRDD<Vector> rdd3 = randomJavaVectorRDD(jsc, generator, m, n, p, seed);
+ for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@@ -271,10 +277,12 @@ class StringGenerator implements RandomDataGenerator<String>, Serializable {
public String nextValue() {
return "42";
}
+
@Override
public StringGenerator copy() {
return new StringGenerator();
}
+
@Override
public void setSeed(long seed) {
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
index d0bf7f556d..aa784054d5 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
@@ -32,40 +32,46 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
public class JavaALSSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaALS");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaALS")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
private void validatePrediction(
- MatrixFactorizationModel model,
- int users,
- int products,
- double[] trueRatings,
- double matchThreshold,
- boolean implicitPrefs,
- double[] truePrefs) {
+ MatrixFactorizationModel model,
+ int users,
+ int products,
+ double[] trueRatings,
+ double matchThreshold,
+ boolean implicitPrefs,
+ double[] truePrefs) {
List<Tuple2<Integer, Integer>> localUsersProducts = new ArrayList<>(users * products);
- for (int u=0; u < users; ++u) {
- for (int p=0; p < products; ++p) {
+ for (int u = 0; u < users; ++u) {
+ for (int p = 0; p < products; ++p) {
localUsersProducts.add(new Tuple2<>(u, p));
}
}
- JavaPairRDD<Integer, Integer> usersProducts = sc.parallelizePairs(localUsersProducts);
+ JavaPairRDD<Integer, Integer> usersProducts = jsc.parallelizePairs(localUsersProducts);
List<Rating> predictedRatings = model.predict(usersProducts).collect();
Assert.assertEquals(users * products, predictedRatings.size());
if (!implicitPrefs) {
- for (Rating r: predictedRatings) {
+ for (Rating r : predictedRatings) {
double prediction = r.rating();
double correct = trueRatings[r.product() * users + r.user()];
Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
@@ -76,7 +82,7 @@ public class JavaALSSuite implements Serializable {
// (ref Mahout's implicit ALS tests)
double sqErr = 0.0;
double denom = 0.0;
- for (Rating r: predictedRatings) {
+ for (Rating r : predictedRatings) {
double prediction = r.rating();
double truePref = truePrefs[r.product() * users + r.user()];
double confidence = 1.0 +
@@ -98,9 +104,9 @@ public class JavaALSSuite implements Serializable {
int users = 50;
int products = 100;
Tuple3<List<Rating>, double[], double[]> testData =
- ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
+ ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
- JavaRDD<Rating> data = sc.parallelize(testData._1());
+ JavaRDD<Rating> data = jsc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3());
}
@@ -112,9 +118,9 @@ public class JavaALSSuite implements Serializable {
int users = 100;
int products = 200;
Tuple3<List<Rating>, double[], double[]> testData =
- ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
+ ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
- JavaRDD<Rating> data = sc.parallelize(testData._1());
+ JavaRDD<Rating> data = jsc.parallelize(testData._1());
MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations)
@@ -129,9 +135,9 @@ public class JavaALSSuite implements Serializable {
int users = 80;
int products = 160;
Tuple3<List<Rating>, double[], double[]> testData =
- ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
+ ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
- JavaRDD<Rating> data = sc.parallelize(testData._1());
+ JavaRDD<Rating> data = jsc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
}
@@ -143,9 +149,9 @@ public class JavaALSSuite implements Serializable {
int users = 100;
int products = 200;
Tuple3<List<Rating>, double[], double[]> testData =
- ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
+ ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
- JavaRDD<Rating> data = sc.parallelize(testData._1());
+ JavaRDD<Rating> data = jsc.parallelize(testData._1());
MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations)
@@ -161,9 +167,9 @@ public class JavaALSSuite implements Serializable {
int users = 80;
int products = 160;
Tuple3<List<Rating>, double[], double[]> testData =
- ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, true);
+ ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, true);
- JavaRDD<Rating> data = sc.parallelize(testData._1());
+ JavaRDD<Rating> data = jsc.parallelize(testData._1());
MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations)
.setImplicitPrefs(true)
@@ -179,8 +185,8 @@ public class JavaALSSuite implements Serializable {
int users = 200;
int products = 50;
List<Rating> testData = ALSSuite.generateRatingsAsJava(
- users, products, features, 0.7, true, false)._1();
- JavaRDD<Rating> data = sc.parallelize(testData);
+ users, products, features, 0.7, true, false)._1();
+ JavaRDD<Rating> data = jsc.parallelize(testData);
MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations)
.setImplicitPrefs(true)
@@ -193,7 +199,7 @@ public class JavaALSSuite implements Serializable {
private static void validateRecommendations(Rating[] recommendations, int howMany) {
Assert.assertEquals(howMany, recommendations.length);
for (int i = 1; i < recommendations.length; i++) {
- Assert.assertTrue(recommendations[i-1].rating() >= recommendations[i].rating());
+ Assert.assertTrue(recommendations[i - 1].rating() >= recommendations[i].rating());
}
Assert.assertTrue(recommendations[0].rating() > 0.7);
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java
index 3db9b39e74..8b05675d65 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java
@@ -32,15 +32,17 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
public class JavaIsotonicRegressionSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
private static List<Tuple3<Double, Double, Double>> generateIsotonicInput(double[] labels) {
List<Tuple3<Double, Double, Double>> input = new ArrayList<>(labels.length);
for (int i = 1; i <= labels.length; i++) {
- input.add(new Tuple3<>(labels[i-1], (double) i, 1.0));
+ input.add(new Tuple3<>(labels[i - 1], (double) i, 1.0));
}
return input;
@@ -48,20 +50,24 @@ public class JavaIsotonicRegressionSuite implements Serializable {
private IsotonicRegressionModel runIsotonicRegression(double[] labels) {
JavaRDD<Tuple3<Double, Double, Double>> trainRDD =
- sc.parallelize(generateIsotonicInput(labels), 2).cache();
+ jsc.parallelize(generateIsotonicInput(labels), 2).cache();
return new IsotonicRegression().run(trainRDD);
}
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLinearRegressionSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -70,7 +76,7 @@ public class JavaIsotonicRegressionSuite implements Serializable {
runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
Assert.assertArrayEquals(
- new double[] {1, 2, 7.0/3, 7.0/3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1.0e-14);
+ new double[]{1, 2, 7.0 / 3, 7.0 / 3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1.0e-14);
}
@Test
@@ -78,7 +84,7 @@ public class JavaIsotonicRegressionSuite implements Serializable {
IsotonicRegressionModel model =
runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
- JavaDoubleRDD testRDD = sc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0));
+ JavaDoubleRDD testRDD = jsc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0));
List<Double> predictions = model.predict(testRDD).collect();
Assert.assertEquals(1.0, predictions.get(0).doubleValue(), 1.0e-14);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
index 8950b48888..098bac3bed 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
@@ -28,24 +28,30 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.util.LinearDataGenerator;
+import org.apache.spark.sql.SparkSession;
public class JavaLassoSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaLassoSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLassoSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
int validatePrediction(List<LabeledPoint> validationData, LassoModel model) {
int numAccurate = 0;
- for (LabeledPoint point: validationData) {
+ for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features());
// A prediction is off if the prediction is more than 0.5 away from expected value.
if (Math.abs(prediction - point.label()) <= 0.5) {
@@ -61,15 +67,15 @@ public class JavaLassoSuite implements Serializable {
double A = 0.0;
double[] weights = {-1.5, 1.0e-2};
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
- weights, nPoints, 42, 0.1), 2).cache();
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
+ weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData =
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LassoWithSGD lassoSGDImpl = new LassoWithSGD();
lassoSGDImpl.optimizer().setStepSize(1.0)
- .setRegParam(0.01)
- .setNumIterations(20);
+ .setRegParam(0.01)
+ .setNumIterations(20);
LassoModel model = lassoSGDImpl.run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model);
@@ -82,10 +88,10 @@ public class JavaLassoSuite implements Serializable {
double A = 0.0;
double[] weights = {-1.5, 1.0e-2};
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
- weights, nPoints, 42, 0.1), 2).cache();
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
+ weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData =
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
index 24c4c20d9a..35087a5e46 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
@@ -25,34 +25,40 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.util.LinearDataGenerator;
+import org.apache.spark.sql.SparkSession;
public class JavaLinearRegressionSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaLinearRegressionSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
int validatePrediction(List<LabeledPoint> validationData, LinearRegressionModel model) {
int numAccurate = 0;
- for (LabeledPoint point: validationData) {
- Double prediction = model.predict(point.features());
- // A prediction is off if the prediction is more than 0.5 away from expected value.
- if (Math.abs(prediction - point.label()) <= 0.5) {
- numAccurate++;
- }
+ for (LabeledPoint point : validationData) {
+ Double prediction = model.predict(point.features());
+ // A prediction is off if the prediction is more than 0.5 away from expected value.
+ if (Math.abs(prediction - point.label()) <= 0.5) {
+ numAccurate++;
+ }
}
return numAccurate;
}
@@ -63,10 +69,10 @@ public class JavaLinearRegressionSuite implements Serializable {
double A = 3.0;
double[] weights = {10, 10};
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData =
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
linSGDImpl.setIntercept(true);
@@ -82,10 +88,10 @@ public class JavaLinearRegressionSuite implements Serializable {
double A = 0.0;
double[] weights = {10, 10};
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData =
- LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100);
@@ -98,7 +104,7 @@ public class JavaLinearRegressionSuite implements Serializable {
int nPoints = 100;
double A = 0.0;
double[] weights = {10, 10};
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
index c56db703ea..b2efb2e72e 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
@@ -29,25 +29,31 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.util.LinearDataGenerator;
+import org.apache.spark.sql.SparkSession;
public class JavaRidgeRegressionSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaRidgeRegressionSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
private static double predictionError(List<LabeledPoint> validationData,
RidgeRegressionModel model) {
double errorSum = 0;
- for (LabeledPoint point: validationData) {
+ for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features());
errorSum += (prediction - point.label()) * (prediction - point.label());
}
@@ -68,9 +74,9 @@ public class JavaRidgeRegressionSuite implements Serializable {
public void runRidgeRegressionUsingConstructor() {
int numExamples = 50;
int numFeatures = 20;
- List<LabeledPoint> data = generateRidgeData(2*numExamples, numFeatures, 10.0);
+ List<LabeledPoint> data = generateRidgeData(2 * numExamples, numFeatures, 10.0);
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, numExamples));
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(data.subList(0, numExamples));
List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD();
@@ -94,7 +100,7 @@ public class JavaRidgeRegressionSuite implements Serializable {
int numFeatures = 20;
List<LabeledPoint> data = generateRidgeData(2 * numExamples, numFeatures, 10.0);
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, numExamples));
+ JavaRDD<LabeledPoint> testRDD = jsc.parallelize(data.subList(0, numExamples));
List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
index 5f1d5987e8..373417d3ba 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
@@ -24,13 +24,11 @@ import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
-
-import static org.apache.spark.streaming.JavaTestUtils.*;
import static org.junit.Assert.assertEquals;
import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaDoubleRDD;
+import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
@@ -38,36 +36,42 @@ import org.apache.spark.mllib.stat.test.BinarySample;
import org.apache.spark.mllib.stat.test.ChiSqTestResult;
import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult;
import org.apache.spark.mllib.stat.test.StreamingTest;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import static org.apache.spark.streaming.JavaTestUtils.*;
public class JavaStatisticsSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
private transient JavaStreamingContext ssc;
@Before
public void setUp() {
SparkConf conf = new SparkConf()
- .setMaster("local[2]")
- .setAppName("JavaStatistics")
.set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
- sc = new JavaSparkContext(conf);
- ssc = new JavaStreamingContext(sc, new Duration(1000));
+ spark = SparkSession.builder()
+ .master("local[2]")
+ .appName("JavaStatistics")
+ .config(conf)
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
+ ssc = new JavaStreamingContext(jsc, new Duration(1000));
ssc.checkpoint("checkpoint");
}
@After
public void tearDown() {
+ spark.stop();
ssc.stop();
- ssc = null;
- sc = null;
+ spark = null;
}
@Test
public void testCorr() {
- JavaRDD<Double> x = sc.parallelize(Arrays.asList(1.0, 2.0, 3.0, 4.0));
- JavaRDD<Double> y = sc.parallelize(Arrays.asList(1.1, 2.2, 3.1, 4.3));
+ JavaRDD<Double> x = jsc.parallelize(Arrays.asList(1.0, 2.0, 3.0, 4.0));
+ JavaRDD<Double> y = jsc.parallelize(Arrays.asList(1.1, 2.2, 3.1, 4.3));
Double corr1 = Statistics.corr(x, y);
Double corr2 = Statistics.corr(x, y, "pearson");
@@ -77,7 +81,7 @@ public class JavaStatisticsSuite implements Serializable {
@Test
public void kolmogorovSmirnovTest() {
- JavaDoubleRDD data = sc.parallelizeDoubles(Arrays.asList(0.2, 1.0, -1.0, 2.0));
+ JavaDoubleRDD data = jsc.parallelizeDoubles(Arrays.asList(0.2, 1.0, -1.0, 2.0));
KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm");
KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest(
data, "norm", 0.0, 1.0);
@@ -85,7 +89,7 @@ public class JavaStatisticsSuite implements Serializable {
@Test
public void chiSqTest() {
- JavaRDD<LabeledPoint> data = sc.parallelize(Arrays.asList(
+ JavaRDD<LabeledPoint> data = jsc.parallelize(Arrays.asList(
new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)),
new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)),
new LabeledPoint(0.0, Vectors.dense(2.4, 8.1))));
diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
index 60585d2727..5b464a4722 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
@@ -35,25 +35,31 @@ import org.apache.spark.mllib.tree.configuration.Algo;
import org.apache.spark.mllib.tree.configuration.Strategy;
import org.apache.spark.mllib.tree.impurity.Gini;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
+import org.apache.spark.sql.SparkSession;
public class JavaDecisionTreeSuite implements Serializable {
- private transient JavaSparkContext sc;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaDecisionTreeSuite");
+ spark = SparkSession.builder()
+ .master("local")
+ .appName("JavaDecisionTreeSuite")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sc.stop();
- sc = null;
+ spark.stop();
+ spark = null;
}
int validatePrediction(List<LabeledPoint> validationData, DecisionTreeModel model) {
int numCorrect = 0;
- for (LabeledPoint point: validationData) {
+ for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features());
if (prediction == point.label()) {
numCorrect++;
@@ -65,7 +71,7 @@ public class JavaDecisionTreeSuite implements Serializable {
@Test
public void runDTUsingConstructor() {
List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
- JavaRDD<LabeledPoint> rdd = sc.parallelize(arr);
+ JavaRDD<LabeledPoint> rdd = jsc.parallelize(arr);
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories
@@ -73,7 +79,7 @@ public class JavaDecisionTreeSuite implements Serializable {
int numClasses = 2;
int maxBins = 100;
Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses,
- maxBins, categoricalFeaturesInfo);
+ maxBins, categoricalFeaturesInfo);
DecisionTree learner = new DecisionTree(strategy);
DecisionTreeModel model = learner.run(rdd.rdd());
@@ -85,7 +91,7 @@ public class JavaDecisionTreeSuite implements Serializable {
@Test
public void runDTUsingStaticMethods() {
List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
- JavaRDD<LabeledPoint> rdd = sc.parallelize(arr);
+ JavaRDD<LabeledPoint> rdd = jsc.parallelize(arr);
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories
@@ -93,7 +99,7 @@ public class JavaDecisionTreeSuite implements Serializable {
int numClasses = 2;
int maxBins = 100;
Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses,
- maxBins, categoricalFeaturesInfo);
+ maxBins, categoricalFeaturesInfo);
DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy);
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 1de638f245..55448325e4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -183,7 +183,7 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
test("pipeline validateParams") {
- val df = sqlContext.createDataFrame(
+ val df = spark.createDataFrame(
Seq(
(1, Vectors.dense(0.0, 1.0, 4.0), 1.0),
(2, Vectors.dense(1.0, 0.0, 4.0), 2.0),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
index 89afb94b0f..98116656ba 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
@@ -32,7 +32,7 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
test("extractLabeledPoints") {
def getTestData(labels: Seq[Double]): DataFrame = {
val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
- sqlContext.createDataFrame(data)
+ spark.createDataFrame(data)
}
val c = new MockClassifier
@@ -72,7 +72,7 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
test("getNumClasses") {
def getTestData(labels: Seq[Double]): DataFrame = {
val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
- sqlContext.createDataFrame(data)
+ spark.createDataFrame(data)
}
val c = new MockClassifier
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 29845b5554..f94d336df5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -337,13 +337,13 @@ class DecisionTreeClassifierSuite
test("should support all NumericType labels and not support other types") {
val dt = new DecisionTreeClassifier().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier](
- dt, isClassification = true, sqlContext) { (expected, actual) =>
+ dt, isClassification = true, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
test("Fitting without numClasses in metadata") {
- val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc))
+ val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc))
val dt = new DecisionTreeClassifier().setMaxDepth(1)
dt.fit(df)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 087e201234..c9453aaec2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -106,7 +106,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
test("should support all NumericType labels and not support other types") {
val gbt = new GBTClassifier().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier](
- gbt, isClassification = true, sqlContext) { (expected, actual) =>
+ gbt, isClassification = true, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
@@ -130,7 +130,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
*/
test("Fitting without numClasses in metadata") {
- val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc))
+ val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc))
val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1)
gbt.fit(df)
}
@@ -138,7 +138,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
test("extractLabeledPoints with bad data") {
def getTestData(labels: Seq[Double]): DataFrame = {
val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
- sqlContext.createDataFrame(data)
+ spark.createDataFrame(data)
}
val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 73e961dbbc..cb4d087ce5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -42,7 +42,7 @@ class LogisticRegressionSuite
override def beforeAll(): Unit = {
super.beforeAll()
- dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42))
+ dataset = spark.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42))
binaryDataset = {
val nPoints = 10000
@@ -54,7 +54,7 @@ class LogisticRegressionSuite
generateMultinomialLogisticInput(coefficients, xMean, xVariance,
addIntercept = true, nPoints, 42)
- sqlContext.createDataFrame(sc.parallelize(testData, 4))
+ spark.createDataFrame(sc.parallelize(testData, 4))
}
}
@@ -202,7 +202,7 @@ class LogisticRegressionSuite
}
test("logistic regression: Predictor, Classifier methods") {
- val sqlContext = this.sqlContext
+ val spark = this.spark
val lr = new LogisticRegression
val model = lr.fit(dataset)
@@ -864,8 +864,8 @@ class LogisticRegressionSuite
}
}
- (sqlContext.createDataFrame(sc.parallelize(data1, 4)),
- sqlContext.createDataFrame(sc.parallelize(data2, 4)))
+ (spark.createDataFrame(sc.parallelize(data1, 4)),
+ spark.createDataFrame(sc.parallelize(data2, 4)))
}
val trainer1a = (new LogisticRegression).setFitIntercept(true)
@@ -938,7 +938,7 @@ class LogisticRegressionSuite
test("should support all NumericType labels and not support other types") {
val lr = new LogisticRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression](
- lr, isClassification = true, sqlContext) { (expected, actual) =>
+ lr, isClassification = true, spark) { (expected, actual) =>
assert(expected.intercept === actual.intercept)
assert(expected.coefficients.toArray === actual.coefficients.toArray)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index f41db31f1e..876e047db5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -36,7 +36,7 @@ class MultilayerPerceptronClassifierSuite
override def beforeAll(): Unit = {
super.beforeAll()
- dataset = sqlContext.createDataFrame(Seq(
+ dataset = spark.createDataFrame(Seq(
(Vectors.dense(0.0, 0.0), 0.0),
(Vectors.dense(0.0, 1.0), 1.0),
(Vectors.dense(1.0, 0.0), 1.0),
@@ -77,7 +77,7 @@ class MultilayerPerceptronClassifierSuite
}
test("Test setWeights by training restart") {
- val dataFrame = sqlContext.createDataFrame(Seq(
+ val dataFrame = spark.createDataFrame(Seq(
(Vectors.dense(0.0, 0.0), 0.0),
(Vectors.dense(0.0, 1.0), 1.0),
(Vectors.dense(1.0, 0.0), 1.0),
@@ -113,7 +113,7 @@ class MultilayerPerceptronClassifierSuite
// the input seed is somewhat magic, to make this test pass
val rdd = sc.parallelize(generateMultinomialLogisticInput(
coefficients, xMean, xVariance, true, nPoints, 1), 2)
- val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features")
+ val dataFrame = spark.createDataFrame(rdd).toDF("label", "features")
val numClasses = 3
val numIterations = 100
val layers = Array[Int](4, 5, 4, numClasses)
@@ -169,7 +169,7 @@ class MultilayerPerceptronClassifierSuite
val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1)
MLTestingUtils.checkNumericTypes[
MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier](
- mpc, isClassification = true, sqlContext) { (expected, actual) =>
+ mpc, isClassification = true, spark) { (expected, actual) =>
assert(expected.layers === actual.layers)
assert(expected.weights === actual.weights)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 80a46fc70c..15d0358c3f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -43,7 +43,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
Array(0.10, 0.10, 0.70, 0.10) // label 2
).map(_.map(math.log))
- dataset = sqlContext.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42))
+ dataset = spark.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42))
}
def validatePrediction(predictionAndLabels: DataFrame): Unit = {
@@ -127,7 +127,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val pi = Vectors.dense(piArray)
val theta = new DenseMatrix(3, 4, thetaArray.flatten, true)
- val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
+ val testDataset = spark.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 42, "multinomial"))
val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial")
val model = nb.fit(testDataset)
@@ -135,7 +135,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
validateModelFit(pi, theta, model)
assert(model.hasParent)
- val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
+ val validationDataset = spark.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 17, "multinomial"))
val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
@@ -157,7 +157,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val pi = Vectors.dense(piArray)
val theta = new DenseMatrix(3, 12, thetaArray.flatten, true)
- val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
+ val testDataset = spark.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 45, "bernoulli"))
val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli")
val model = nb.fit(testDataset)
@@ -165,7 +165,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
validateModelFit(pi, theta, model)
assert(model.hasParent)
- val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
+ val validationDataset = spark.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 20, "bernoulli"))
val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
@@ -188,7 +188,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
test("should support all NumericType labels and not support other types") {
val nb = new NaiveBayes()
MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes](
- nb, isClassification = true, sqlContext) { (expected, actual) =>
+ nb, isClassification = true, spark) { (expected, actual) =>
assert(expected.pi === actual.pi)
assert(expected.theta === actual.theta)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 51871a9bab..005d609307 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -53,7 +53,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
rdd = sc.parallelize(generateMultinomialLogisticInput(
coefficients, xMean, xVariance, true, nPoints, 42), 2)
- dataset = sqlContext.createDataFrame(rdd)
+ dataset = spark.createDataFrame(rdd)
}
test("params") {
@@ -228,7 +228,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("should support all NumericType labels and not support other types") {
val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1))
MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest](
- ovr, isClassification = true, sqlContext) { (expected, actual) =>
+ ovr, isClassification = true, spark) { (expected, actual) =>
val expectedModels = expected.models.map(m => m.asInstanceOf[LogisticRegressionModel])
val actualModels = actual.models.map(m => m.asInstanceOf[LogisticRegressionModel])
assert(expectedModels.length === actualModels.length)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 90744353d9..97f3feacca 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -155,7 +155,7 @@ class RandomForestClassifierSuite
}
test("Fitting without numClasses in metadata") {
- val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc))
+ val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc))
val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1)
rf.fit(df)
}
@@ -189,7 +189,7 @@ class RandomForestClassifierSuite
test("should support all NumericType labels and not support other types") {
val rf = new RandomForestClassifier().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier](
- rf, isClassification = true, sqlContext) { (expected, actual) =>
+ rf, isClassification = true, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index 212ea7a0a9..4f7d4418a8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -30,7 +30,7 @@ class BisectingKMeansSuite
override def beforeAll(): Unit = {
super.beforeAll()
- dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k)
+ dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
}
test("default parameters") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index 9d868174c1..04366f5250 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -32,7 +32,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
override def beforeAll(): Unit = {
super.beforeAll()
- dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k)
+ dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
}
test("default parameters") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 241d21961f..2832db2f99 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Dataset, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
private[clustering] case class TestRow(features: Vector)
@@ -34,7 +34,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
override def beforeAll(): Unit = {
super.beforeAll()
- dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k)
+ dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
}
test("default parameters") {
@@ -142,11 +142,11 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
}
object KMeansSuite {
- def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
- val sc = sql.sparkContext
+ def generateKMeansData(spark: SparkSession, rows: Int, dim: Int, k: Int): DataFrame = {
+ val sc = spark.sparkContext
val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
.map(v => new TestRow(v))
- sql.createDataFrame(rdd)
+ spark.createDataFrame(rdd)
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index 6cb07aecb9..34e8964286 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -17,30 +17,30 @@
package org.apache.spark.ml.clustering
-import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.Path
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
+import org.apache.spark.sql._
object LDASuite {
def generateLDAData(
- sql: SQLContext,
+ spark: SparkSession,
rows: Int,
k: Int,
vocabSize: Int): DataFrame = {
val avgWC = 1 // average instances of each word in a doc
- val sc = sql.sparkContext
+ val sc = spark.sparkContext
val rng = new java.util.Random()
rng.setSeed(1)
val rdd = sc.parallelize(1 to rows).map { i =>
Vectors.dense(Array.fill(vocabSize)(rng.nextInt(2 * avgWC).toDouble))
}.map(v => new TestRow(v))
- sql.createDataFrame(rdd)
+ spark.createDataFrame(rdd)
}
/**
@@ -68,7 +68,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
override def beforeAll(): Unit = {
super.beforeAll()
- dataset = LDASuite.generateLDAData(sqlContext, 50, k, vocabSize)
+ dataset = LDASuite.generateLDAData(spark, 50, k, vocabSize)
}
test("default parameters") {
@@ -140,7 +140,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
new LDA().setTopicConcentration(-1.1)
}
- val dummyDF = sqlContext.createDataFrame(Seq(
+ val dummyDF = spark.createDataFrame(Seq(
(1, Vectors.dense(1.0, 2.0)))).toDF("id", "features")
// validate parameters
lda.transformSchema(dummyDF.schema)
@@ -274,7 +274,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
// There should be 1 checkpoint remaining.
assert(model.getCheckpointFiles.length === 1)
val checkpointFile = new Path(model.getCheckpointFiles.head)
- val fs = checkpointFile.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
+ val fs = checkpointFile.getFileSystem(spark.sparkContext.hadoopConfiguration)
assert(fs.exists(checkpointFile))
model.deleteCheckpointFiles()
assert(model.getCheckpointFiles.isEmpty)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
index ff34522178..a8766f9035 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
@@ -42,21 +42,21 @@ class BinaryClassificationEvaluatorSuite
val evaluator = new BinaryClassificationEvaluator()
.setMetricName("areaUnderPR")
- val vectorDF = sqlContext.createDataFrame(Seq(
+ val vectorDF = spark.createDataFrame(Seq(
(0d, Vectors.dense(12, 2.5)),
(1d, Vectors.dense(1, 3)),
(0d, Vectors.dense(10, 2))
)).toDF("label", "rawPrediction")
assert(evaluator.evaluate(vectorDF) === 1.0)
- val doubleDF = sqlContext.createDataFrame(Seq(
+ val doubleDF = spark.createDataFrame(Seq(
(0d, 0d),
(1d, 1d),
(0d, 0d)
)).toDF("label", "rawPrediction")
assert(evaluator.evaluate(doubleDF) === 1.0)
- val stringDF = sqlContext.createDataFrame(Seq(
+ val stringDF = spark.createDataFrame(Seq(
(0d, "0d"),
(1d, "1d"),
(0d, "0d")
@@ -71,6 +71,6 @@ class BinaryClassificationEvaluatorSuite
test("should support all NumericType labels and not support other types") {
val evaluator = new BinaryClassificationEvaluator().setRawPredictionCol("prediction")
- MLTestingUtils.checkNumericTypes(evaluator, sqlContext)
+ MLTestingUtils.checkNumericTypes(evaluator, spark)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
index 87e511a368..522f6675d7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
@@ -38,6 +38,6 @@ class MulticlassClassificationEvaluatorSuite
}
test("should support all NumericType labels and not support other types") {
- MLTestingUtils.checkNumericTypes(new MulticlassClassificationEvaluator, sqlContext)
+ MLTestingUtils.checkNumericTypes(new MulticlassClassificationEvaluator, spark)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
index c7b9483069..dcc004358d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
@@ -42,7 +42,7 @@ class RegressionEvaluatorSuite
* data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1))
* .saveAsTextFile("path")
*/
- val dataset = sqlContext.createDataFrame(
+ val dataset = spark.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
@@ -85,6 +85,6 @@ class RegressionEvaluatorSuite
}
test("should support all NumericType labels and not support other types") {
- MLTestingUtils.checkNumericTypes(new RegressionEvaluator, sqlContext)
+ MLTestingUtils.checkNumericTypes(new RegressionEvaluator, spark)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
index 714b9db3aa..e91f758112 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -39,7 +39,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("Binarize continuous features with default parameter") {
val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
- val dataFrame: DataFrame = sqlContext.createDataFrame(
+ val dataFrame: DataFrame = spark.createDataFrame(
data.zip(defaultBinarized)).toDF("feature", "expected")
val binarizer: Binarizer = new Binarizer()
@@ -55,7 +55,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("Binarize continuous features with setter") {
val threshold: Double = 0.2
val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
- val dataFrame: DataFrame = sqlContext.createDataFrame(
+ val dataFrame: DataFrame = spark.createDataFrame(
data.zip(thresholdBinarized)).toDF("feature", "expected")
val binarizer: Binarizer = new Binarizer()
@@ -71,7 +71,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("Binarize vector of continuous features with default parameter") {
val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
- val dataFrame: DataFrame = sqlContext.createDataFrame(Seq(
+ val dataFrame: DataFrame = spark.createDataFrame(Seq(
(Vectors.dense(data), Vectors.dense(defaultBinarized))
)).toDF("feature", "expected")
@@ -88,7 +88,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("Binarize vector of continuous features with setter") {
val threshold: Double = 0.2
val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
- val dataFrame: DataFrame = sqlContext.createDataFrame(Seq(
+ val dataFrame: DataFrame = spark.createDataFrame(Seq(
(Vectors.dense(data), Vectors.dense(defaultBinarized))
)).toDF("feature", "expected")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index 9ea7d43176..98b2316d78 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -39,7 +39,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val validData = Array(-0.5, -0.3, 0.0, 0.2)
val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0)
val dataFrame: DataFrame =
- sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")
+ spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")
val bucketizer: Bucketizer = new Bucketizer()
.setInputCol("feature")
@@ -55,13 +55,13 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
// Check for exceptions when using a set of invalid feature values.
val invalidData1: Array[Double] = Array(-0.9) ++ validData
val invalidData2 = Array(0.51) ++ validData
- val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx")
+ val badDF1 = spark.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx")
withClue("Invalid feature value -0.9 was not caught as an invalid feature!") {
intercept[SparkException] {
bucketizer.transform(badDF1).collect()
}
}
- val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx")
+ val badDF2 = spark.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx")
withClue("Invalid feature value 0.51 was not caught as an invalid feature!") {
intercept[SparkException] {
bucketizer.transform(badDF2).collect()
@@ -74,7 +74,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9)
val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0)
val dataFrame: DataFrame =
- sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")
+ spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")
val bucketizer: Bucketizer = new Bucketizer()
.setInputCol("feature")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index 7827db2794..4c6d9c5e26 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -24,14 +24,17 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SparkSession}
class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
with DefaultReadWriteTest {
test("Test Chi-Square selector") {
- val sqlContext = SQLContext.getOrCreate(sc)
- import sqlContext.implicits._
+ val spark = SparkSession.builder
+ .master("local[2]")
+ .appName("ChiSqSelectorSuite")
+ .getOrCreate()
+ import spark.implicits._
val data = Seq(
LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
index 7641e3b8cf..b82e3e90b4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
@@ -35,7 +35,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
private def split(s: String): Seq[String] = s.split("\\s+")
test("CountVectorizerModel common cases") {
- val df = sqlContext.createDataFrame(Seq(
+ val df = spark.createDataFrame(Seq(
(0, split("a b c d"),
Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
(1, split("a b b c d a"),
@@ -55,7 +55,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
}
test("CountVectorizer common cases") {
- val df = sqlContext.createDataFrame(Seq(
+ val df = spark.createDataFrame(Seq(
(0, split("a b c d e"),
Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))),
(1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))),
@@ -76,7 +76,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
}
test("CountVectorizer vocabSize and minDF") {
- val df = sqlContext.createDataFrame(Seq(
+ val df = spark.createDataFrame(Seq(
(0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
(1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
(2, split("a b"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
@@ -118,7 +118,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
test("CountVectorizer throws exception when vocab is empty") {
intercept[IllegalArgumentException] {
- val df = sqlContext.createDataFrame(Seq(
+ val df = spark.createDataFrame(Seq(
(0, split("a a b b c c")),
(1, split("aa bb cc")))
).toDF("id", "words")
@@ -132,7 +132,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
}
test("CountVectorizerModel with minTF count") {
- val df = sqlContext.createDataFrame(Seq(
+ val df = spark.createDataFrame(Seq(
(0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
(1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))),
(2, split("a"), Vectors.sparse(4, Seq())),
@@ -151,7 +151,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
}
test("CountVectorizerModel with minTF freq") {
- val df = sqlContext.createDataFrame(Seq(
+ val df = spark.createDataFrame(Seq(
(0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
(1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))),
(2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))),
@@ -170,7 +170,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
}
test("CountVectorizerModel and CountVectorizer with binary") {
- val df = sqlContext.createDataFrame(Seq(
+ val df = spark.createDataFrame(Seq(
(0, split("a a a a b b b b c d"),
Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
(1, split("c c c"), Vectors.sparse(4, Seq((2, 1.0)))),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
index 36cafa290f..dbd5ae8345 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
@@ -63,7 +63,7 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
}
val expectedResult = Vectors.dense(expectedResultBuffer)
- val dataset = sqlContext.createDataFrame(Seq(
+ val dataset = spark.createDataFrame(Seq(
DCTTestData(data, expectedResult)
))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
index 44bad4aba4..89d67d8e6f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -34,7 +34,7 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
}
test("hashingTF") {
- val df = sqlContext.createDataFrame(Seq(
+ val df = spark.createDataFrame(Seq(
(0, "a a b b c d".split(" ").toSeq)
)).toDF("id", "words")
val n = 100
@@ -54,7 +54,7 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
}
test("applying binary term freqs") {
- val df = sqlContext.createDataFrame(Seq(
+ val df = spark.createDataFrame(Seq(
(0, "a a b c c c".split(" ").toSeq)
)).toDF("id", "words")
val n = 100
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
index bc958c1585..208ea84913 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
@@ -60,7 +60,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
})
val expected = scaleDataWithIDF(data, idf)
- val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
+ val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected")
val idfModel = new IDF()
.setInputCol("features")
@@ -86,7 +86,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
})
val expected = scaleDataWithIDF(data, idf)
- val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
+ val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected")
val idfModel = new IDF()
.setInputCol("features")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala
index 0d4e00668d..3409928007 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala
@@ -59,7 +59,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
}
test("numeric interaction") {
- val data = sqlContext.createDataFrame(
+ val data = spark.createDataFrame(
Seq(
(2, Vectors.dense(3.0, 4.0)),
(1, Vectors.dense(1.0, 5.0)))
@@ -74,7 +74,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
col("b").as("b", groupAttr.toMetadata()))
val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features")
val res = trans.transform(df)
- val expected = sqlContext.createDataFrame(
+ val expected = spark.createDataFrame(
Seq(
(2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)),
(1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0)))
@@ -90,7 +90,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
}
test("nominal interaction") {
- val data = sqlContext.createDataFrame(
+ val data = spark.createDataFrame(
Seq(
(2, Vectors.dense(3.0, 4.0)),
(1, Vectors.dense(1.0, 5.0)))
@@ -106,7 +106,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
col("b").as("b", groupAttr.toMetadata()))
val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features")
val res = trans.transform(df)
- val expected = sqlContext.createDataFrame(
+ val expected = spark.createDataFrame(
Seq(
(2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)),
(1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0)))
@@ -126,7 +126,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
}
test("default attr names") {
- val data = sqlContext.createDataFrame(
+ val data = spark.createDataFrame(
Seq(
(2, Vectors.dense(0.0, 4.0), 1.0),
(1, Vectors.dense(1.0, 5.0), 10.0))
@@ -142,7 +142,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
col("c").as("c", NumericAttribute.defaultAttr.toMetadata()))
val trans = new Interaction().setInputCols(Array("a", "b", "c")).setOutputCol("features")
val res = trans.transform(df)
- val expected = sqlContext.createDataFrame(
+ val expected = spark.createDataFrame(
Seq(
(2, Vectors.dense(0.0, 4.0), 1.0, Vectors.dense(0, 0, 0, 0, 0, 0, 1, 0, 4)),
(1, Vectors.dense(1.0, 5.0), 10.0, Vectors.dense(0, 0, 0, 0, 10, 50, 0, 0, 0)))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala
index e083d47136..73d69ebfee 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala
@@ -36,7 +36,7 @@ class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
Vectors.sparse(3, Array(0, 2), Array(-1, -1)),
Vectors.sparse(3, Array(0), Array(-0.75)))
- val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
+ val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected")
val scaler = new MaxAbsScaler()
.setInputCol("features")
.setOutputCol("scaled")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
index 87206c777e..e495c8e571 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
@@ -38,7 +38,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
Vectors.sparse(3, Array(0, 2), Array(5, 5)),
Vectors.sparse(3, Array(0), Array(-2.5)))
- val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
+ val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected")
val scaler = new MinMaxScaler()
.setInputCol("features")
.setOutputCol("scaled")
@@ -57,7 +57,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
test("MinMaxScaler arguments max must be larger than min") {
withClue("arguments max must be larger than min") {
- val dummyDF = sqlContext.createDataFrame(Seq(
+ val dummyDF = spark.createDataFrame(Seq(
(1, Vectors.dense(1.0, 2.0)))).toDF("id", "feature")
intercept[IllegalArgumentException] {
val scaler = new MinMaxScaler().setMin(10).setMax(0).setInputCol("feature")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
index a9421e6825..e5288d9259 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
@@ -34,7 +34,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
val nGram = new NGram()
.setInputCol("inputTokens")
.setOutputCol("nGrams")
- val dataset = sqlContext.createDataFrame(Seq(
+ val dataset = spark.createDataFrame(Seq(
NGramTestData(
Array("Test", "for", "ngram", "."),
Array("Test for", "for ngram", "ngram .")
@@ -47,7 +47,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
.setInputCol("inputTokens")
.setOutputCol("nGrams")
.setN(4)
- val dataset = sqlContext.createDataFrame(Seq(
+ val dataset = spark.createDataFrame(Seq(
NGramTestData(
Array("a", "b", "c", "d", "e"),
Array("a b c d", "b c d e")
@@ -60,7 +60,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
.setInputCol("inputTokens")
.setOutputCol("nGrams")
.setN(4)
- val dataset = sqlContext.createDataFrame(Seq(
+ val dataset = spark.createDataFrame(Seq(
NGramTestData(
Array(),
Array()
@@ -73,7 +73,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
.setInputCol("inputTokens")
.setOutputCol("nGrams")
.setN(6)
- val dataset = sqlContext.createDataFrame(Seq(
+ val dataset = spark.createDataFrame(Seq(
NGramTestData(
Array("a", "b", "c", "d", "e"),
Array()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
index 4688339019..241a1e9fb5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
@@ -61,7 +61,7 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
Vectors.sparse(3, Seq())
)
- dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData))
+ dataFrame = spark.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData))
normalizer = new Normalizer()
.setInputCol("features")
.setOutputCol("normalized_features")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
index 49803aef71..06ffbc386f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -32,7 +32,7 @@ class OneHotEncoderSuite
def stringIndexed(): DataFrame = {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
- val df = sqlContext.createDataFrame(data).toDF("id", "label")
+ val df = spark.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
@@ -81,7 +81,7 @@ class OneHotEncoderSuite
test("input column with ML attribute") {
val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large")
- val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size")
+ val df = spark.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size")
.select(col("size").as("size", attr.toMetadata()))
val encoder = new OneHotEncoder()
.setInputCol("size")
@@ -94,7 +94,7 @@ class OneHotEncoderSuite
}
test("input column without ML attribute") {
- val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index")
+ val df = spark.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index")
val encoder = new OneHotEncoder()
.setInputCol("index")
.setOutputCol("encoded")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
index f372ec5826..4befa84dbb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
@@ -49,7 +49,7 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
val pc = mat.computePrincipalComponents(3)
val expected = mat.multiply(pc).rows
- val df = sqlContext.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected")
+ val df = spark.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected")
val pca = new PCA()
.setInputCol("features")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
index 86dbee1cf4..e3adbba9d5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
@@ -59,7 +59,7 @@ class PolynomialExpansionSuite
Vectors.sparse(19, Array.empty, Array.empty))
test("Polynomial expansion with default parameter") {
- val df = sqlContext.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected")
+ val df = spark.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected")
val polynomialExpansion = new PolynomialExpansion()
.setInputCol("features")
@@ -76,7 +76,7 @@ class PolynomialExpansionSuite
}
test("Polynomial expansion with setter") {
- val df = sqlContext.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected")
+ val df = spark.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected")
val polynomialExpansion = new PolynomialExpansion()
.setInputCol("features")
@@ -94,7 +94,7 @@ class PolynomialExpansionSuite
}
test("Polynomial expansion with degree 1 is identity on vectors") {
- val df = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected")
+ val df = spark.createDataFrame(data.zip(data)).toDF("features", "expected")
val polynomialExpansion = new PolynomialExpansion()
.setInputCol("features")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index f8476953d8..46e7495297 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -32,12 +32,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("transform numeric data") {
val formula = new RFormula().setFormula("id ~ v1 + v2")
- val original = sqlContext.createDataFrame(
+ val original = spark.createDataFrame(
Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
val model = formula.fit(original)
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
- val expected = sqlContext.createDataFrame(
+ val expected = spark.createDataFrame(
Seq(
(0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0),
(2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0))
@@ -50,7 +50,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("features column already exists") {
val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x")
- val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
+ val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
intercept[IllegalArgumentException] {
formula.fit(original)
}
@@ -61,7 +61,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("label column already exists") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
- val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
+ val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
val model = formula.fit(original)
val resultSchema = model.transformSchema(original.schema)
assert(resultSchema.length == 3)
@@ -70,7 +70,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("label column already exists but is not double type") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
- val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
+ val original = spark.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
val model = formula.fit(original)
intercept[IllegalArgumentException] {
model.transformSchema(original.schema)
@@ -82,7 +82,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("allow missing label column for test datasets") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("label")
- val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y")
+ val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y")
val model = formula.fit(original)
val resultSchema = model.transformSchema(original.schema)
assert(resultSchema.length == 3)
@@ -91,14 +91,14 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
test("allow empty label") {
- val original = sqlContext.createDataFrame(
+ val original = spark.createDataFrame(
Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0))
).toDF("id", "a", "b")
val formula = new RFormula().setFormula("~ a + b")
val model = formula.fit(original)
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
- val expected = sqlContext.createDataFrame(
+ val expected = spark.createDataFrame(
Seq(
(1, 2.0, 3.0, Vectors.dense(2.0, 3.0)),
(4, 5.0, 6.0, Vectors.dense(5.0, 6.0)),
@@ -110,13 +110,13 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("encodes string terms") {
val formula = new RFormula().setFormula("id ~ a + b")
- val original = sqlContext.createDataFrame(
+ val original = spark.createDataFrame(
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
).toDF("id", "a", "b")
val model = formula.fit(original)
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
- val expected = sqlContext.createDataFrame(
+ val expected = spark.createDataFrame(
Seq(
(1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
(2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0),
@@ -129,13 +129,13 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("index string label") {
val formula = new RFormula().setFormula("id ~ a + b")
- val original = sqlContext.createDataFrame(
+ val original = spark.createDataFrame(
Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5))
).toDF("id", "a", "b")
val model = formula.fit(original)
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
- val expected = sqlContext.createDataFrame(
+ val expected = spark.createDataFrame(
Seq(
("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
@@ -148,7 +148,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("attribute generation") {
val formula = new RFormula().setFormula("id ~ a + b")
- val original = sqlContext.createDataFrame(
+ val original = spark.createDataFrame(
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
).toDF("id", "a", "b")
val model = formula.fit(original)
@@ -165,7 +165,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("vector attribute generation") {
val formula = new RFormula().setFormula("id ~ vec")
- val original = sqlContext.createDataFrame(
+ val original = spark.createDataFrame(
Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
).toDF("id", "vec")
val model = formula.fit(original)
@@ -181,7 +181,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("vector attribute generation with unnamed input attrs") {
val formula = new RFormula().setFormula("id ~ vec2")
- val base = sqlContext.createDataFrame(
+ val base = spark.createDataFrame(
Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
).toDF("id", "vec")
val metadata = new AttributeGroup(
@@ -203,12 +203,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("numeric interaction") {
val formula = new RFormula().setFormula("a ~ b:c:d")
- val original = sqlContext.createDataFrame(
+ val original = spark.createDataFrame(
Seq((1, 2, 4, 2), (2, 3, 4, 1))
).toDF("a", "b", "c", "d")
val model = formula.fit(original)
val result = model.transform(original)
- val expected = sqlContext.createDataFrame(
+ val expected = spark.createDataFrame(
Seq(
(1, 2, 4, 2, Vectors.dense(16.0), 1.0),
(2, 3, 4, 1, Vectors.dense(12.0), 2.0))
@@ -223,12 +223,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("factor numeric interaction") {
val formula = new RFormula().setFormula("id ~ a:b")
- val original = sqlContext.createDataFrame(
+ val original = spark.createDataFrame(
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5))
).toDF("id", "a", "b")
val model = formula.fit(original)
val result = model.transform(original)
- val expected = sqlContext.createDataFrame(
+ val expected = spark.createDataFrame(
Seq(
(1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0),
(2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0),
@@ -250,12 +250,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("factor factor interaction") {
val formula = new RFormula().setFormula("id ~ a:b")
- val original = sqlContext.createDataFrame(
+ val original = spark.createDataFrame(
Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz"))
).toDF("id", "a", "b")
val model = formula.fit(original)
val result = model.transform(original)
- val expected = sqlContext.createDataFrame(
+ val expected = spark.createDataFrame(
Seq(
(1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0),
(2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0),
@@ -299,7 +299,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
}
- val dataset = sqlContext.createDataFrame(
+ val dataset = spark.createDataFrame(
Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz"))
).toDF("id", "a", "b")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
index e213e17d0d..1401ea9c4b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
@@ -31,13 +31,13 @@ class SQLTransformerSuite
}
test("transform numeric data") {
- val original = sqlContext.createDataFrame(
+ val original = spark.createDataFrame(
Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
val sqlTrans = new SQLTransformer().setStatement(
"SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")
val result = sqlTrans.transform(original)
val resultSchema = sqlTrans.transformSchema(original.schema)
- val expected = sqlContext.createDataFrame(
+ val expected = spark.createDataFrame(
Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)))
.toDF("id", "v1", "v2", "v3", "v4")
assert(result.schema.toString == resultSchema.toString)
@@ -52,7 +52,7 @@ class SQLTransformerSuite
}
test("transformSchema") {
- val df = sqlContext.range(10)
+ val df = spark.range(10)
val outputSchema = new SQLTransformer()
.setStatement("SELECT id + 1 AS id1 FROM __THIS__")
.transformSchema(df.schema)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
index 8c5e47a22c..d62301be14 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
@@ -73,7 +73,7 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
}
test("Standardization with default parameter") {
- val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected")
+ val df0 = spark.createDataFrame(data.zip(resWithStd)).toDF("features", "expected")
val standardScaler0 = new StandardScaler()
.setInputCol("features")
@@ -84,9 +84,9 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
}
test("Standardization with setter") {
- val df1 = sqlContext.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected")
- val df2 = sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected")
- val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected")
+ val df1 = spark.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected")
+ val df2 = spark.createDataFrame(data.zip(resWithMean)).toDF("features", "expected")
+ val df3 = spark.createDataFrame(data.zip(data)).toDF("features", "expected")
val standardScaler1 = new StandardScaler()
.setInputCol("features")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
index 8e7e000fbc..125ad02ebc 100755
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{Dataset, Row}
object StopWordsRemoverSuite extends SparkFunSuite {
def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = {
@@ -42,7 +42,7 @@ class StopWordsRemoverSuite
val remover = new StopWordsRemover()
.setInputCol("raw")
.setOutputCol("filtered")
- val dataSet = sqlContext.createDataFrame(Seq(
+ val dataSet = spark.createDataFrame(Seq(
(Seq("test", "test"), Seq("test", "test")),
(Seq("a", "b", "c", "d"), Seq("b", "c")),
(Seq("a", "the", "an"), Seq()),
@@ -60,7 +60,7 @@ class StopWordsRemoverSuite
.setInputCol("raw")
.setOutputCol("filtered")
.setStopWords(stopWords)
- val dataSet = sqlContext.createDataFrame(Seq(
+ val dataSet = spark.createDataFrame(Seq(
(Seq("test", "test"), Seq()),
(Seq("a", "b", "c", "d"), Seq("b", "c", "d")),
(Seq("a", "the", "an"), Seq()),
@@ -77,7 +77,7 @@ class StopWordsRemoverSuite
.setInputCol("raw")
.setOutputCol("filtered")
.setCaseSensitive(true)
- val dataSet = sqlContext.createDataFrame(Seq(
+ val dataSet = spark.createDataFrame(Seq(
(Seq("A"), Seq("A")),
(Seq("The", "the"), Seq("The"))
)).toDF("raw", "expected")
@@ -98,7 +98,7 @@ class StopWordsRemoverSuite
.setInputCol("raw")
.setOutputCol("filtered")
.setStopWords(stopWords)
- val dataSet = sqlContext.createDataFrame(Seq(
+ val dataSet = spark.createDataFrame(Seq(
(Seq("acaba", "ama", "biri"), Seq()),
(Seq("hep", "her", "scala"), Seq("scala"))
)).toDF("raw", "expected")
@@ -112,7 +112,7 @@ class StopWordsRemoverSuite
.setInputCol("raw")
.setOutputCol("filtered")
.setStopWords(stopWords.toArray)
- val dataSet = sqlContext.createDataFrame(Seq(
+ val dataSet = spark.createDataFrame(Seq(
(Seq("python", "scala", "a"), Seq("python", "scala", "a")),
(Seq("Python", "Scala", "swift"), Seq("Python", "Scala", "swift"))
)).toDF("raw", "expected")
@@ -126,7 +126,7 @@ class StopWordsRemoverSuite
.setInputCol("raw")
.setOutputCol("filtered")
.setStopWords(stopWords.toArray)
- val dataSet = sqlContext.createDataFrame(Seq(
+ val dataSet = spark.createDataFrame(Seq(
(Seq("python", "scala", "a"), Seq()),
(Seq("Python", "Scala", "swift"), Seq("swift"))
)).toDF("raw", "expected")
@@ -148,7 +148,7 @@ class StopWordsRemoverSuite
val remover = new StopWordsRemover()
.setInputCol("raw")
.setOutputCol(outputCol)
- val dataSet = sqlContext.createDataFrame(Seq(
+ val dataSet = spark.createDataFrame(Seq(
(Seq("The", "the", "swift"), Seq("swift"))
)).toDF("raw", outputCol)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index d0f3cdc841..c221d4aa55 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -39,7 +39,7 @@ class StringIndexerSuite
test("StringIndexer") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
- val df = sqlContext.createDataFrame(data).toDF("id", "label")
+ val df = spark.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
@@ -63,8 +63,8 @@ class StringIndexerSuite
test("StringIndexerUnseen") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2)
val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2)
- val df = sqlContext.createDataFrame(data).toDF("id", "label")
- val df2 = sqlContext.createDataFrame(data2).toDF("id", "label")
+ val df = spark.createDataFrame(data).toDF("id", "label")
+ val df2 = spark.createDataFrame(data2).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
@@ -93,7 +93,7 @@ class StringIndexerSuite
test("StringIndexer with a numeric input column") {
val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2)
- val df = sqlContext.createDataFrame(data).toDF("id", "label")
+ val df = spark.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
@@ -114,12 +114,12 @@ class StringIndexerSuite
val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
.setInputCol("label")
.setOutputCol("labelIndex")
- val df = sqlContext.range(0L, 10L).toDF()
+ val df = spark.range(0L, 10L).toDF()
assert(indexerModel.transform(df).collect().toSet === df.collect().toSet)
}
test("StringIndexerModel can't overwrite output column") {
- val df = sqlContext.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output")
+ val df = spark.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output")
val indexer = new StringIndexer()
.setInputCol("input")
.setOutputCol("output")
@@ -153,7 +153,7 @@ class StringIndexerSuite
test("IndexToString.transform") {
val labels = Array("a", "b", "c")
- val df0 = sqlContext.createDataFrame(Seq(
+ val df0 = spark.createDataFrame(Seq(
(0, "a"), (1, "b"), (2, "c"), (0, "a")
)).toDF("index", "expected")
@@ -180,7 +180,7 @@ class StringIndexerSuite
test("StringIndexer, IndexToString are inverses") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
- val df = sqlContext.createDataFrame(data).toDF("id", "label")
+ val df = spark.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
@@ -213,7 +213,7 @@ class StringIndexerSuite
test("StringIndexer metadata") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
- val df = sqlContext.createDataFrame(data).toDF("id", "label")
+ val df = spark.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
index 123ddfe42c..f30bdc3ddc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -57,13 +57,13 @@ class RegexTokenizerSuite
.setPattern("\\w+|\\p{Punct}")
.setInputCol("rawText")
.setOutputCol("tokens")
- val dataset0 = sqlContext.createDataFrame(Seq(
+ val dataset0 = spark.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization", ".")),
TokenizerTestData("Te,st. punct", Array("te", ",", "st", ".", "punct"))
))
testRegexTokenizer(tokenizer0, dataset0)
- val dataset1 = sqlContext.createDataFrame(Seq(
+ val dataset1 = spark.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization")),
TokenizerTestData("Te,st. punct", Array("punct"))
))
@@ -73,7 +73,7 @@ class RegexTokenizerSuite
val tokenizer2 = new RegexTokenizer()
.setInputCol("rawText")
.setOutputCol("tokens")
- val dataset2 = sqlContext.createDataFrame(Seq(
+ val dataset2 = spark.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization.")),
TokenizerTestData("Te,st. punct", Array("te,st.", "punct"))
))
@@ -85,7 +85,7 @@ class RegexTokenizerSuite
.setInputCol("rawText")
.setOutputCol("tokens")
.setToLowercase(false)
- val dataset = sqlContext.createDataFrame(Seq(
+ val dataset = spark.createDataFrame(Seq(
TokenizerTestData("JAVA SCALA", Array("JAVA", "SCALA")),
TokenizerTestData("java scala", Array("java", "scala"))
))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index dce994fdbd..250011c859 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -57,7 +57,7 @@ class VectorAssemblerSuite
}
test("VectorAssembler") {
- val df = sqlContext.createDataFrame(Seq(
+ val df = spark.createDataFrame(Seq(
(0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L)
)).toDF("id", "x", "y", "name", "z", "n")
val assembler = new VectorAssembler()
@@ -70,7 +70,7 @@ class VectorAssemblerSuite
}
test("transform should throw an exception in case of unsupported type") {
- val df = sqlContext.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c")
+ val df = spark.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c")
val assembler = new VectorAssembler()
.setInputCols(Array("a", "b", "c"))
.setOutputCol("features")
@@ -87,7 +87,7 @@ class VectorAssemblerSuite
NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"),
NumericAttribute.defaultAttr.withName("salary")))
val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0)))
- val df = sqlContext.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad")
+ val df = spark.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad")
.select(
col("browser").as("browser", browser.toMetadata()),
col("hour").as("hour", hour.toMetadata()),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index 1ffc62b38e..d1c0270a02 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -85,11 +85,11 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
checkPair(densePoints1Seq, sparsePoints1Seq)
checkPair(densePoints2Seq, sparsePoints2Seq)
- densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData))
- sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData))
- densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData))
- sparsePoints2 = sqlContext.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData))
- badPoints = sqlContext.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData))
+ densePoints1 = spark.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData))
+ sparsePoints1 = spark.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData))
+ densePoints2 = spark.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData))
+ sparsePoints2 = spark.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData))
+ badPoints = spark.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData))
}
private def getIndexer: VectorIndexer =
@@ -102,7 +102,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
}
test("Cannot fit an empty DataFrame") {
- val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData))
+ val rdd = spark.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData))
val vectorIndexer = getIndexer
intercept[IllegalArgumentException] {
vectorIndexer.fit(rdd)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
index 6bb4678dc5..88a077f9a1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
@@ -79,7 +79,7 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De
val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]])
val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) }
- val df = sqlContext.createDataFrame(rdd,
+ val df = spark.createDataFrame(rdd,
StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField())))
val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 80c177b8d3..8cbe0f3def 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -36,8 +36,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("Word2Vec") {
- val sqlContext = this.sqlContext
- import sqlContext.implicits._
+ val spark = this.spark
+ import spark.implicits._
val sentence = "a b " * 100 + "a c " * 10
val numOfWords = sentence.split(" ").size
@@ -78,8 +78,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("getVectors") {
- val sqlContext = this.sqlContext
- import sqlContext.implicits._
+ val spark = this.spark
+ import spark.implicits._
val sentence = "a b " * 100 + "a c " * 10
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
@@ -119,8 +119,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("findSynonyms") {
- val sqlContext = this.sqlContext
- import sqlContext.implicits._
+ val spark = this.spark
+ import spark.implicits._
val sentence = "a b " * 100 + "a c " * 10
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
@@ -146,8 +146,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("window size") {
- val sqlContext = this.sqlContext
- import sqlContext.implicits._
+ val spark = this.spark
+ import spark.implicits._
val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 1704037395..9da0c32dee 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -38,7 +38,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -305,8 +305,8 @@ class ALSSuite
numUserBlocks: Int = 2,
numItemBlocks: Int = 3,
targetRMSE: Double = 0.05): Unit = {
- val sqlContext = this.sqlContext
- import sqlContext.implicits._
+ val spark = this.spark
+ import spark.implicits._
val als = new ALS()
.setRank(rank)
.setRegParam(regParam)
@@ -460,8 +460,8 @@ class ALSSuite
allEstimatorParamSettings.foreach { case (p, v) =>
als.set(als.getParam(p), v)
}
- val sqlContext = this.sqlContext
- import sqlContext.implicits._
+ val spark = this.spark
+ import spark.implicits._
val model = als.fit(ratings.toDF())
// Test Estimator save/load
@@ -535,8 +535,11 @@ class ALSCleanerSuite extends SparkFunSuite {
// Generate test data
val (training, _) = ALSSuite.genImplicitTestData(sc, 20, 5, 1, 0.2, 0)
// Implicitly test the cleaning of parents during ALS training
- val sqlContext = new SQLContext(sc)
- import sqlContext.implicits._
+ val spark = SparkSession.builder
+ .master("local[2]")
+ .appName("ALSCleanerSuite")
+ .getOrCreate()
+ import spark.implicits._
val als = new ALS()
.setRank(1)
.setRegParam(1e-5)
@@ -577,8 +580,8 @@ class ALSStorageSuite
}
test("default and non-default storage params set correct RDD StorageLevels") {
- val sqlContext = this.sqlContext
- import sqlContext.implicits._
+ val spark = this.spark
+ import spark.implicits._
val data = Seq(
(0, 0, 1.0),
(0, 1, 2.0),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index 76891ad562..f8fc775676 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -37,13 +37,13 @@ class AFTSurvivalRegressionSuite
override def beforeAll(): Unit = {
super.beforeAll()
- datasetUnivariate = sqlContext.createDataFrame(
+ datasetUnivariate = spark.createDataFrame(
sc.parallelize(generateAFTInput(
1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)))
- datasetMultivariate = sqlContext.createDataFrame(
+ datasetMultivariate = spark.createDataFrame(
sc.parallelize(generateAFTInput(
2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0)))
- datasetUnivariateScaled = sqlContext.createDataFrame(
+ datasetUnivariateScaled = spark.createDataFrame(
sc.parallelize(generateAFTInput(
1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)).map { x =>
AFTPoint(Vectors.dense(x.features(0) * 1.0E3), x.label, x.censor)
@@ -356,7 +356,7 @@ class AFTSurvivalRegressionSuite
test("should support all NumericType labels") {
val aft = new AFTSurvivalRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression](
- aft, isClassification = false, sqlContext) { (expected, actual) =>
+ aft, isClassification = false, spark) { (expected, actual) =>
assert(expected.intercept === actual.intercept)
assert(expected.coefficients === actual.coefficients)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index e9fb2677b2..d9f26ad8dc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -120,7 +120,7 @@ class DecisionTreeRegressorSuite
test("should support all NumericType labels and not support other types") {
val dt = new DecisionTreeRegressor().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor](
- dt, isClassification = false, sqlContext) { (expected, actual) =>
+ dt, isClassification = false, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 216377959e..f6ea5bb741 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -72,7 +72,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
}
test("GBTRegressor behaves reasonably on toy data") {
- val df = sqlContext.createDataFrame(Seq(
+ val df = spark.createDataFrame(Seq(
LabeledPoint(10, Vectors.dense(1, 2, 3, 4)),
LabeledPoint(-5, Vectors.dense(6, 3, 2, 1)),
LabeledPoint(11, Vectors.dense(2, 2, 3, 4)),
@@ -99,7 +99,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
val path = tempDir.toURI.toString
sc.setCheckpointDir(path)
- val df = sqlContext.createDataFrame(data)
+ val df = spark.createDataFrame(data)
val gbt = new GBTRegressor()
.setMaxDepth(2)
.setMaxIter(5)
@@ -115,7 +115,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
test("should support all NumericType labels and not support other types") {
val gbt = new GBTRegressor().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[GBTRegressionModel, GBTRegressor](
- gbt, isClassification = false, sqlContext) { (expected, actual) =>
+ gbt, isClassification = false, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index b854be2f1f..161f8c80f8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -52,19 +52,19 @@ class GeneralizedLinearRegressionSuite
import GeneralizedLinearRegressionSuite._
- datasetGaussianIdentity = sqlContext.createDataFrame(
+ datasetGaussianIdentity = spark.createDataFrame(
sc.parallelize(generateGeneralizedLinearRegressionInput(
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
family = "gaussian", link = "identity"), 2))
- datasetGaussianLog = sqlContext.createDataFrame(
+ datasetGaussianLog = spark.createDataFrame(
sc.parallelize(generateGeneralizedLinearRegressionInput(
intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
family = "gaussian", link = "log"), 2))
- datasetGaussianInverse = sqlContext.createDataFrame(
+ datasetGaussianInverse = spark.createDataFrame(
sc.parallelize(generateGeneralizedLinearRegressionInput(
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
@@ -80,40 +80,40 @@ class GeneralizedLinearRegressionSuite
generateMultinomialLogisticInput(coefficients, xMean, xVariance,
addIntercept = true, nPoints, seed)
- sqlContext.createDataFrame(sc.parallelize(testData, 2))
+ spark.createDataFrame(sc.parallelize(testData, 2))
}
- datasetPoissonLog = sqlContext.createDataFrame(
+ datasetPoissonLog = spark.createDataFrame(
sc.parallelize(generateGeneralizedLinearRegressionInput(
intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
family = "poisson", link = "log"), 2))
- datasetPoissonIdentity = sqlContext.createDataFrame(
+ datasetPoissonIdentity = spark.createDataFrame(
sc.parallelize(generateGeneralizedLinearRegressionInput(
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
family = "poisson", link = "identity"), 2))
- datasetPoissonSqrt = sqlContext.createDataFrame(
+ datasetPoissonSqrt = spark.createDataFrame(
sc.parallelize(generateGeneralizedLinearRegressionInput(
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
family = "poisson", link = "sqrt"), 2))
- datasetGammaInverse = sqlContext.createDataFrame(
+ datasetGammaInverse = spark.createDataFrame(
sc.parallelize(generateGeneralizedLinearRegressionInput(
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
family = "gamma", link = "inverse"), 2))
- datasetGammaIdentity = sqlContext.createDataFrame(
+ datasetGammaIdentity = spark.createDataFrame(
sc.parallelize(generateGeneralizedLinearRegressionInput(
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
family = "gamma", link = "identity"), 2))
- datasetGammaLog = sqlContext.createDataFrame(
+ datasetGammaLog = spark.createDataFrame(
sc.parallelize(generateGeneralizedLinearRegressionInput(
intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
@@ -540,7 +540,7 @@ class GeneralizedLinearRegressionSuite
w <- c(1, 2, 3, 4)
df <- as.data.frame(cbind(A, b))
*/
- val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq(
+ val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq(
Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)),
Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)),
@@ -668,7 +668,7 @@ class GeneralizedLinearRegressionSuite
w <- c(1, 2, 3, 4)
df <- as.data.frame(cbind(A, b))
*/
- val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq(
+ val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq(
Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)),
Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)),
@@ -782,7 +782,7 @@ class GeneralizedLinearRegressionSuite
w <- c(1, 2, 3, 4)
df <- as.data.frame(cbind(A, b))
*/
- val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq(
+ val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq(
Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)),
Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)),
@@ -899,7 +899,7 @@ class GeneralizedLinearRegressionSuite
w <- c(1, 2, 3, 4)
df <- as.data.frame(cbind(A, b))
*/
- val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq(
+ val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq(
Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)),
Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)),
@@ -1021,14 +1021,14 @@ class GeneralizedLinearRegressionSuite
val glr = new GeneralizedLinearRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[
GeneralizedLinearRegressionModel, GeneralizedLinearRegression](
- glr, isClassification = false, sqlContext) { (expected, actual) =>
+ glr, isClassification = false, spark) { (expected, actual) =>
assert(expected.intercept === actual.intercept)
assert(expected.coefficients === actual.coefficients)
}
}
test("glm accepts Dataset[LabeledPoint]") {
- val context = sqlContext
+ val context = spark
import context.implicits._
new GeneralizedLinearRegression()
.setFamily("gaussian")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
index 3a10ad7ed0..9bf7542b12 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
@@ -28,13 +28,13 @@ class IsotonicRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
private def generateIsotonicInput(labels: Seq[Double]): DataFrame = {
- sqlContext.createDataFrame(
+ spark.createDataFrame(
labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) }
).toDF("label", "features", "weight")
}
private def generatePredictionInput(features: Seq[Double]): DataFrame = {
- sqlContext.createDataFrame(features.map(Tuple1.apply))
+ spark.createDataFrame(features.map(Tuple1.apply))
.toDF("features")
}
@@ -145,7 +145,7 @@ class IsotonicRegressionSuite
}
test("vector features column with feature index") {
- val dataset = sqlContext.createDataFrame(Seq(
+ val dataset = spark.createDataFrame(Seq(
(4.0, Vectors.dense(0.0, 1.0)),
(3.0, Vectors.dense(0.0, 2.0)),
(5.0, Vectors.sparse(2, Array(1), Array(3.0))))
@@ -184,7 +184,7 @@ class IsotonicRegressionSuite
test("should support all NumericType labels and not support other types") {
val ir = new IsotonicRegression()
MLTestingUtils.checkNumericTypes[IsotonicRegressionModel, IsotonicRegression](
- ir, isClassification = false, sqlContext) { (expected, actual) =>
+ ir, isClassification = false, spark) { (expected, actual) =>
assert(expected.boundaries === actual.boundaries)
assert(expected.predictions === actual.predictions)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index eb19d13093..10f547b673 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -42,7 +42,7 @@ class LinearRegressionSuite
override def beforeAll(): Unit = {
super.beforeAll()
- datasetWithDenseFeature = sqlContext.createDataFrame(
+ datasetWithDenseFeature = spark.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3),
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2))
@@ -50,7 +50,7 @@ class LinearRegressionSuite
datasetWithDenseFeatureWithoutIntercept is not needed for correctness testing
but is useful for illustrating training model without intercept
*/
- datasetWithDenseFeatureWithoutIntercept = sqlContext.createDataFrame(
+ datasetWithDenseFeatureWithoutIntercept = spark.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
intercept = 0.0, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3),
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2))
@@ -59,7 +59,7 @@ class LinearRegressionSuite
// When feature size is larger than 4096, normal optimizer is choosed
// as the solver of linear regression in the case of "auto" mode.
val featureSize = 4100
- datasetWithSparseFeature = sqlContext.createDataFrame(
+ datasetWithSparseFeature = spark.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble()).toArray,
xMean = Seq.fill(featureSize)(r.nextDouble()).toArray,
@@ -74,7 +74,7 @@ class LinearRegressionSuite
w <- c(1, 2, 3, 4)
df <- as.data.frame(cbind(A, b))
*/
- datasetWithWeight = sqlContext.createDataFrame(
+ datasetWithWeight = spark.createDataFrame(
sc.parallelize(Seq(
Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)),
@@ -90,14 +90,14 @@ class LinearRegressionSuite
w <- c(1, 2, 3, 4)
df.const.label <- as.data.frame(cbind(A, b.const))
*/
- datasetWithWeightConstantLabel = sqlContext.createDataFrame(
+ datasetWithWeightConstantLabel = spark.createDataFrame(
sc.parallelize(Seq(
Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)),
Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)),
Instance(17.0, 4.0, Vectors.dense(3.0, 13.0))
), 2))
- datasetWithWeightZeroLabel = sqlContext.createDataFrame(
+ datasetWithWeightZeroLabel = spark.createDataFrame(
sc.parallelize(Seq(
Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)),
@@ -828,8 +828,8 @@ class LinearRegressionSuite
}
val data2 = weightedSignedData ++ weightedNoiseData
- (sqlContext.createDataFrame(sc.parallelize(data1, 4)),
- sqlContext.createDataFrame(sc.parallelize(data2, 4)))
+ (spark.createDataFrame(sc.parallelize(data1, 4)),
+ spark.createDataFrame(sc.parallelize(data2, 4)))
}
val trainer1a = (new LinearRegression).setFitIntercept(true)
@@ -1010,7 +1010,7 @@ class LinearRegressionSuite
test("should support all NumericType labels and not support other types") {
val lr = new LinearRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression](
- lr, isClassification = false, sqlContext) { (expected, actual) =>
+ lr, isClassification = false, spark) { (expected, actual) =>
assert(expected.intercept === actual.intercept)
assert(expected.coefficients === actual.coefficients)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index ca400e1914..72f3c65eb8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -98,7 +98,7 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
test("should support all NumericType labels and not support other types") {
val rf = new RandomForestRegressor().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[RandomForestRegressionModel, RandomForestRegressor](
- rf, isClassification = false, sqlContext) { (expected, actual) =>
+ rf, isClassification = false, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
index 1d7144f4e5..7d0e01fd8f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
@@ -56,7 +56,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("select as sparse vector") {
- val df = sqlContext.read.format("libsvm").load(path)
+ val df = spark.read.format("libsvm").load(path)
assert(df.columns(0) == "label")
assert(df.columns(1) == "features")
val row1 = df.first()
@@ -66,7 +66,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("select as dense vector") {
- val df = sqlContext.read.format("libsvm").options(Map("vectorType" -> "dense"))
+ val df = spark.read.format("libsvm").options(Map("vectorType" -> "dense"))
.load(path)
assert(df.columns(0) == "label")
assert(df.columns(1) == "features")
@@ -78,7 +78,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("select a vector with specifying the longer dimension") {
- val df = sqlContext.read.option("numFeatures", "100").format("libsvm")
+ val df = spark.read.option("numFeatures", "100").format("libsvm")
.load(path)
val row1 = df.first()
val v = row1.getAs[SparseVector](1)
@@ -86,27 +86,27 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("write libsvm data and read it again") {
- val df = sqlContext.read.format("libsvm").load(path)
+ val df = spark.read.format("libsvm").load(path)
val tempDir2 = new File(tempDir, "read_write_test")
val writepath = tempDir2.toURI.toString
// TODO: Remove requirement to coalesce by supporting multiple reads.
df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath)
- val df2 = sqlContext.read.format("libsvm").load(writepath)
+ val df2 = spark.read.format("libsvm").load(writepath)
val row1 = df2.first()
val v = row1.getAs[SparseVector](1)
assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
}
test("write libsvm data failed due to invalid schema") {
- val df = sqlContext.read.format("text").load(path)
+ val df = spark.read.format("text").load(path)
intercept[SparkException] {
df.write.format("libsvm").save(path + "_2")
}
}
test("select features from libsvm relation") {
- val df = sqlContext.read.format("libsvm").load(path)
+ val df = spark.read.format("libsvm").load(path)
df.select("features").rdd.map { case Row(d: Vector) => d }.first
df.select("features").collect
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
index fecf372c3d..de92b51eb0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
@@ -37,8 +37,8 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
val numIterations = 20
val trainRdd = sc.parallelize(OldGBTSuite.trainData, 2)
val validateRdd = sc.parallelize(OldGBTSuite.validateData, 2)
- val trainDF = sqlContext.createDataFrame(trainRdd)
- val validateDF = sqlContext.createDataFrame(validateRdd)
+ val trainDF = spark.createDataFrame(trainRdd)
+ val validateDF = spark.createDataFrame(validateRdd)
val algos = Array(Regression, Regression, Classification)
val losses = Array(SquaredError, AbsoluteError, LogLoss)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index e3f09899d7..12ade4c92f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -26,7 +26,7 @@ import org.apache.spark.ml.tree._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, SparkSession}
private[ml] object TreeTests extends SparkFunSuite {
@@ -42,8 +42,12 @@ private[ml] object TreeTests extends SparkFunSuite {
data: RDD[LabeledPoint],
categoricalFeatures: Map[Int, Int],
numClasses: Int): DataFrame = {
- val sqlContext = SQLContext.getOrCreate(data.sparkContext)
- import sqlContext.implicits._
+ val spark = SparkSession.builder
+ .master("local[2]")
+ .appName("TreeTests")
+ .getOrCreate()
+ import spark.implicits._
+
val df = data.toDF()
val numFeatures = data.first().features.size
val featuresAttributes = Range(0, numFeatures).map { feature =>
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 061d04c932..85df6da7a1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -39,7 +39,7 @@ class CrossValidatorSuite
override def beforeAll(): Unit = {
super.beforeAll()
- dataset = sqlContext.createDataFrame(
+ dataset = spark.createDataFrame(
sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
}
@@ -67,7 +67,7 @@ class CrossValidatorSuite
}
test("cross validation with linear regression") {
- val dataset = sqlContext.createDataFrame(
+ val dataset = spark.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index df9ba418b8..f8d3de19b0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.types.StructType
class TrainValidationSplitSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("train validation with logistic regression") {
- val dataset = sqlContext.createDataFrame(
+ val dataset = spark.createDataFrame(
sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
val lr = new LogisticRegression
@@ -58,7 +58,7 @@ class TrainValidationSplitSuite
}
test("train validation with linear regression") {
- val dataset = sqlContext.createDataFrame(
+ val dataset = spark.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index d9e6fd5aae..4fe473bbac 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -38,17 +38,17 @@ object MLTestingUtils extends SparkFunSuite {
def checkNumericTypes[M <: Model[M], T <: Estimator[M]](
estimator: T,
isClassification: Boolean,
- sqlContext: SQLContext)(check: (M, M) => Unit): Unit = {
+ spark: SparkSession)(check: (M, M) => Unit): Unit = {
val dfs = if (isClassification) {
- genClassifDFWithNumericLabelCol(sqlContext)
+ genClassifDFWithNumericLabelCol(spark)
} else {
- genRegressionDFWithNumericLabelCol(sqlContext)
+ genRegressionDFWithNumericLabelCol(spark)
}
val expected = estimator.fit(dfs(DoubleType))
val actuals = dfs.keys.filter(_ != DoubleType).map(t => estimator.fit(dfs(t)))
actuals.foreach(actual => check(expected, actual))
- val dfWithStringLabels = sqlContext.createDataFrame(Seq(
+ val dfWithStringLabels = spark.createDataFrame(Seq(
("0", Vectors.dense(0, 2, 3), 0.0)
)).toDF("label", "features", "censor")
val thrown = intercept[IllegalArgumentException] {
@@ -58,13 +58,13 @@ object MLTestingUtils extends SparkFunSuite {
"Column label must be of type NumericType but was actually of type StringType"))
}
- def checkNumericTypes[T <: Evaluator](evaluator: T, sqlContext: SQLContext): Unit = {
- val dfs = genEvaluatorDFWithNumericLabelCol(sqlContext, "label", "prediction")
+ def checkNumericTypes[T <: Evaluator](evaluator: T, spark: SparkSession): Unit = {
+ val dfs = genEvaluatorDFWithNumericLabelCol(spark, "label", "prediction")
val expected = evaluator.evaluate(dfs(DoubleType))
val actuals = dfs.keys.filter(_ != DoubleType).map(t => evaluator.evaluate(dfs(t)))
actuals.foreach(actual => assert(expected === actual))
- val dfWithStringLabels = sqlContext.createDataFrame(Seq(
+ val dfWithStringLabels = spark.createDataFrame(Seq(
("0", 0d)
)).toDF("label", "prediction")
val thrown = intercept[IllegalArgumentException] {
@@ -75,10 +75,10 @@ object MLTestingUtils extends SparkFunSuite {
}
def genClassifDFWithNumericLabelCol(
- sqlContext: SQLContext,
+ spark: SparkSession,
labelColName: String = "label",
featuresColName: String = "features"): Map[NumericType, DataFrame] = {
- val df = sqlContext.createDataFrame(Seq(
+ val df = spark.createDataFrame(Seq(
(0, Vectors.dense(0, 2, 3)),
(1, Vectors.dense(0, 3, 1)),
(0, Vectors.dense(0, 2, 2)),
@@ -95,11 +95,11 @@ object MLTestingUtils extends SparkFunSuite {
}
def genRegressionDFWithNumericLabelCol(
- sqlContext: SQLContext,
+ spark: SparkSession,
labelColName: String = "label",
featuresColName: String = "features",
censorColName: String = "censor"): Map[NumericType, DataFrame] = {
- val df = sqlContext.createDataFrame(Seq(
+ val df = spark.createDataFrame(Seq(
(0, Vectors.dense(0)),
(1, Vectors.dense(1)),
(2, Vectors.dense(2)),
@@ -117,10 +117,10 @@ object MLTestingUtils extends SparkFunSuite {
}
def genEvaluatorDFWithNumericLabelCol(
- sqlContext: SQLContext,
+ spark: SparkSession,
labelColName: String = "label",
predictionColName: String = "prediction"): Map[NumericType, DataFrame] = {
- val df = sqlContext.createDataFrame(Seq(
+ val df = spark.createDataFrame(Seq(
(0, 0d),
(1, 1d),
(2, 2d),
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
index 7f9e340f54..ba8d36f45f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
@@ -23,23 +23,22 @@ import org.scalatest.Suite
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.util.TempDirectory
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.util.Utils
trait MLlibTestSparkContext extends TempDirectory { self: Suite =>
+ @transient var spark: SparkSession = _
@transient var sc: SparkContext = _
- @transient var sqlContext: SQLContext = _
@transient var checkpointDir: String = _
override def beforeAll() {
super.beforeAll()
- val conf = new SparkConf()
- .setMaster("local[2]")
- .setAppName("MLlibUnitTest")
- sc = new SparkContext(conf)
- SQLContext.clearActive()
- sqlContext = new SQLContext(sc)
- SQLContext.setActive(sqlContext)
+ spark = SparkSession.builder
+ .master("local[2]")
+ .appName("MLlibUnitTest")
+ .getOrCreate()
+ sc = spark.sparkContext
+
checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString
sc.setCheckpointDir(checkpointDir)
}
@@ -47,12 +46,11 @@ trait MLlibTestSparkContext extends TempDirectory { self: Suite =>
override def afterAll() {
try {
Utils.deleteRecursively(new File(checkpointDir))
- sqlContext = null
SQLContext.clearActive()
- if (sc != null) {
- sc.stop()
+ if (spark != null) {
+ spark.stop()
}
- sc = null
+ spark = null
} finally {
super.afterAll()
}