aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2016-05-20 16:40:33 -0700
committerAndrew Or <andrew@databricks.com>2016-05-20 16:40:33 -0700
commit127bf1bb07967e2e4f99ad7abaa7f6fab3b3f407 (patch)
treea127031cd361df2f1d895cb11489f8e183c76f73 /examples/src/main
parent06c9f520714e07259c6f8ce6f9ea5a230a278cb5 (diff)
downloadspark-127bf1bb07967e2e4f99ad7abaa7f6fab3b3f407.tar.gz
spark-127bf1bb07967e2e4f99ad7abaa7f6fab3b3f407.tar.bz2
spark-127bf1bb07967e2e4f99ad7abaa7f6fab3b3f407.zip
[SPARK-15031][EXAMPLE] Use SparkSession in examples
## What changes were proposed in this pull request? Use `SparkSession` according to [SPARK-15031](https://issues.apache.org/jira/browse/SPARK-15031) `MLLLIB` is not recommended to use now, so examples in `MLLIB` are ignored in this PR. `StreamingContext` can not be directly obtained from `SparkSession`, so example in `Streaming` are ignored too. cc andrewor14 ## How was this patch tested? manual tests with spark-submit Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #13164 from zhengruifeng/use_sparksession_ii.
Diffstat (limited to 'examples/src/main')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java14
-rw-r--r--examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java12
-rw-r--r--examples/src/main/java/org/apache/spark/examples/JavaPageRank.java13
-rw-r--r--examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java12
-rw-r--r--examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java19
-rw-r--r--examples/src/main/java/org/apache/spark/examples/JavaTC.java15
-rw-r--r--examples/src/main/java/org/apache/spark/examples/JavaWordCount.java15
-rwxr-xr-xexamples/src/main/python/als.py12
-rw-r--r--examples/src/main/python/avro_inputformat.py12
-rwxr-xr-xexamples/src/main/python/kmeans.py12
-rwxr-xr-xexamples/src/main/python/logistic_regression.py13
-rwxr-xr-xexamples/src/main/python/pagerank.py11
-rw-r--r--examples/src/main/python/parquet_inputformat.py12
-rwxr-xr-xexamples/src/main/python/pi.py12
-rwxr-xr-xexamples/src/main/python/sort.py13
-rwxr-xr-xexamples/src/main/python/transitive_closure.py12
-rwxr-xr-xexamples/src/main/python/wordcount.py13
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala16
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala13
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala12
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala12
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala12
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala13
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala13
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala13
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SparkALS.scala12
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala16
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala15
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SparkLR.scala13
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala13
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SparkPi.scala11
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SparkTC.scala11
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala2
33 files changed, 276 insertions, 143 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java
index f64155ce3c..ded442096c 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java
@@ -17,11 +17,10 @@
package org.apache.spark.examples;
-import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.sql.SparkSession;
import java.io.Serializable;
import java.util.Arrays;
@@ -122,9 +121,12 @@ public final class JavaHdfsLR {
showWarning();
- SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR");
- JavaSparkContext sc = new JavaSparkContext(sparkConf);
- JavaRDD<String> lines = sc.textFile(args[0]);
+ SparkSession spark = SparkSession
+ .builder()
+ .appName("JavaHdfsLR")
+ .getOrCreate();
+
+ JavaRDD<String> lines = spark.read().text(args[0]).javaRDD();
JavaRDD<DataPoint> points = lines.map(new ParsePoint()).cache();
int ITERATIONS = Integer.parseInt(args[1]);
@@ -152,6 +154,6 @@ public final class JavaHdfsLR {
System.out.print("Final w: ");
printWeights(w);
- sc.stop();
+ spark.stop();
}
}
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java
index ebb0687b14..7775443861 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java
@@ -20,12 +20,13 @@ package org.apache.spark.examples;
import com.google.common.collect.Lists;
import scala.Tuple2;
import scala.Tuple3;
-import org.apache.spark.SparkConf;
+
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.sql.SparkSession;
import java.io.Serializable;
import java.util.List;
@@ -99,9 +100,12 @@ public final class JavaLogQuery {
}
public static void main(String[] args) {
+ SparkSession spark = SparkSession
+ .builder()
+ .appName("JavaLogQuery")
+ .getOrCreate();
- SparkConf sparkConf = new SparkConf().setAppName("JavaLogQuery");
- JavaSparkContext jsc = new JavaSparkContext(sparkConf);
+ JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
JavaRDD<String> dataSet = (args.length == 1) ? jsc.textFile(args[0]) : jsc.parallelize(exampleApacheLogs);
@@ -123,6 +127,6 @@ public final class JavaLogQuery {
for (Tuple2<?,?> t : output) {
System.out.println(t._1() + "\t" + t._2());
}
- jsc.stop();
+ spark.stop();
}
}
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
index 229d123441..128b5ab17c 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
@@ -26,14 +26,13 @@ import scala.Tuple2;
import com.google.common.collect.Iterables;
-import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.sql.SparkSession;
/**
* Computes the PageRank of URLs from an input file. Input file should
@@ -73,15 +72,17 @@ public final class JavaPageRank {
showWarning();
- SparkConf sparkConf = new SparkConf().setAppName("JavaPageRank");
- JavaSparkContext ctx = new JavaSparkContext(sparkConf);
+ SparkSession spark = SparkSession
+ .builder()
+ .appName("JavaPageRank")
+ .getOrCreate();
// Loads in input file. It should be in format of:
// URL neighbor URL
// URL neighbor URL
// URL neighbor URL
// ...
- JavaRDD<String> lines = ctx.textFile(args[0], 1);
+ JavaRDD<String> lines = spark.read().text(args[0]).javaRDD();
// Loads all URLs from input file and initialize their neighbors.
JavaPairRDD<String, Iterable<String>> links = lines.mapToPair(
@@ -132,6 +133,6 @@ public final class JavaPageRank {
System.out.println(tuple._1() + " has rank: " + tuple._2() + ".");
}
- ctx.stop();
+ spark.stop();
}
}
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java
index 04a57a6bfb..7df145e311 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java
@@ -17,11 +17,11 @@
package org.apache.spark.examples;
-import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.sql.SparkSession;
import java.util.ArrayList;
import java.util.List;
@@ -33,8 +33,12 @@ import java.util.List;
public final class JavaSparkPi {
public static void main(String[] args) throws Exception {
- SparkConf sparkConf = new SparkConf().setAppName("JavaSparkPi");
- JavaSparkContext jsc = new JavaSparkContext(sparkConf);
+ SparkSession spark = SparkSession
+ .builder()
+ .appName("JavaSparkPi")
+ .getOrCreate();
+
+ JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
int slices = (args.length == 1) ? Integer.parseInt(args[0]) : 2;
int n = 100000 * slices;
@@ -61,6 +65,6 @@ public final class JavaSparkPi {
System.out.println("Pi is roughly " + 4.0 * count / n);
- jsc.stop();
+ spark.stop();
}
}
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java b/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java
index e68ec74c3e..6f899c772e 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java
@@ -17,13 +17,14 @@
package org.apache.spark.examples;
-import org.apache.spark.SparkConf;
import org.apache.spark.SparkJobInfo;
import org.apache.spark.SparkStageInfo;
import org.apache.spark.api.java.JavaFutureAction;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
+import org.apache.spark.sql.SparkSession;
+
import java.util.Arrays;
import java.util.List;
@@ -44,11 +45,15 @@ public final class JavaStatusTrackerDemo {
}
public static void main(String[] args) throws Exception {
- SparkConf sparkConf = new SparkConf().setAppName(APP_NAME);
- final JavaSparkContext sc = new JavaSparkContext(sparkConf);
+ SparkSession spark = SparkSession
+ .builder()
+ .appName(APP_NAME)
+ .getOrCreate();
+
+ final JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
// Example of implementing a progress reporter for a simple job.
- JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 5).map(
+ JavaRDD<Integer> rdd = jsc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 5).map(
new IdentityWithDelay<Integer>());
JavaFutureAction<List<Integer>> jobFuture = rdd.collectAsync();
while (!jobFuture.isDone()) {
@@ -58,13 +63,13 @@ public final class JavaStatusTrackerDemo {
continue;
}
int currentJobId = jobIds.get(jobIds.size() - 1);
- SparkJobInfo jobInfo = sc.statusTracker().getJobInfo(currentJobId);
- SparkStageInfo stageInfo = sc.statusTracker().getStageInfo(jobInfo.stageIds()[0]);
+ SparkJobInfo jobInfo = jsc.statusTracker().getJobInfo(currentJobId);
+ SparkStageInfo stageInfo = jsc.statusTracker().getStageInfo(jobInfo.stageIds()[0]);
System.out.println(stageInfo.numTasks() + " tasks total: " + stageInfo.numActiveTasks() +
" active, " + stageInfo.numCompletedTasks() + " complete");
}
System.out.println("Job results are: " + jobFuture.get());
- sc.stop();
+ spark.stop();
}
}
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaTC.java b/examples/src/main/java/org/apache/spark/examples/JavaTC.java
index ca10384212..f12ca77ed1 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaTC.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaTC.java
@@ -25,10 +25,10 @@ import java.util.Set;
import scala.Tuple2;
-import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.sql.SparkSession;
/**
* Transitive closure on a graph, implemented in Java.
@@ -64,10 +64,15 @@ public final class JavaTC {
}
public static void main(String[] args) {
- SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR");
- JavaSparkContext sc = new JavaSparkContext(sparkConf);
+ SparkSession spark = SparkSession
+ .builder()
+ .appName("JavaTC")
+ .getOrCreate();
+
+ JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
+
Integer slices = (args.length > 0) ? Integer.parseInt(args[0]): 2;
- JavaPairRDD<Integer, Integer> tc = sc.parallelizePairs(generateGraph(), slices).cache();
+ JavaPairRDD<Integer, Integer> tc = jsc.parallelizePairs(generateGraph(), slices).cache();
// Linear transitive closure: each round grows paths by one edge,
// by joining the graph's edges with the already-discovered paths.
@@ -94,6 +99,6 @@ public final class JavaTC {
} while (nextCount != oldCount);
System.out.println("TC has " + tc.count() + " edges.");
- sc.stop();
+ spark.stop();
}
}
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java
index 3ff5412b93..1caee60e34 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java
@@ -18,13 +18,13 @@
package org.apache.spark.examples;
import scala.Tuple2;
-import org.apache.spark.SparkConf;
+
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.sql.SparkSession;
import java.util.Arrays;
import java.util.Iterator;
@@ -41,9 +41,12 @@ public final class JavaWordCount {
System.exit(1);
}
- SparkConf sparkConf = new SparkConf().setAppName("JavaWordCount");
- JavaSparkContext ctx = new JavaSparkContext(sparkConf);
- JavaRDD<String> lines = ctx.textFile(args[0], 1);
+ SparkSession spark = SparkSession
+ .builder()
+ .appName("JavaWordCount")
+ .getOrCreate();
+
+ JavaRDD<String> lines = spark.read().text(args[0]).javaRDD();
JavaRDD<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
@Override
@@ -72,6 +75,6 @@ public final class JavaWordCount {
for (Tuple2<?,?> tuple : output) {
System.out.println(tuple._1() + ": " + tuple._2());
}
- ctx.stop();
+ spark.stop();
}
}
diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py
index f07020b503..81562e20a9 100755
--- a/examples/src/main/python/als.py
+++ b/examples/src/main/python/als.py
@@ -28,7 +28,7 @@ import sys
import numpy as np
from numpy.random import rand
from numpy import matrix
-from pyspark import SparkContext
+from pyspark.sql import SparkSession
LAMBDA = 0.01 # regularization
np.random.seed(42)
@@ -62,7 +62,13 @@ if __name__ == "__main__":
example. Please use pyspark.ml.recommendation.ALS for more
conventional use.""", file=sys.stderr)
- sc = SparkContext(appName="PythonALS")
+ spark = SparkSession\
+ .builder\
+ .appName("PythonALS")\
+ .getOrCreate()
+
+ sc = spark._sc
+
M = int(sys.argv[1]) if len(sys.argv) > 1 else 100
U = int(sys.argv[2]) if len(sys.argv) > 2 else 500
F = int(sys.argv[3]) if len(sys.argv) > 3 else 10
@@ -99,4 +105,4 @@ if __name__ == "__main__":
print("Iteration %d:" % i)
print("\nRMSE: %5.4f\n" % error)
- sc.stop()
+ spark.stop()
diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py
index da368ac628..3f65e8f79a 100644
--- a/examples/src/main/python/avro_inputformat.py
+++ b/examples/src/main/python/avro_inputformat.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import sys
-from pyspark import SparkContext
from functools import reduce
+from pyspark.sql import SparkSession
"""
Read data file users.avro in local Spark distro:
@@ -64,7 +64,13 @@ if __name__ == "__main__":
exit(-1)
path = sys.argv[1]
- sc = SparkContext(appName="AvroKeyInputFormat")
+
+ spark = SparkSession\
+ .builder\
+ .appName("AvroKeyInputFormat")\
+ .getOrCreate()
+
+ sc = spark._sc
conf = None
if len(sys.argv) == 3:
@@ -82,4 +88,4 @@ if __name__ == "__main__":
for k in output:
print(k)
- sc.stop()
+ spark.stop()
diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py
index 3426e491dc..92e0a3ae2e 100755
--- a/examples/src/main/python/kmeans.py
+++ b/examples/src/main/python/kmeans.py
@@ -27,7 +27,7 @@ from __future__ import print_function
import sys
import numpy as np
-from pyspark import SparkContext
+from pyspark.sql import SparkSession
def parseVector(line):
@@ -55,8 +55,12 @@ if __name__ == "__main__":
as an example! Please refer to examples/src/main/python/ml/kmeans_example.py for an
example on how to use ML's KMeans implementation.""", file=sys.stderr)
- sc = SparkContext(appName="PythonKMeans")
- lines = sc.textFile(sys.argv[1])
+ spark = SparkSession\
+ .builder\
+ .appName("PythonKMeans")\
+ .getOrCreate()
+
+ lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])
data = lines.map(parseVector).cache()
K = int(sys.argv[2])
convergeDist = float(sys.argv[3])
@@ -79,4 +83,4 @@ if __name__ == "__main__":
print("Final centers: " + str(kPoints))
- sc.stop()
+ spark.stop()
diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py
index 7d33be7e81..01c938454b 100755
--- a/examples/src/main/python/logistic_regression.py
+++ b/examples/src/main/python/logistic_regression.py
@@ -27,7 +27,7 @@ from __future__ import print_function
import sys
import numpy as np
-from pyspark import SparkContext
+from pyspark.sql import SparkSession
D = 10 # Number of dimensions
@@ -55,8 +55,13 @@ if __name__ == "__main__":
Please refer to examples/src/main/python/ml/logistic_regression_with_elastic_net.py
to see how ML's implementation is used.""", file=sys.stderr)
- sc = SparkContext(appName="PythonLR")
- points = sc.textFile(sys.argv[1]).mapPartitions(readPointBatch).cache()
+ spark = SparkSession\
+ .builder\
+ .appName("PythonLR")\
+ .getOrCreate()
+
+ points = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])\
+ .mapPartitions(readPointBatch).cache()
iterations = int(sys.argv[2])
# Initialize w to a random value
@@ -80,4 +85,4 @@ if __name__ == "__main__":
print("Final w: " + str(w))
- sc.stop()
+ spark.stop()
diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py
index 2fdc9773d4..a399a9c37c 100755
--- a/examples/src/main/python/pagerank.py
+++ b/examples/src/main/python/pagerank.py
@@ -25,7 +25,7 @@ import re
import sys
from operator import add
-from pyspark import SparkContext
+from pyspark.sql import SparkSession
def computeContribs(urls, rank):
@@ -51,14 +51,17 @@ if __name__ == "__main__":
file=sys.stderr)
# Initialize the spark context.
- sc = SparkContext(appName="PythonPageRank")
+ spark = SparkSession\
+ .builder\
+ .appName("PythonPageRank")\
+ .getOrCreate()
# Loads in input file. It should be in format of:
# URL neighbor URL
# URL neighbor URL
# URL neighbor URL
# ...
- lines = sc.textFile(sys.argv[1], 1)
+ lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])
# Loads all URLs from input file and initialize their neighbors.
links = lines.map(lambda urls: parseNeighbors(urls)).distinct().groupByKey().cache()
@@ -79,4 +82,4 @@ if __name__ == "__main__":
for (link, rank) in ranks.collect():
print("%s has rank: %s." % (link, rank))
- sc.stop()
+ spark.stop()
diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py
index e1fd85b082..2f09f4d573 100644
--- a/examples/src/main/python/parquet_inputformat.py
+++ b/examples/src/main/python/parquet_inputformat.py
@@ -18,7 +18,7 @@ from __future__ import print_function
import sys
-from pyspark import SparkContext
+from pyspark.sql import SparkSession
"""
Read data file users.parquet in local Spark distro:
@@ -47,7 +47,13 @@ if __name__ == "__main__":
exit(-1)
path = sys.argv[1]
- sc = SparkContext(appName="ParquetInputFormat")
+
+ spark = SparkSession\
+ .builder\
+ .appName("ParquetInputFormat")\
+ .getOrCreate()
+
+ sc = spark._sc
parquet_rdd = sc.newAPIHadoopFile(
path,
@@ -59,4 +65,4 @@ if __name__ == "__main__":
for k in output:
print(k)
- sc.stop()
+ spark.stop()
diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py
index 92e5cf45ab..5db03e4a21 100755
--- a/examples/src/main/python/pi.py
+++ b/examples/src/main/python/pi.py
@@ -20,14 +20,20 @@ import sys
from random import random
from operator import add
-from pyspark import SparkContext
+from pyspark.sql import SparkSession
if __name__ == "__main__":
"""
Usage: pi [partitions]
"""
- sc = SparkContext(appName="PythonPi")
+ spark = SparkSession\
+ .builder\
+ .appName("PythonPi")\
+ .getOrCreate()
+
+ sc = spark._sc
+
partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 2
n = 100000 * partitions
@@ -39,4 +45,4 @@ if __name__ == "__main__":
count = sc.parallelize(range(1, n + 1), partitions).map(f).reduce(add)
print("Pi is roughly %f" % (4.0 * count / n))
- sc.stop()
+ spark.stop()
diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py
index b6c2916254..81898cf6d5 100755
--- a/examples/src/main/python/sort.py
+++ b/examples/src/main/python/sort.py
@@ -19,15 +19,20 @@ from __future__ import print_function
import sys
-from pyspark import SparkContext
+from pyspark.sql import SparkSession
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: sort <file>", file=sys.stderr)
exit(-1)
- sc = SparkContext(appName="PythonSort")
- lines = sc.textFile(sys.argv[1], 1)
+
+ spark = SparkSession\
+ .builder\
+ .appName("PythonSort")\
+ .getOrCreate()
+
+ lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])
sortedCount = lines.flatMap(lambda x: x.split(' ')) \
.map(lambda x: (int(x), 1)) \
.sortByKey()
@@ -37,4 +42,4 @@ if __name__ == "__main__":
for (num, unitcount) in output:
print(num)
- sc.stop()
+ spark.stop()
diff --git a/examples/src/main/python/transitive_closure.py b/examples/src/main/python/transitive_closure.py
index 3d61250d8b..37c41dcd03 100755
--- a/examples/src/main/python/transitive_closure.py
+++ b/examples/src/main/python/transitive_closure.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import sys
from random import Random
-from pyspark import SparkContext
+from pyspark.sql import SparkSession
numEdges = 200
numVertices = 100
@@ -41,7 +41,13 @@ if __name__ == "__main__":
"""
Usage: transitive_closure [partitions]
"""
- sc = SparkContext(appName="PythonTransitiveClosure")
+ spark = SparkSession\
+ .builder\
+ .appName("PythonTransitiveClosure")\
+ .getOrCreate()
+
+ sc = spark._sc
+
partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 2
tc = sc.parallelize(generateGraph(), partitions).cache()
@@ -67,4 +73,4 @@ if __name__ == "__main__":
print("TC has %i edges" % tc.count())
- sc.stop()
+ spark.stop()
diff --git a/examples/src/main/python/wordcount.py b/examples/src/main/python/wordcount.py
index 7c0143607b..3d5e44d5b2 100755
--- a/examples/src/main/python/wordcount.py
+++ b/examples/src/main/python/wordcount.py
@@ -20,15 +20,20 @@ from __future__ import print_function
import sys
from operator import add
-from pyspark import SparkContext
+from pyspark.sql import SparkSession
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: wordcount <file>", file=sys.stderr)
exit(-1)
- sc = SparkContext(appName="PythonWordCount")
- lines = sc.textFile(sys.argv[1], 1)
+
+ spark = SparkSession\
+ .builder\
+ .appName("PythonWordCount")\
+ .getOrCreate()
+
+ lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])
counts = lines.flatMap(lambda x: x.split(' ')) \
.map(lambda x: (x, 1)) \
.reduceByKey(add)
@@ -36,4 +41,4 @@ if __name__ == "__main__":
for (word, count) in output:
print("%s: %i" % (word, count))
- sc.stop()
+ spark.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
index af5a815f6e..c50f25d951 100644
--- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
@@ -18,7 +18,8 @@
// scalastyle:off println
package org.apache.spark.examples
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.SparkSession
/**
* Usage: BroadcastTest [slices] [numElem] [blockSize]
@@ -28,9 +29,16 @@ object BroadcastTest {
val blockSize = if (args.length > 2) args(2) else "4096"
- val sparkConf = new SparkConf().setAppName("Broadcast Test")
+ val sparkConf = new SparkConf()
.set("spark.broadcast.blockSize", blockSize)
- val sc = new SparkContext(sparkConf)
+
+ val spark = SparkSession
+ .builder
+ .config(sparkConf)
+ .appName("Broadcast Test")
+ .getOrCreate()
+
+ val sc = spark.sparkContext
val slices = if (args.length > 0) args(0).toInt else 2
val num = if (args.length > 1) args(1).toInt else 1000000
@@ -48,7 +56,7 @@ object BroadcastTest {
println("Iteration %d took %.0f milliseconds".format(i, (System.nanoTime - startTime) / 1E6))
}
- sc.stop()
+ spark.stop()
}
}
// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala
index 7bf023667d..4b5e36c736 100644
--- a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala
@@ -22,7 +22,7 @@ import java.io.File
import scala.io.Source._
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.sql.SparkSession
/**
* Simple test for reading and writing to a distributed
@@ -101,11 +101,14 @@ object DFSReadWriteTest {
val fileContents = readFile(localFilePath.toString())
val localWordCount = runLocalWordCount(fileContents)
- println("Creating SparkConf")
- val conf = new SparkConf().setAppName("DFS Read Write Test")
+ println("Creating SparkSession")
+ val spark = SparkSession
+ .builder
+ .appName("DFS Read Write Test")
+ .getOrCreate()
println("Creating SparkContext")
- val sc = new SparkContext(conf)
+ val sc = spark.sparkContext
println("Writing local file to DFS")
val dfsFilename = dfsDirPath + "/dfs_read_write_test"
@@ -124,7 +127,7 @@ object DFSReadWriteTest {
.values
.sum
- sc.stop()
+ spark.stop()
if (localWordCount == dfsWordCount) {
println(s"Success! Local Word Count ($localWordCount) " +
diff --git a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala
index d42f63e870..6a1bbed290 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala
@@ -17,18 +17,22 @@
package org.apache.spark.examples
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.sql.SparkSession
object ExceptionHandlingTest {
def main(args: Array[String]) {
- val sparkConf = new SparkConf().setAppName("ExceptionHandlingTest")
- val sc = new SparkContext(sparkConf)
+ val spark = SparkSession
+ .builder
+ .appName("ExceptionHandlingTest")
+ .getOrCreate()
+ val sc = spark.sparkContext
+
sc.parallelize(0 until sc.defaultParallelism).foreach { i =>
if (math.random > 0.75) {
throw new Exception("Testing exception handling")
}
}
- sc.stop()
+ spark.stop()
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala
index 4db229b5de..0cb61d7495 100644
--- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala
@@ -20,20 +20,24 @@ package org.apache.spark.examples
import java.util.Random
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.sql.SparkSession
/**
* Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]
*/
object GroupByTest {
def main(args: Array[String]) {
- val sparkConf = new SparkConf().setAppName("GroupBy Test")
+ val spark = SparkSession
+ .builder
+ .appName("GroupBy Test")
+ .getOrCreate()
+
var numMappers = if (args.length > 0) args(0).toInt else 2
var numKVPairs = if (args.length > 1) args(1).toInt else 1000
var valSize = if (args.length > 2) args(2).toInt else 1000
var numReducers = if (args.length > 3) args(3).toInt else numMappers
- val sc = new SparkContext(sparkConf)
+ val sc = spark.sparkContext
val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random
@@ -50,7 +54,7 @@ object GroupByTest {
println(pairs1.groupByKey(numReducers).count())
- sc.stop()
+ spark.stop()
}
}
// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala
index 124dc9af63..aa8de69839 100644
--- a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala
@@ -18,7 +18,7 @@
// scalastyle:off println
package org.apache.spark.examples
-import org.apache.spark._
+import org.apache.spark.sql.SparkSession
object HdfsTest {
@@ -29,9 +29,11 @@ object HdfsTest {
System.err.println("Usage: HdfsTest <file>")
System.exit(1)
}
- val sparkConf = new SparkConf().setAppName("HdfsTest")
- val sc = new SparkContext(sparkConf)
- val file = sc.textFile(args(0))
+ val spark = SparkSession
+ .builder
+ .appName("HdfsTest")
+ .getOrCreate()
+ val file = spark.read.text(args(0)).rdd
val mapped = file.map(s => s.length).cache()
for (iter <- 1 to 10) {
val start = System.currentTimeMillis()
@@ -39,7 +41,7 @@ object HdfsTest {
val end = System.currentTimeMillis()
println("Iteration " + iter + " took " + (end-start) + " ms")
}
- sc.stop()
+ spark.stop()
}
}
// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala
index 3eb0c27723..961ab99200 100644
--- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala
@@ -18,8 +18,9 @@
// scalastyle:off println
package org.apache.spark.examples
-import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SparkSession
+
/**
* Usage: MultiBroadcastTest [slices] [numElem]
@@ -27,8 +28,12 @@ import org.apache.spark.rdd.RDD
object MultiBroadcastTest {
def main(args: Array[String]) {
- val sparkConf = new SparkConf().setAppName("Multi-Broadcast Test")
- val sc = new SparkContext(sparkConf)
+ val spark = SparkSession
+ .builder
+ .appName("Multi-Broadcast Test")
+ .getOrCreate()
+
+ val sc = spark.sparkContext
val slices = if (args.length > 0) args(0).toInt else 2
val num = if (args.length > 1) args(1).toInt else 1000000
@@ -51,7 +56,7 @@ object MultiBroadcastTest {
// Collect the small RDD so we can print the observed sizes locally.
observedSizes.collect().foreach(i => println(i))
- sc.stop()
+ spark.stop()
}
}
// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala
index ec07e6323e..255c2bfcee 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala
@@ -20,23 +20,26 @@ package org.apache.spark.examples
import java.util.Random
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.sql.SparkSession
/**
* Usage: SimpleSkewedGroupByTest [numMappers] [numKVPairs] [valSize] [numReducers] [ratio]
*/
object SimpleSkewedGroupByTest {
def main(args: Array[String]) {
+ val spark = SparkSession
+ .builder
+ .appName("SimpleSkewedGroupByTest")
+ .getOrCreate()
+
+ val sc = spark.sparkContext
- val sparkConf = new SparkConf().setAppName("SimpleSkewedGroupByTest")
var numMappers = if (args.length > 0) args(0).toInt else 2
var numKVPairs = if (args.length > 1) args(1).toInt else 1000
var valSize = if (args.length > 2) args(2).toInt else 1000
var numReducers = if (args.length > 3) args(3).toInt else numMappers
var ratio = if (args.length > 4) args(4).toInt else 5.0
- val sc = new SparkContext(sparkConf)
-
val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random
var result = new Array[(Int, Array[Byte])](numKVPairs)
@@ -64,7 +67,7 @@ object SimpleSkewedGroupByTest {
// .map{case (k,v) => (k, v.size)}
// .collectAsMap)
- sc.stop()
+ spark.stop()
}
}
// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala
index 8e4c2b6229..efd40147f7 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala
@@ -20,20 +20,25 @@ package org.apache.spark.examples
import java.util.Random
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.sql.SparkSession
/**
* Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]
*/
object SkewedGroupByTest {
def main(args: Array[String]) {
- val sparkConf = new SparkConf().setAppName("GroupBy Test")
+ val spark = SparkSession
+ .builder
+ .appName("GroupBy Test")
+ .getOrCreate()
+
+ val sc = spark.sparkContext
+
var numMappers = if (args.length > 0) args(0).toInt else 2
var numKVPairs = if (args.length > 1) args(1).toInt else 1000
var valSize = if (args.length > 2) args(2).toInt else 1000
var numReducers = if (args.length > 3) args(3).toInt else numMappers
- val sc = new SparkContext(sparkConf)
val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random
@@ -54,7 +59,7 @@ object SkewedGroupByTest {
println(pairs1.groupByKey(numReducers).count())
- sc.stop()
+ spark.stop()
}
}
// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
index b06c629802..8a3d08f459 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
@@ -20,7 +20,7 @@ package org.apache.spark.examples
import org.apache.commons.math3.linear._
-import org.apache.spark._
+import org.apache.spark.sql.SparkSession
/**
* Alternating least squares matrix factorization.
@@ -108,8 +108,12 @@ object SparkALS {
println(s"Running with M=$M, U=$U, F=$F, iters=$ITERATIONS")
- val sparkConf = new SparkConf().setAppName("SparkALS")
- val sc = new SparkContext(sparkConf)
+ val spark = SparkSession
+ .builder
+ .appName("SparkALS")
+ .getOrCreate()
+
+ val sc = spark.sparkContext
val R = generateR()
@@ -135,7 +139,7 @@ object SparkALS {
println()
}
- sc.stop()
+ spark.stop()
}
private def randomVector(n: Int): RealVector =
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
index c514eb0fa5..84f133e011 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
@@ -23,9 +23,8 @@ import java.util.Random
import scala.math.exp
import breeze.linalg.{DenseVector, Vector}
-import org.apache.hadoop.conf.Configuration
-import org.apache.spark._
+import org.apache.spark.sql.SparkSession
/**
* Logistic regression based classification.
@@ -67,11 +66,14 @@ object SparkHdfsLR {
showWarning()
- val sparkConf = new SparkConf().setAppName("SparkHdfsLR")
+ val spark = SparkSession
+ .builder
+ .appName("SparkHdfsLR")
+ .getOrCreate()
+
val inputPath = args(0)
- val conf = new Configuration()
- val sc = new SparkContext(sparkConf)
- val lines = sc.textFile(inputPath)
+ val lines = spark.read.text(inputPath).rdd
+
val points = lines.map(parsePoint).cache()
val ITERATIONS = args(1).toInt
@@ -88,7 +90,7 @@ object SparkHdfsLR {
}
println("Final w: " + w)
- sc.stop()
+ spark.stop()
}
}
// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala
index 676164806e..aa93c93c44 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala
@@ -20,7 +20,7 @@ package org.apache.spark.examples
import breeze.linalg.{squaredDistance, DenseVector, Vector}
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.sql.SparkSession
/**
* K-means clustering.
@@ -66,14 +66,17 @@ object SparkKMeans {
showWarning()
- val sparkConf = new SparkConf().setAppName("SparkKMeans")
- val sc = new SparkContext(sparkConf)
- val lines = sc.textFile(args(0))
+ val spark = SparkSession
+ .builder
+ .appName("SparkKMeans")
+ .getOrCreate()
+
+ val lines = spark.read.text(args(0)).rdd
val data = lines.map(parseVector _).cache()
val K = args(1).toInt
val convergeDist = args(2).toDouble
- val kPoints = data.takeSample(withReplacement = false, K, 42).toArray
+ val kPoints = data.takeSample(withReplacement = false, K, 42)
var tempDist = 1.0
while(tempDist > convergeDist) {
@@ -97,7 +100,7 @@ object SparkKMeans {
println("Final centers:")
kPoints.foreach(println)
- sc.stop()
+ spark.stop()
}
}
// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala
index 718f84f645..8ef3aab657 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala
@@ -24,7 +24,7 @@ import scala.math.exp
import breeze.linalg.{DenseVector, Vector}
-import org.apache.spark._
+import org.apache.spark.sql.SparkSession
/**
* Logistic regression based classification.
@@ -63,8 +63,13 @@ object SparkLR {
showWarning()
- val sparkConf = new SparkConf().setAppName("SparkLR")
- val sc = new SparkContext(sparkConf)
+ val spark = SparkSession
+ .builder
+ .appName("SparkLR")
+ .getOrCreate()
+
+ val sc = spark.sparkContext
+
val numSlices = if (args.length > 0) args(0).toInt else 2
val points = sc.parallelize(generateData, numSlices).cache()
@@ -82,7 +87,7 @@ object SparkLR {
println("Final w: " + w)
- sc.stop()
+ spark.stop()
}
}
// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala
index 2664ddbb87..b7c363c7d4 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala
@@ -18,7 +18,7 @@
// scalastyle:off println
package org.apache.spark.examples
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.sql.SparkSession
/**
* Computes the PageRank of URLs from an input file. Input file should
@@ -50,10 +50,13 @@ object SparkPageRank {
showWarning()
- val sparkConf = new SparkConf().setAppName("PageRank")
+ val spark = SparkSession
+ .builder
+ .appName("SparkPageRank")
+ .getOrCreate()
+
val iters = if (args.length > 1) args(1).toInt else 10
- val ctx = new SparkContext(sparkConf)
- val lines = ctx.textFile(args(0), 1)
+ val lines = spark.read.text(args(0)).rdd
val links = lines.map{ s =>
val parts = s.split("\\s+")
(parts(0), parts(1))
@@ -71,7 +74,7 @@ object SparkPageRank {
val output = ranks.collect()
output.foreach(tup => println(tup._1 + " has rank: " + tup._2 + "."))
- ctx.stop()
+ spark.stop()
}
}
// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
index 818d4f2b81..5be8f3b073 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
@@ -20,16 +20,19 @@ package org.apache.spark.examples
import scala.math.random
-import org.apache.spark._
+import org.apache.spark.sql.SparkSession
/** Computes an approximation to pi */
object SparkPi {
def main(args: Array[String]) {
- val conf = new SparkConf().setAppName("Spark Pi")
- val spark = new SparkContext(conf)
+ val spark = SparkSession
+ .builder
+ .appName("Spark Pi")
+ .getOrCreate()
+ val sc = spark.sparkContext
val slices = if (args.length > 0) args(0).toInt else 2
val n = math.min(100000L * slices, Int.MaxValue).toInt // avoid overflow
- val count = spark.parallelize(1 until n, slices).map { i =>
+ val count = sc.parallelize(1 until n, slices).map { i =>
val x = random * 2 - 1
val y = random * 2 - 1
if (x*x + y*y < 1) 1 else 0
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala
index fc7a1f859f..46aa68b8b8 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala
@@ -21,7 +21,7 @@ package org.apache.spark.examples
import scala.collection.mutable
import scala.util.Random
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.sql.SparkSession
/**
* Transitive closure on a graph.
@@ -42,10 +42,13 @@ object SparkTC {
}
def main(args: Array[String]) {
- val sparkConf = new SparkConf().setAppName("SparkTC")
- val spark = new SparkContext(sparkConf)
+ val spark = SparkSession
+ .builder
+ .appName("SparkTC")
+ .getOrCreate()
+ val sc = spark.sparkContext
val slices = if (args.length > 0) args(0).toInt else 2
- var tc = spark.parallelize(generateGraph, slices).cache()
+ var tc = sc.parallelize(generateGraph, slices).cache()
// Linear transitive closure: each round grows paths by one edge,
// by joining the graph's edges with the already-discovered paths.
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
index 7293cb51b2..59bdfa09ad 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
@@ -22,7 +22,7 @@ import java.io.File
import com.google.common.io.{ByteStreams, Files}
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.SparkConf
import org.apache.spark.sql._
object HiveFromSpark {