aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test/java
diff options
context:
space:
mode:
authorSandeep Singh <sandeep@techaddict.me>2016-05-10 11:17:47 -0700
committerAndrew Or <andrew@databricks.com>2016-05-10 11:17:47 -0700
commited0b4070fb50054b1ecf66ff6c32458a4967dfd3 (patch)
tree68b3ad1a3ca22f2e0b5966db517c9bc42da3d254 /sql/core/src/test/java
parentbcfee153b1cacfe617e602f3b72c0877e0bdf1f7 (diff)
downloadspark-ed0b4070fb50054b1ecf66ff6c32458a4967dfd3.tar.gz
spark-ed0b4070fb50054b1ecf66ff6c32458a4967dfd3.tar.bz2
spark-ed0b4070fb50054b1ecf66ff6c32458a4967dfd3.zip
[SPARK-15037][SQL][MLLIB] Use SparkSession instead of SQLContext in Scala/Java TestSuites
## What changes were proposed in this pull request? Use SparkSession instead of SQLContext in Scala/Java TestSuites as this PR already very big working Python TestSuites in a diff PR. ## How was this patch tested? Existing tests Author: Sandeep Singh <sandeep@techaddict.me> Closes #12907 from techaddict/SPARK-15037.
Diffstat (limited to 'sql/core/src/test/java')
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java42
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java70
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java89
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java27
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java20
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java33
6 files changed, 133 insertions, 148 deletions
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
index 189cc3972c..f2ae40e644 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
@@ -28,14 +28,13 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
@@ -44,21 +43,22 @@ import org.apache.spark.sql.types.StructType;
// serialized, as an alternative to converting these anonymous classes to static inner classes;
// see http://stackoverflow.com/questions/758570/.
public class JavaApplySchemaSuite implements Serializable {
- private transient JavaSparkContext javaCtx;
- private transient SQLContext sqlContext;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- SparkContext context = new SparkContext("local[*]", "testing");
- javaCtx = new JavaSparkContext(context);
- sqlContext = new SQLContext(context);
+ spark = SparkSession.builder()
+ .master("local[*]")
+ .appName("testing")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sqlContext.sparkContext().stop();
- sqlContext = null;
- javaCtx = null;
+ spark.stop();
+ spark = null;
}
public static class Person implements Serializable {
@@ -94,7 +94,7 @@ public class JavaApplySchemaSuite implements Serializable {
person2.setAge(28);
personList.add(person2);
- JavaRDD<Row> rowRDD = javaCtx.parallelize(personList).map(
+ JavaRDD<Row> rowRDD = jsc.parallelize(personList).map(
new Function<Person, Row>() {
@Override
public Row call(Person person) throws Exception {
@@ -107,9 +107,9 @@ public class JavaApplySchemaSuite implements Serializable {
fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
StructType schema = DataTypes.createStructType(fields);
- Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema);
+ Dataset<Row> df = spark.createDataFrame(rowRDD, schema);
df.registerTempTable("people");
- List<Row> actual = sqlContext.sql("SELECT * FROM people").collectAsList();
+ List<Row> actual = spark.sql("SELECT * FROM people").collectAsList();
List<Row> expected = new ArrayList<>(2);
expected.add(RowFactory.create("Michael", 29));
@@ -130,7 +130,7 @@ public class JavaApplySchemaSuite implements Serializable {
person2.setAge(28);
personList.add(person2);
- JavaRDD<Row> rowRDD = javaCtx.parallelize(personList).map(
+ JavaRDD<Row> rowRDD = jsc.parallelize(personList).map(
new Function<Person, Row>() {
@Override
public Row call(Person person) {
@@ -143,9 +143,9 @@ public class JavaApplySchemaSuite implements Serializable {
fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
StructType schema = DataTypes.createStructType(fields);
- Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema);
+ Dataset<Row> df = spark.createDataFrame(rowRDD, schema);
df.registerTempTable("people");
- List<String> actual = sqlContext.sql("SELECT * FROM people").toJavaRDD()
+ List<String> actual = spark.sql("SELECT * FROM people").toJavaRDD()
.map(new Function<Row, String>() {
@Override
public String call(Row row) {
@@ -162,7 +162,7 @@ public class JavaApplySchemaSuite implements Serializable {
@Test
public void applySchemaToJSON() {
- JavaRDD<String> jsonRDD = javaCtx.parallelize(Arrays.asList(
+ JavaRDD<String> jsonRDD = jsc.parallelize(Arrays.asList(
"{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " +
"\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " +
"\"boolean\":true, \"null\":null}",
@@ -199,18 +199,18 @@ public class JavaApplySchemaSuite implements Serializable {
null,
"this is another simple string."));
- Dataset<Row> df1 = sqlContext.read().json(jsonRDD);
+ Dataset<Row> df1 = spark.read().json(jsonRDD);
StructType actualSchema1 = df1.schema();
Assert.assertEquals(expectedSchema, actualSchema1);
df1.registerTempTable("jsonTable1");
- List<Row> actual1 = sqlContext.sql("select * from jsonTable1").collectAsList();
+ List<Row> actual1 = spark.sql("select * from jsonTable1").collectAsList();
Assert.assertEquals(expectedResult, actual1);
- Dataset<Row> df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD);
+ Dataset<Row> df2 = spark.read().schema(expectedSchema).json(jsonRDD);
StructType actualSchema2 = df2.schema();
Assert.assertEquals(expectedSchema, actualSchema2);
df2.registerTempTable("jsonTable2");
- List<Row> actual2 = sqlContext.sql("select * from jsonTable2").collectAsList();
+ List<Row> actual2 = spark.sql("select * from jsonTable2").collectAsList();
Assert.assertEquals(expectedResult, actual2);
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 1eb680dc4c..324ebbae38 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -20,12 +20,7 @@ package test.org.apache.spark.sql;
import java.io.Serializable;
import java.net.URISyntaxException;
import java.net.URL;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.Comparator;
-import java.util.List;
-import java.util.Map;
-import java.util.ArrayList;
+import java.util.*;
import scala.collection.JavaConverters;
import scala.collection.Seq;
@@ -34,46 +29,45 @@ import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Ints;
import org.junit.*;
-import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.*;
-import org.apache.spark.sql.test.TestSQLContext;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.test.TestSparkSession;
import org.apache.spark.sql.types.*;
+import org.apache.spark.util.sketch.BloomFilter;
import org.apache.spark.util.sketch.CountMinSketch;
import static org.apache.spark.sql.functions.*;
import static org.apache.spark.sql.types.DataTypes.*;
-import org.apache.spark.util.sketch.BloomFilter;
public class JavaDataFrameSuite {
+ private transient TestSparkSession spark;
private transient JavaSparkContext jsc;
- private transient TestSQLContext context;
@Before
public void setUp() {
// Trigger static initializer of TestData
- SparkContext sc = new SparkContext("local[*]", "testing");
- jsc = new JavaSparkContext(sc);
- context = new TestSQLContext(sc);
- context.loadTestData();
+ spark = new TestSparkSession();
+ jsc = new JavaSparkContext(spark.sparkContext());
+ spark.loadTestData();
}
@After
public void tearDown() {
- context.sparkContext().stop();
- context = null;
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
public void testExecution() {
- Dataset<Row> df = context.table("testData").filter("key = 1");
+ Dataset<Row> df = spark.table("testData").filter("key = 1");
Assert.assertEquals(1, df.select("key").collectAsList().get(0).get(0));
}
@Test
public void testCollectAndTake() {
- Dataset<Row> df = context.table("testData").filter("key = 1 or key = 2 or key = 3");
+ Dataset<Row> df = spark.table("testData").filter("key = 1 or key = 2 or key = 3");
Assert.assertEquals(3, df.select("key").collectAsList().size());
Assert.assertEquals(2, df.select("key").takeAsList(2).size());
}
@@ -83,7 +77,7 @@ public class JavaDataFrameSuite {
*/
@Test
public void testVarargMethods() {
- Dataset<Row> df = context.table("testData");
+ Dataset<Row> df = spark.table("testData");
df.toDF("key1", "value1");
@@ -112,7 +106,7 @@ public class JavaDataFrameSuite {
df.select(coalesce(col("key")));
// Varargs with mathfunctions
- Dataset<Row> df2 = context.table("testData2");
+ Dataset<Row> df2 = spark.table("testData2");
df2.select(exp("a"), exp("b"));
df2.select(exp(log("a")));
df2.select(pow("a", "a"), pow("b", 2.0));
@@ -126,7 +120,7 @@ public class JavaDataFrameSuite {
@Ignore
public void testShow() {
// This test case is intended ignored, but to make sure it compiles correctly
- Dataset<Row> df = context.table("testData");
+ Dataset<Row> df = spark.table("testData");
df.show();
df.show(1000);
}
@@ -194,7 +188,7 @@ public class JavaDataFrameSuite {
public void testCreateDataFrameFromLocalJavaBeans() {
Bean bean = new Bean();
List<Bean> data = Arrays.asList(bean);
- Dataset<Row> df = context.createDataFrame(data, Bean.class);
+ Dataset<Row> df = spark.createDataFrame(data, Bean.class);
validateDataFrameWithBeans(bean, df);
}
@@ -202,7 +196,7 @@ public class JavaDataFrameSuite {
public void testCreateDataFrameFromJavaBeans() {
Bean bean = new Bean();
JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean));
- Dataset<Row> df = context.createDataFrame(rdd, Bean.class);
+ Dataset<Row> df = spark.createDataFrame(rdd, Bean.class);
validateDataFrameWithBeans(bean, df);
}
@@ -210,7 +204,7 @@ public class JavaDataFrameSuite {
public void testCreateDataFromFromList() {
StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true)));
List<Row> rows = Arrays.asList(RowFactory.create(0));
- Dataset<Row> df = context.createDataFrame(rows, schema);
+ Dataset<Row> df = spark.createDataFrame(rows, schema);
List<Row> result = df.collectAsList();
Assert.assertEquals(1, result.size());
}
@@ -239,7 +233,7 @@ public class JavaDataFrameSuite {
@Test
public void testCrosstab() {
- Dataset<Row> df = context.table("testData2");
+ Dataset<Row> df = spark.table("testData2");
Dataset<Row> crosstab = df.stat().crosstab("a", "b");
String[] columnNames = crosstab.schema().fieldNames();
Assert.assertEquals("a_b", columnNames[0]);
@@ -258,7 +252,7 @@ public class JavaDataFrameSuite {
@Test
public void testFrequentItems() {
- Dataset<Row> df = context.table("testData2");
+ Dataset<Row> df = spark.table("testData2");
String[] cols = {"a"};
Dataset<Row> results = df.stat().freqItems(cols, 0.2);
Assert.assertTrue(results.collectAsList().get(0).getSeq(0).contains(1));
@@ -266,21 +260,21 @@ public class JavaDataFrameSuite {
@Test
public void testCorrelation() {
- Dataset<Row> df = context.table("testData2");
+ Dataset<Row> df = spark.table("testData2");
Double pearsonCorr = df.stat().corr("a", "b", "pearson");
Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6);
}
@Test
public void testCovariance() {
- Dataset<Row> df = context.table("testData2");
+ Dataset<Row> df = spark.table("testData2");
Double result = df.stat().cov("a", "b");
Assert.assertTrue(Math.abs(result) < 1.0e-6);
}
@Test
public void testSampleBy() {
- Dataset<Row> df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
+ Dataset<Row> df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
Dataset<Row> sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
List<Row> actual = sampled.groupBy("key").count().orderBy("key").collectAsList();
Assert.assertEquals(0, actual.get(0).getLong(0));
@@ -291,7 +285,7 @@ public class JavaDataFrameSuite {
@Test
public void pivot() {
- Dataset<Row> df = context.table("courseSales");
+ Dataset<Row> df = spark.table("courseSales");
List<Row> actual = df.groupBy("year")
.pivot("course", Arrays.<Object>asList("dotNET", "Java"))
.agg(sum("earnings")).orderBy("year").collectAsList();
@@ -324,10 +318,10 @@ public class JavaDataFrameSuite {
@Test
public void testGenericLoad() {
- Dataset<Row> df1 = context.read().format("text").load(getResource("text-suite.txt"));
+ Dataset<Row> df1 = spark.read().format("text").load(getResource("text-suite.txt"));
Assert.assertEquals(4L, df1.count());
- Dataset<Row> df2 = context.read().format("text").load(
+ Dataset<Row> df2 = spark.read().format("text").load(
getResource("text-suite.txt"),
getResource("text-suite2.txt"));
Assert.assertEquals(5L, df2.count());
@@ -335,10 +329,10 @@ public class JavaDataFrameSuite {
@Test
public void testTextLoad() {
- Dataset<String> ds1 = context.read().text(getResource("text-suite.txt"));
+ Dataset<String> ds1 = spark.read().text(getResource("text-suite.txt"));
Assert.assertEquals(4L, ds1.count());
- Dataset<String> ds2 = context.read().text(
+ Dataset<String> ds2 = spark.read().text(
getResource("text-suite.txt"),
getResource("text-suite2.txt"));
Assert.assertEquals(5L, ds2.count());
@@ -346,7 +340,7 @@ public class JavaDataFrameSuite {
@Test
public void testCountMinSketch() {
- Dataset<Long> df = context.range(1000);
+ Dataset<Long> df = spark.range(1000);
CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42);
Assert.assertEquals(sketch1.totalCount(), 1000);
@@ -371,7 +365,7 @@ public class JavaDataFrameSuite {
@Test
public void testBloomFilter() {
- Dataset<Long> df = context.range(1000);
+ Dataset<Long> df = spark.range(1000);
BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03);
Assert.assertTrue(filter1.expectedFpp() - 0.03 < 1e-3);
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index f1b1c22e4a..8354a5bdac 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -23,46 +23,43 @@ import java.sql.Date;
import java.sql.Timestamp;
import java.util.*;
-import com.google.common.base.Objects;
-import org.junit.rules.ExpectedException;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
import scala.Tuple5;
+import com.google.common.base.Objects;
import org.junit.*;
+import org.junit.rules.ExpectedException;
import org.apache.spark.Accumulator;
-import org.apache.spark.SparkContext;
-import org.apache.spark.api.java.function.*;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.*;
import org.apache.spark.sql.*;
-import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.catalyst.encoders.OuterScopes;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
+import org.apache.spark.sql.test.TestSparkSession;
import org.apache.spark.sql.types.StructType;
-
-import static org.apache.spark.sql.functions.*;
+import static org.apache.spark.sql.functions.col;
+import static org.apache.spark.sql.functions.expr;
import static org.apache.spark.sql.types.DataTypes.*;
public class JavaDatasetSuite implements Serializable {
+ private transient TestSparkSession spark;
private transient JavaSparkContext jsc;
- private transient TestSQLContext context;
@Before
public void setUp() {
// Trigger static initializer of TestData
- SparkContext sc = new SparkContext("local[*]", "testing");
- jsc = new JavaSparkContext(sc);
- context = new TestSQLContext(sc);
- context.loadTestData();
+ spark = new TestSparkSession();
+ jsc = new JavaSparkContext(spark.sparkContext());
+ spark.loadTestData();
}
@After
public void tearDown() {
- context.sparkContext().stop();
- context = null;
- jsc = null;
+ spark.stop();
+ spark = null;
}
private <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) {
@@ -72,7 +69,7 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testCollect() {
List<String> data = Arrays.asList("hello", "world");
- Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
List<String> collected = ds.collectAsList();
Assert.assertEquals(Arrays.asList("hello", "world"), collected);
}
@@ -80,7 +77,7 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testTake() {
List<String> data = Arrays.asList("hello", "world");
- Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
List<String> collected = ds.takeAsList(1);
Assert.assertEquals(Arrays.asList("hello"), collected);
}
@@ -88,7 +85,7 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testToLocalIterator() {
List<String> data = Arrays.asList("hello", "world");
- Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
Iterator<String> iter = ds.toLocalIterator();
Assert.assertEquals("hello", iter.next());
Assert.assertEquals("world", iter.next());
@@ -98,7 +95,7 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testCommonOperation() {
List<String> data = Arrays.asList("hello", "world");
- Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
Assert.assertEquals("hello", ds.first());
Dataset<String> filtered = ds.filter(new FilterFunction<String>() {
@@ -149,7 +146,7 @@ public class JavaDatasetSuite implements Serializable {
public void testForeach() {
final Accumulator<Integer> accum = jsc.accumulator(0);
List<String> data = Arrays.asList("a", "b", "c");
- Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
ds.foreach(new ForeachFunction<String>() {
@Override
@@ -163,7 +160,7 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testReduce() {
List<Integer> data = Arrays.asList(1, 2, 3);
- Dataset<Integer> ds = context.createDataset(data, Encoders.INT());
+ Dataset<Integer> ds = spark.createDataset(data, Encoders.INT());
int reduced = ds.reduce(new ReduceFunction<Integer>() {
@Override
@@ -177,7 +174,7 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testGroupBy() {
List<String> data = Arrays.asList("a", "foo", "bar");
- Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
KeyValueGroupedDataset<Integer, String> grouped = ds.groupByKey(
new MapFunction<String, Integer>() {
@Override
@@ -227,7 +224,7 @@ public class JavaDatasetSuite implements Serializable {
toSet(reduced.collectAsList()));
List<Integer> data2 = Arrays.asList(2, 6, 10);
- Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT());
+ Dataset<Integer> ds2 = spark.createDataset(data2, Encoders.INT());
KeyValueGroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey(
new MapFunction<Integer, Integer>() {
@Override
@@ -261,7 +258,7 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testSelect() {
List<Integer> data = Arrays.asList(2, 6);
- Dataset<Integer> ds = context.createDataset(data, Encoders.INT());
+ Dataset<Integer> ds = spark.createDataset(data, Encoders.INT());
Dataset<Tuple2<Integer, String>> selected = ds.select(
expr("value + 1"),
@@ -275,12 +272,12 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testSetOperation() {
List<String> data = Arrays.asList("abc", "abc", "xyz");
- Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
Assert.assertEquals(asSet("abc", "xyz"), toSet(ds.distinct().collectAsList()));
List<String> data2 = Arrays.asList("xyz", "foo", "foo");
- Dataset<String> ds2 = context.createDataset(data2, Encoders.STRING());
+ Dataset<String> ds2 = spark.createDataset(data2, Encoders.STRING());
Dataset<String> intersected = ds.intersect(ds2);
Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList());
@@ -307,9 +304,9 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testJoin() {
List<Integer> data = Arrays.asList(1, 2, 3);
- Dataset<Integer> ds = context.createDataset(data, Encoders.INT()).as("a");
+ Dataset<Integer> ds = spark.createDataset(data, Encoders.INT()).as("a");
List<Integer> data2 = Arrays.asList(2, 3, 4);
- Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()).as("b");
+ Dataset<Integer> ds2 = spark.createDataset(data2, Encoders.INT()).as("b");
Dataset<Tuple2<Integer, Integer>> joined =
ds.joinWith(ds2, col("a.value").equalTo(col("b.value")));
@@ -322,21 +319,21 @@ public class JavaDatasetSuite implements Serializable {
public void testTupleEncoder() {
Encoder<Tuple2<Integer, String>> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING());
List<Tuple2<Integer, String>> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b"));
- Dataset<Tuple2<Integer, String>> ds2 = context.createDataset(data2, encoder2);
+ Dataset<Tuple2<Integer, String>> ds2 = spark.createDataset(data2, encoder2);
Assert.assertEquals(data2, ds2.collectAsList());
Encoder<Tuple3<Integer, Long, String>> encoder3 =
Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING());
List<Tuple3<Integer, Long, String>> data3 =
Arrays.asList(new Tuple3<>(1, 2L, "a"));
- Dataset<Tuple3<Integer, Long, String>> ds3 = context.createDataset(data3, encoder3);
+ Dataset<Tuple3<Integer, Long, String>> ds3 = spark.createDataset(data3, encoder3);
Assert.assertEquals(data3, ds3.collectAsList());
Encoder<Tuple4<Integer, String, Long, String>> encoder4 =
Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING());
List<Tuple4<Integer, String, Long, String>> data4 =
Arrays.asList(new Tuple4<>(1, "b", 2L, "a"));
- Dataset<Tuple4<Integer, String, Long, String>> ds4 = context.createDataset(data4, encoder4);
+ Dataset<Tuple4<Integer, String, Long, String>> ds4 = spark.createDataset(data4, encoder4);
Assert.assertEquals(data4, ds4.collectAsList());
Encoder<Tuple5<Integer, String, Long, String, Boolean>> encoder5 =
@@ -345,7 +342,7 @@ public class JavaDatasetSuite implements Serializable {
List<Tuple5<Integer, String, Long, String, Boolean>> data5 =
Arrays.asList(new Tuple5<>(1, "b", 2L, "a", true));
Dataset<Tuple5<Integer, String, Long, String, Boolean>> ds5 =
- context.createDataset(data5, encoder5);
+ spark.createDataset(data5, encoder5);
Assert.assertEquals(data5, ds5.collectAsList());
}
@@ -356,7 +353,7 @@ public class JavaDatasetSuite implements Serializable {
Encoders.tuple(Encoders.tuple(Encoders.INT(), Encoders.STRING()), Encoders.STRING());
List<Tuple2<Tuple2<Integer, String>, String>> data =
Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b"));
- Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = context.createDataset(data, encoder);
+ Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = spark.createDataset(data, encoder);
Assert.assertEquals(data, ds.collectAsList());
// test (int, (string, string, long))
@@ -366,7 +363,7 @@ public class JavaDatasetSuite implements Serializable {
List<Tuple2<Integer, Tuple3<String, String, Long>>> data2 =
Arrays.asList(tuple2(1, new Tuple3<>("a", "b", 3L)));
Dataset<Tuple2<Integer, Tuple3<String, String, Long>>> ds2 =
- context.createDataset(data2, encoder2);
+ spark.createDataset(data2, encoder2);
Assert.assertEquals(data2, ds2.collectAsList());
// test (int, ((string, long), string))
@@ -376,7 +373,7 @@ public class JavaDatasetSuite implements Serializable {
List<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> data3 =
Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b")));
Dataset<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> ds3 =
- context.createDataset(data3, encoder3);
+ spark.createDataset(data3, encoder3);
Assert.assertEquals(data3, ds3.collectAsList());
}
@@ -390,7 +387,7 @@ public class JavaDatasetSuite implements Serializable {
1.7976931348623157E308, new BigDecimal("0.922337203685477589"),
Date.valueOf("1970-01-01"), new Timestamp(System.currentTimeMillis()), Float.MAX_VALUE));
Dataset<Tuple5<Double, BigDecimal, Date, Timestamp, Float>> ds =
- context.createDataset(data, encoder);
+ spark.createDataset(data, encoder);
Assert.assertEquals(data, ds.collectAsList());
}
@@ -441,7 +438,7 @@ public class JavaDatasetSuite implements Serializable {
Encoder<KryoSerializable> encoder = Encoders.kryo(KryoSerializable.class);
List<KryoSerializable> data = Arrays.asList(
new KryoSerializable("hello"), new KryoSerializable("world"));
- Dataset<KryoSerializable> ds = context.createDataset(data, encoder);
+ Dataset<KryoSerializable> ds = spark.createDataset(data, encoder);
Assert.assertEquals(data, ds.collectAsList());
}
@@ -450,14 +447,14 @@ public class JavaDatasetSuite implements Serializable {
Encoder<JavaSerializable> encoder = Encoders.javaSerialization(JavaSerializable.class);
List<JavaSerializable> data = Arrays.asList(
new JavaSerializable("hello"), new JavaSerializable("world"));
- Dataset<JavaSerializable> ds = context.createDataset(data, encoder);
+ Dataset<JavaSerializable> ds = spark.createDataset(data, encoder);
Assert.assertEquals(data, ds.collectAsList());
}
@Test
public void testRandomSplit() {
List<String> data = Arrays.asList("hello", "world", "from", "spark");
- Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
double[] arraySplit = {1, 2, 3};
List<Dataset<String>> randomSplit = ds.randomSplitAsList(arraySplit, 1);
@@ -647,14 +644,14 @@ public class JavaDatasetSuite implements Serializable {
obj2.setF(Arrays.asList(300L, null, 400L));
List<SimpleJavaBean> data = Arrays.asList(obj1, obj2);
- Dataset<SimpleJavaBean> ds = context.createDataset(data, Encoders.bean(SimpleJavaBean.class));
+ Dataset<SimpleJavaBean> ds = spark.createDataset(data, Encoders.bean(SimpleJavaBean.class));
Assert.assertEquals(data, ds.collectAsList());
NestedJavaBean obj3 = new NestedJavaBean();
obj3.setA(obj1);
List<NestedJavaBean> data2 = Arrays.asList(obj3);
- Dataset<NestedJavaBean> ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class));
+ Dataset<NestedJavaBean> ds2 = spark.createDataset(data2, Encoders.bean(NestedJavaBean.class));
Assert.assertEquals(data2, ds2.collectAsList());
Row row1 = new GenericRow(new Object[]{
@@ -678,7 +675,7 @@ public class JavaDatasetSuite implements Serializable {
.add("d", createArrayType(StringType))
.add("e", createArrayType(StringType))
.add("f", createArrayType(LongType));
- Dataset<SimpleJavaBean> ds3 = context.createDataFrame(Arrays.asList(row1, row2), schema)
+ Dataset<SimpleJavaBean> ds3 = spark.createDataFrame(Arrays.asList(row1, row2), schema)
.as(Encoders.bean(SimpleJavaBean.class));
Assert.assertEquals(data, ds3.collectAsList());
}
@@ -692,7 +689,7 @@ public class JavaDatasetSuite implements Serializable {
obj.setB(new Date(0));
obj.setC(java.math.BigDecimal.valueOf(1));
Dataset<SimpleJavaBean2> ds =
- context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class));
+ spark.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class));
ds.collect();
}
@@ -776,7 +773,7 @@ public class JavaDatasetSuite implements Serializable {
})
});
- Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<Row> df = spark.createDataFrame(Collections.singletonList(row), schema);
Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
SmallBean smallBean = new SmallBean();
@@ -793,7 +790,7 @@ public class JavaDatasetSuite implements Serializable {
{
Row row = new GenericRow(new Object[] { null });
- Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<Row> df = spark.createDataFrame(Collections.singletonList(row), schema);
Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
NestedSmallBean nestedSmallBean = new NestedSmallBean();
@@ -810,7 +807,7 @@ public class JavaDatasetSuite implements Serializable {
})
});
- Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<Row> df = spark.createDataFrame(Collections.singletonList(row), schema);
Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
ds.collect();
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
index 4a78dca7fe..2274912521 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
@@ -24,33 +24,30 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.SparkContext;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.api.java.UDF2;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.types.DataTypes;
// The test suite itself is Serializable so that anonymous Function implementations can be
// serialized, as an alternative to converting these anonymous classes to static inner classes;
// see http://stackoverflow.com/questions/758570/.
public class JavaUDFSuite implements Serializable {
- private transient JavaSparkContext sc;
- private transient SQLContext sqlContext;
+ private transient SparkSession spark;
@Before
public void setUp() {
- SparkContext _sc = new SparkContext("local[*]", "testing");
- sqlContext = new SQLContext(_sc);
- sc = new JavaSparkContext(_sc);
+ spark = SparkSession.builder()
+ .master("local[*]")
+ .appName("testing")
+ .getOrCreate();
}
@After
public void tearDown() {
- sqlContext.sparkContext().stop();
- sqlContext = null;
- sc = null;
+ spark.stop();
+ spark = null;
}
@SuppressWarnings("unchecked")
@@ -60,14 +57,14 @@ public class JavaUDFSuite implements Serializable {
// sqlContext.registerFunction(
// "stringLengthTest", (String str) -> str.length(), DataType.IntegerType);
- sqlContext.udf().register("stringLengthTest", new UDF1<String, Integer>() {
+ spark.udf().register("stringLengthTest", new UDF1<String, Integer>() {
@Override
public Integer call(String str) {
return str.length();
}
}, DataTypes.IntegerType);
- Row result = sqlContext.sql("SELECT stringLengthTest('test')").head();
+ Row result = spark.sql("SELECT stringLengthTest('test')").head();
Assert.assertEquals(4, result.getInt(0));
}
@@ -80,14 +77,14 @@ public class JavaUDFSuite implements Serializable {
// (String str1, String str2) -> str1.length() + str2.length,
// DataType.IntegerType);
- sqlContext.udf().register("stringLengthTest", new UDF2<String, String, Integer>() {
+ spark.udf().register("stringLengthTest", new UDF2<String, String, Integer>() {
@Override
public Integer call(String str1, String str2) {
return str1.length() + str2.length();
}
}, DataTypes.IntegerType);
- Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head();
+ Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
Assert.assertEquals(9, result.getInt(0));
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java
index 7863177093..059c2d9f2c 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java
@@ -26,36 +26,30 @@ import scala.Tuple2;
import org.junit.After;
import org.junit.Before;
-import org.apache.spark.SparkContext;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.KeyValueGroupedDataset;
-import org.apache.spark.sql.test.TestSQLContext;
+import org.apache.spark.sql.test.TestSparkSession;
/**
* Common test base shared across this and Java8DatasetAggregatorSuite.
*/
public class JavaDatasetAggregatorSuiteBase implements Serializable {
- protected transient JavaSparkContext jsc;
- protected transient TestSQLContext context;
+ private transient TestSparkSession spark;
@Before
public void setUp() {
// Trigger static initializer of TestData
- SparkContext sc = new SparkContext("local[*]", "testing");
- jsc = new JavaSparkContext(sc);
- context = new TestSQLContext(sc);
- context.loadTestData();
+ spark = new TestSparkSession();
+ spark.loadTestData();
}
@After
public void tearDown() {
- context.sparkContext().stop();
- context = null;
- jsc = null;
+ spark.stop();
+ spark = null;
}
protected <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) {
@@ -66,7 +60,7 @@ public class JavaDatasetAggregatorSuiteBase implements Serializable {
Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT());
List<Tuple2<String, Integer>> data =
Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
- Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder);
+ Dataset<Tuple2<String, Integer>> ds = spark.createDataset(data, encoder);
return ds.groupByKey(
new MapFunction<Tuple2<String, Integer>, String>() {
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
index 9e65158eb0..d0435e4d43 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
@@ -19,14 +19,16 @@ package test.org.apache.spark.sql.sources;
import java.io.File;
import java.io.IOException;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
@@ -37,8 +39,8 @@ import org.apache.spark.util.Utils;
public class JavaSaveLoadSuite {
- private transient JavaSparkContext sc;
- private transient SQLContext sqlContext;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
File path;
Dataset<Row> df;
@@ -52,9 +54,11 @@ public class JavaSaveLoadSuite {
@Before
public void setUp() throws IOException {
- SparkContext _sc = new SparkContext("local[*]", "testing");
- sqlContext = new SQLContext(_sc);
- sc = new JavaSparkContext(_sc);
+ spark = SparkSession.builder()
+ .master("local[*]")
+ .appName("testing")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
path =
Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile();
@@ -66,16 +70,15 @@ public class JavaSaveLoadSuite {
for (int i = 0; i < 10; i++) {
jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}");
}
- JavaRDD<String> rdd = sc.parallelize(jsonObjects);
- df = sqlContext.read().json(rdd);
+ JavaRDD<String> rdd = jsc.parallelize(jsonObjects);
+ df = spark.read().json(rdd);
df.registerTempTable("jsonTable");
}
@After
public void tearDown() {
- sqlContext.sparkContext().stop();
- sqlContext = null;
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -83,7 +86,7 @@ public class JavaSaveLoadSuite {
Map<String, String> options = new HashMap<>();
options.put("path", path.toString());
df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save();
- Dataset<Row> loadedDF = sqlContext.read().format("json").options(options).load();
+ Dataset<Row> loadedDF = spark.read().format("json").options(options).load();
checkAnswer(loadedDF, df.collectAsList());
}
@@ -96,8 +99,8 @@ public class JavaSaveLoadSuite {
List<StructField> fields = new ArrayList<>();
fields.add(DataTypes.createStructField("b", DataTypes.StringType, true));
StructType schema = DataTypes.createStructType(fields);
- Dataset<Row> loadedDF = sqlContext.read().format("json").schema(schema).options(options).load();
+ Dataset<Row> loadedDF = spark.read().format("json").schema(schema).options(options).load();
- checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList());
+ checkAnswer(loadedDF, spark.sql("SELECT b FROM jsonTable").collectAsList());
}
}