aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-06-10 15:40:29 -0700
committerReynold Xin <rxin@databricks.com>2016-06-10 15:40:29 -0700
commit2022afe57dbf8cb0c9909399962c4a3649e0601c (patch)
tree3f850e4ad28effb207044df759aade5b24afd394 /examples
parent127a6678d7af6b5164a115be7c64525bb80001fe (diff)
downloadspark-2022afe57dbf8cb0c9909399962c4a3649e0601c.tar.gz
spark-2022afe57dbf8cb0c9909399962c4a3649e0601c.tar.bz2
spark-2022afe57dbf8cb0c9909399962c4a3649e0601c.zip
[SPARK-15773][CORE][EXAMPLE] Avoid creating local variable `sc` in examples if possible
## What changes were proposed in this pull request? Instead of using local variable `sc` like the following example, this PR uses `spark.sparkContext`. This makes examples more concise, and also fixes some misleading, i.e., creating SparkContext from SparkSession. ``` - println("Creating SparkContext") - val sc = spark.sparkContext - println("Writing local file to DFS") val dfsFilename = dfsDirPath + "/dfs_read_write_test" - val fileRDD = sc.parallelize(fileContents) + val fileRDD = spark.sparkContext.parallelize(fileContents) ``` This will change 12 files (+30 lines, -52 lines). ## How was this patch tested? Manual. Author: Dongjoon Hyun <dongjoon@apache.org> Closes #13520 from dongjoon-hyun/SPARK-15773.
Diffstat (limited to 'examples')
-rwxr-xr-xexamples/src/main/python/pi.py4
-rwxr-xr-xexamples/src/main/python/transitive_closure.py4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala7
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala3
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala14
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala8
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala16
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala13
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SparkLR.scala4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SparkPi.scala3
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SparkTC.scala3
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala3
12 files changed, 30 insertions, 52 deletions
diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py
index b39d710540..e3f0c4aeef 100755
--- a/examples/src/main/python/pi.py
+++ b/examples/src/main/python/pi.py
@@ -32,8 +32,6 @@ if __name__ == "__main__":
.appName("PythonPi")\
.getOrCreate()
- sc = spark.sparkContext
-
partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 2
n = 100000 * partitions
@@ -42,7 +40,7 @@ if __name__ == "__main__":
y = random() * 2 - 1
return 1 if x ** 2 + y ** 2 < 1 else 0
- count = sc.parallelize(range(1, n + 1), partitions).map(f).reduce(add)
+ count = spark.sparkContext.parallelize(range(1, n + 1), partitions).map(f).reduce(add)
print("Pi is roughly %f" % (4.0 * count / n))
spark.stop()
diff --git a/examples/src/main/python/transitive_closure.py b/examples/src/main/python/transitive_closure.py
index d88ea94e41..49551d4085 100755
--- a/examples/src/main/python/transitive_closure.py
+++ b/examples/src/main/python/transitive_closure.py
@@ -46,10 +46,8 @@ if __name__ == "__main__":
.appName("PythonTransitiveClosure")\
.getOrCreate()
- sc = spark.sparkContext
-
partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 2
- tc = sc.parallelize(generateGraph(), partitions).cache()
+ tc = spark.sparkContext.parallelize(generateGraph(), partitions).cache()
# Linear transitive closure: each round grows paths by one edge,
# by joining the graph's edges with the already-discovered paths.
diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala
index 4b5e36c736..3bff7ce736 100644
--- a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala
@@ -107,16 +107,13 @@ object DFSReadWriteTest {
.appName("DFS Read Write Test")
.getOrCreate()
- println("Creating SparkContext")
- val sc = spark.sparkContext
-
println("Writing local file to DFS")
val dfsFilename = dfsDirPath + "/dfs_read_write_test"
- val fileRDD = sc.parallelize(fileContents)
+ val fileRDD = spark.sparkContext.parallelize(fileContents)
fileRDD.saveAsTextFile(dfsFilename)
println("Reading file from DFS and running Word Count")
- val readFileRDD = sc.textFile(dfsFilename)
+ val readFileRDD = spark.sparkContext.textFile(dfsFilename)
val dfsWordCount = readFileRDD
.flatMap(_.split(" "))
diff --git a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala
index 6a1bbed290..45c4953a84 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala
@@ -25,9 +25,8 @@ object ExceptionHandlingTest {
.builder
.appName("ExceptionHandlingTest")
.getOrCreate()
- val sc = spark.sparkContext
- sc.parallelize(0 until sc.defaultParallelism).foreach { i =>
+ spark.sparkContext.parallelize(0 until spark.sparkContext.defaultParallelism).foreach { i =>
if (math.random > 0.75) {
throw new Exception("Testing exception handling")
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala
index 0cb61d7495..2f2bbb1275 100644
--- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala
@@ -32,16 +32,14 @@ object GroupByTest {
.appName("GroupBy Test")
.getOrCreate()
- var numMappers = if (args.length > 0) args(0).toInt else 2
- var numKVPairs = if (args.length > 1) args(1).toInt else 1000
- var valSize = if (args.length > 2) args(2).toInt else 1000
- var numReducers = if (args.length > 3) args(3).toInt else numMappers
+ val numMappers = if (args.length > 0) args(0).toInt else 2
+ val numKVPairs = if (args.length > 1) args(1).toInt else 1000
+ val valSize = if (args.length > 2) args(2).toInt else 1000
+ val numReducers = if (args.length > 3) args(3).toInt else numMappers
- val sc = spark.sparkContext
-
- val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
+ val pairs1 = spark.sparkContext.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random
- var arr1 = new Array[(Int, Array[Byte])](numKVPairs)
+ val arr1 = new Array[(Int, Array[Byte])](numKVPairs)
for (i <- 0 until numKVPairs) {
val byteArr = new Array[Byte](valSize)
ranGen.nextBytes(byteArr)
diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala
index 961ab99200..6495a86fcd 100644
--- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala
@@ -33,8 +33,6 @@ object MultiBroadcastTest {
.appName("Multi-Broadcast Test")
.getOrCreate()
- val sc = spark.sparkContext
-
val slices = if (args.length > 0) args(0).toInt else 2
val num = if (args.length > 1) args(1).toInt else 1000000
@@ -48,9 +46,9 @@ object MultiBroadcastTest {
arr2(i) = i
}
- val barr1 = sc.broadcast(arr1)
- val barr2 = sc.broadcast(arr2)
- val observedSizes: RDD[(Int, Int)] = sc.parallelize(1 to 10, slices).map { _ =>
+ val barr1 = spark.sparkContext.broadcast(arr1)
+ val barr2 = spark.sparkContext.broadcast(arr2)
+ val observedSizes: RDD[(Int, Int)] = spark.sparkContext.parallelize(1 to 10, slices).map { _ =>
(barr1.value.length, barr2.value.length)
}
// Collect the small RDD so we can print the observed sizes locally.
diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala
index 255c2bfcee..8e1a574c92 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala
@@ -32,17 +32,15 @@ object SimpleSkewedGroupByTest {
.appName("SimpleSkewedGroupByTest")
.getOrCreate()
- val sc = spark.sparkContext
+ val numMappers = if (args.length > 0) args(0).toInt else 2
+ val numKVPairs = if (args.length > 1) args(1).toInt else 1000
+ val valSize = if (args.length > 2) args(2).toInt else 1000
+ val numReducers = if (args.length > 3) args(3).toInt else numMappers
+ val ratio = if (args.length > 4) args(4).toInt else 5.0
- var numMappers = if (args.length > 0) args(0).toInt else 2
- var numKVPairs = if (args.length > 1) args(1).toInt else 1000
- var valSize = if (args.length > 2) args(2).toInt else 1000
- var numReducers = if (args.length > 3) args(3).toInt else numMappers
- var ratio = if (args.length > 4) args(4).toInt else 5.0
-
- val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
+ val pairs1 = spark.sparkContext.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random
- var result = new Array[(Int, Array[Byte])](numKVPairs)
+ val result = new Array[(Int, Array[Byte])](numKVPairs)
for (i <- 0 until numKVPairs) {
val byteArr = new Array[Byte](valSize)
ranGen.nextBytes(byteArr)
diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala
index efd40147f7..4d3c34041b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala
@@ -32,21 +32,18 @@ object SkewedGroupByTest {
.appName("GroupBy Test")
.getOrCreate()
- val sc = spark.sparkContext
-
- var numMappers = if (args.length > 0) args(0).toInt else 2
+ val numMappers = if (args.length > 0) args(0).toInt else 2
var numKVPairs = if (args.length > 1) args(1).toInt else 1000
- var valSize = if (args.length > 2) args(2).toInt else 1000
- var numReducers = if (args.length > 3) args(3).toInt else numMappers
-
+ val valSize = if (args.length > 2) args(2).toInt else 1000
+ val numReducers = if (args.length > 3) args(3).toInt else numMappers
- val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
+ val pairs1 = spark.sparkContext.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random
// map output sizes linearly increase from the 1st to the last
numKVPairs = (1.0 * (p + 1) / numMappers * numKVPairs).toInt
- var arr1 = new Array[(Int, Array[Byte])](numKVPairs)
+ val arr1 = new Array[(Int, Array[Byte])](numKVPairs)
for (i <- 0 until numKVPairs) {
val byteArr = new Array[Byte](valSize)
ranGen.nextBytes(byteArr)
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala
index 8ef3aab657..afa8f58c96 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala
@@ -68,10 +68,8 @@ object SparkLR {
.appName("SparkLR")
.getOrCreate()
- val sc = spark.sparkContext
-
val numSlices = if (args.length > 0) args(0).toInt else 2
- val points = sc.parallelize(generateData, numSlices).cache()
+ val points = spark.sparkContext.parallelize(generateData, numSlices).cache()
// Initialize w to a random value
var w = DenseVector.fill(D) {2 * rand.nextDouble - 1}
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
index 5be8f3b073..42f6cef4e1 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
@@ -29,10 +29,9 @@ object SparkPi {
.builder
.appName("Spark Pi")
.getOrCreate()
- val sc = spark.sparkContext
val slices = if (args.length > 0) args(0).toInt else 2
val n = math.min(100000L * slices, Int.MaxValue).toInt // avoid overflow
- val count = sc.parallelize(1 until n, slices).map { i =>
+ val count = spark.sparkContext.parallelize(1 until n, slices).map { i =>
val x = random * 2 - 1
val y = random * 2 - 1
if (x*x + y*y < 1) 1 else 0
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala
index 46aa68b8b8..558295ab92 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala
@@ -46,9 +46,8 @@ object SparkTC {
.builder
.appName("SparkTC")
.getOrCreate()
- val sc = spark.sparkContext
val slices = if (args.length > 0) args(0).toInt else 2
- var tc = sc.parallelize(generateGraph, slices).cache()
+ var tc = spark.sparkContext.parallelize(generateGraph, slices).cache()
// Linear transitive closure: each round grows paths by one edge,
// by joining the graph's edges with the already-discovered paths.
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
index 2d7a01a95d..2343f98c8d 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
@@ -45,7 +45,6 @@ object HiveFromSpark {
.appName("HiveFromSpark")
.enableHiveSupport()
.getOrCreate()
- val sc = spark.sparkContext
import spark.implicits._
import spark.sql
@@ -71,7 +70,7 @@ object HiveFromSpark {
}
// You can also use RDDs to create temporary views within a HiveContext.
- val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i")))
+ val rdd = spark.sparkContext.parallelize((1 to 100).map(i => Record(i, s"val_$i")))
rdd.toDF().createOrReplaceTempView("records")
// Queries can then join RDD data with data stored in Hive.