diff options
author | Zheng RuiFeng <ruifengz@foxmail.com> | 2016-05-20 16:40:33 -0700 |
---|---|---|
committer | Andrew Or <andrew@databricks.com> | 2016-05-20 16:40:33 -0700 |
commit | 127bf1bb07967e2e4f99ad7abaa7f6fab3b3f407 (patch) | |
tree | a127031cd361df2f1d895cb11489f8e183c76f73 /examples/src/main/scala | |
parent | 06c9f520714e07259c6f8ce6f9ea5a230a278cb5 (diff) | |
download | spark-127bf1bb07967e2e4f99ad7abaa7f6fab3b3f407.tar.gz spark-127bf1bb07967e2e4f99ad7abaa7f6fab3b3f407.tar.bz2 spark-127bf1bb07967e2e4f99ad7abaa7f6fab3b3f407.zip |
[SPARK-15031][EXAMPLE] Use SparkSession in examples
## What changes were proposed in this pull request?
Use `SparkSession` according to [SPARK-15031](https://issues.apache.org/jira/browse/SPARK-15031)
`MLLLIB` is not recommended to use now, so examples in `MLLIB` are ignored in this PR.
`StreamingContext` can not be directly obtained from `SparkSession`, so example in `Streaming` are ignored too.
cc andrewor14
## How was this patch tested?
manual tests with spark-submit
Author: Zheng RuiFeng <ruifengz@foxmail.com>
Closes #13164 from zhengruifeng/use_sparksession_ii.
Diffstat (limited to 'examples/src/main/scala')
16 files changed, 127 insertions, 70 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index af5a815f6e..c50f25d951 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -18,7 +18,8 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession /** * Usage: BroadcastTest [slices] [numElem] [blockSize] @@ -28,9 +29,16 @@ object BroadcastTest { val blockSize = if (args.length > 2) args(2) else "4096" - val sparkConf = new SparkConf().setAppName("Broadcast Test") + val sparkConf = new SparkConf() .set("spark.broadcast.blockSize", blockSize) - val sc = new SparkContext(sparkConf) + + val spark = SparkSession + .builder + .config(sparkConf) + .appName("Broadcast Test") + .getOrCreate() + + val sc = spark.sparkContext val slices = if (args.length > 0) args(0).toInt else 2 val num = if (args.length > 1) args(1).toInt else 1000000 @@ -48,7 +56,7 @@ object BroadcastTest { println("Iteration %d took %.0f milliseconds".format(i, (System.nanoTime - startTime) / 1E6)) } - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala index 7bf023667d..4b5e36c736 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala @@ -22,7 +22,7 @@ import java.io.File import scala.io.Source._ -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * Simple test for reading and writing to a distributed @@ -101,11 +101,14 @@ object DFSReadWriteTest { val fileContents = readFile(localFilePath.toString()) val localWordCount = runLocalWordCount(fileContents) - println("Creating SparkConf") - val conf = new SparkConf().setAppName("DFS Read Write Test") + println("Creating SparkSession") + val spark = SparkSession + .builder + .appName("DFS Read Write Test") + .getOrCreate() println("Creating SparkContext") - val sc = new SparkContext(conf) + val sc = spark.sparkContext println("Writing local file to DFS") val dfsFilename = dfsDirPath + "/dfs_read_write_test" @@ -124,7 +127,7 @@ object DFSReadWriteTest { .values .sum - sc.stop() + spark.stop() if (localWordCount == dfsWordCount) { println(s"Success! Local Word Count ($localWordCount) " + diff --git a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala index d42f63e870..6a1bbed290 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala @@ -17,18 +17,22 @@ package org.apache.spark.examples -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession object ExceptionHandlingTest { def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("ExceptionHandlingTest") - val sc = new SparkContext(sparkConf) + val spark = SparkSession + .builder + .appName("ExceptionHandlingTest") + .getOrCreate() + val sc = spark.sparkContext + sc.parallelize(0 until sc.defaultParallelism).foreach { i => if (math.random > 0.75) { throw new Exception("Testing exception handling") } } - sc.stop() + spark.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala index 4db229b5de..0cb61d7495 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala @@ -20,20 +20,24 @@ package org.apache.spark.examples import java.util.Random -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] */ object GroupByTest { def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("GroupBy Test") + val spark = SparkSession + .builder + .appName("GroupBy Test") + .getOrCreate() + var numMappers = if (args.length > 0) args(0).toInt else 2 var numKVPairs = if (args.length > 1) args(1).toInt else 1000 var valSize = if (args.length > 2) args(2).toInt else 1000 var numReducers = if (args.length > 3) args(3).toInt else numMappers - val sc = new SparkContext(sparkConf) + val sc = spark.sparkContext val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => val ranGen = new Random @@ -50,7 +54,7 @@ object GroupByTest { println(pairs1.groupByKey(numReducers).count()) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala index 124dc9af63..aa8de69839 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.spark._ +import org.apache.spark.sql.SparkSession object HdfsTest { @@ -29,9 +29,11 @@ object HdfsTest { System.err.println("Usage: HdfsTest <file>") System.exit(1) } - val sparkConf = new SparkConf().setAppName("HdfsTest") - val sc = new SparkContext(sparkConf) - val file = sc.textFile(args(0)) + val spark = SparkSession + .builder + .appName("HdfsTest") + .getOrCreate() + val file = spark.read.text(args(0)).rdd val mapped = file.map(s => s.length).cache() for (iter <- 1 to 10) { val start = System.currentTimeMillis() @@ -39,7 +41,7 @@ object HdfsTest { val end = System.currentTimeMillis() println("Iteration " + iter + " took " + (end-start) + " ms") } - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index 3eb0c27723..961ab99200 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -18,8 +18,9 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession + /** * Usage: MultiBroadcastTest [slices] [numElem] @@ -27,8 +28,12 @@ import org.apache.spark.rdd.RDD object MultiBroadcastTest { def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("Multi-Broadcast Test") - val sc = new SparkContext(sparkConf) + val spark = SparkSession + .builder + .appName("Multi-Broadcast Test") + .getOrCreate() + + val sc = spark.sparkContext val slices = if (args.length > 0) args(0).toInt else 2 val num = if (args.length > 1) args(1).toInt else 1000000 @@ -51,7 +56,7 @@ object MultiBroadcastTest { // Collect the small RDD so we can print the observed sizes locally. observedSizes.collect().foreach(i => println(i)) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index ec07e6323e..255c2bfcee 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -20,23 +20,26 @@ package org.apache.spark.examples import java.util.Random -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * Usage: SimpleSkewedGroupByTest [numMappers] [numKVPairs] [valSize] [numReducers] [ratio] */ object SimpleSkewedGroupByTest { def main(args: Array[String]) { + val spark = SparkSession + .builder + .appName("SimpleSkewedGroupByTest") + .getOrCreate() + + val sc = spark.sparkContext - val sparkConf = new SparkConf().setAppName("SimpleSkewedGroupByTest") var numMappers = if (args.length > 0) args(0).toInt else 2 var numKVPairs = if (args.length > 1) args(1).toInt else 1000 var valSize = if (args.length > 2) args(2).toInt else 1000 var numReducers = if (args.length > 3) args(3).toInt else numMappers var ratio = if (args.length > 4) args(4).toInt else 5.0 - val sc = new SparkContext(sparkConf) - val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => val ranGen = new Random var result = new Array[(Int, Array[Byte])](numKVPairs) @@ -64,7 +67,7 @@ object SimpleSkewedGroupByTest { // .map{case (k,v) => (k, v.size)} // .collectAsMap) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala index 8e4c2b6229..efd40147f7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -20,20 +20,25 @@ package org.apache.spark.examples import java.util.Random -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] */ object SkewedGroupByTest { def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("GroupBy Test") + val spark = SparkSession + .builder + .appName("GroupBy Test") + .getOrCreate() + + val sc = spark.sparkContext + var numMappers = if (args.length > 0) args(0).toInt else 2 var numKVPairs = if (args.length > 1) args(1).toInt else 1000 var valSize = if (args.length > 2) args(2).toInt else 1000 var numReducers = if (args.length > 3) args(3).toInt else numMappers - val sc = new SparkContext(sparkConf) val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => val ranGen = new Random @@ -54,7 +59,7 @@ object SkewedGroupByTest { println(pairs1.groupByKey(numReducers).count()) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index b06c629802..8a3d08f459 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -20,7 +20,7 @@ package org.apache.spark.examples import org.apache.commons.math3.linear._ -import org.apache.spark._ +import org.apache.spark.sql.SparkSession /** * Alternating least squares matrix factorization. @@ -108,8 +108,12 @@ object SparkALS { println(s"Running with M=$M, U=$U, F=$F, iters=$ITERATIONS") - val sparkConf = new SparkConf().setAppName("SparkALS") - val sc = new SparkContext(sparkConf) + val spark = SparkSession + .builder + .appName("SparkALS") + .getOrCreate() + + val sc = spark.sparkContext val R = generateR() @@ -135,7 +139,7 @@ object SparkALS { println() } - sc.stop() + spark.stop() } private def randomVector(n: Int): RealVector = diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index c514eb0fa5..84f133e011 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -23,9 +23,8 @@ import java.util.Random import scala.math.exp import breeze.linalg.{DenseVector, Vector} -import org.apache.hadoop.conf.Configuration -import org.apache.spark._ +import org.apache.spark.sql.SparkSession /** * Logistic regression based classification. @@ -67,11 +66,14 @@ object SparkHdfsLR { showWarning() - val sparkConf = new SparkConf().setAppName("SparkHdfsLR") + val spark = SparkSession + .builder + .appName("SparkHdfsLR") + .getOrCreate() + val inputPath = args(0) - val conf = new Configuration() - val sc = new SparkContext(sparkConf) - val lines = sc.textFile(inputPath) + val lines = spark.read.text(inputPath).rdd + val points = lines.map(parsePoint).cache() val ITERATIONS = args(1).toInt @@ -88,7 +90,7 @@ object SparkHdfsLR { } println("Final w: " + w) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index 676164806e..aa93c93c44 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -20,7 +20,7 @@ package org.apache.spark.examples import breeze.linalg.{squaredDistance, DenseVector, Vector} -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * K-means clustering. @@ -66,14 +66,17 @@ object SparkKMeans { showWarning() - val sparkConf = new SparkConf().setAppName("SparkKMeans") - val sc = new SparkContext(sparkConf) - val lines = sc.textFile(args(0)) + val spark = SparkSession + .builder + .appName("SparkKMeans") + .getOrCreate() + + val lines = spark.read.text(args(0)).rdd val data = lines.map(parseVector _).cache() val K = args(1).toInt val convergeDist = args(2).toDouble - val kPoints = data.takeSample(withReplacement = false, K, 42).toArray + val kPoints = data.takeSample(withReplacement = false, K, 42) var tempDist = 1.0 while(tempDist > convergeDist) { @@ -97,7 +100,7 @@ object SparkKMeans { println("Final centers:") kPoints.foreach(println) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 718f84f645..8ef3aab657 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -24,7 +24,7 @@ import scala.math.exp import breeze.linalg.{DenseVector, Vector} -import org.apache.spark._ +import org.apache.spark.sql.SparkSession /** * Logistic regression based classification. @@ -63,8 +63,13 @@ object SparkLR { showWarning() - val sparkConf = new SparkConf().setAppName("SparkLR") - val sc = new SparkContext(sparkConf) + val spark = SparkSession + .builder + .appName("SparkLR") + .getOrCreate() + + val sc = spark.sparkContext + val numSlices = if (args.length > 0) args(0).toInt else 2 val points = sc.parallelize(generateData, numSlices).cache() @@ -82,7 +87,7 @@ object SparkLR { println("Final w: " + w) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index 2664ddbb87..b7c363c7d4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * Computes the PageRank of URLs from an input file. Input file should @@ -50,10 +50,13 @@ object SparkPageRank { showWarning() - val sparkConf = new SparkConf().setAppName("PageRank") + val spark = SparkSession + .builder + .appName("SparkPageRank") + .getOrCreate() + val iters = if (args.length > 1) args(1).toInt else 10 - val ctx = new SparkContext(sparkConf) - val lines = ctx.textFile(args(0), 1) + val lines = spark.read.text(args(0)).rdd val links = lines.map{ s => val parts = s.split("\\s+") (parts(0), parts(1)) @@ -71,7 +74,7 @@ object SparkPageRank { val output = ranks.collect() output.foreach(tup => println(tup._1 + " has rank: " + tup._2 + ".")) - ctx.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index 818d4f2b81..5be8f3b073 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -20,16 +20,19 @@ package org.apache.spark.examples import scala.math.random -import org.apache.spark._ +import org.apache.spark.sql.SparkSession /** Computes an approximation to pi */ object SparkPi { def main(args: Array[String]) { - val conf = new SparkConf().setAppName("Spark Pi") - val spark = new SparkContext(conf) + val spark = SparkSession + .builder + .appName("Spark Pi") + .getOrCreate() + val sc = spark.sparkContext val slices = if (args.length > 0) args(0).toInt else 2 val n = math.min(100000L * slices, Int.MaxValue).toInt // avoid overflow - val count = spark.parallelize(1 until n, slices).map { i => + val count = sc.parallelize(1 until n, slices).map { i => val x = random * 2 - 1 val y = random * 2 - 1 if (x*x + y*y < 1) 1 else 0 diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index fc7a1f859f..46aa68b8b8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -21,7 +21,7 @@ package org.apache.spark.examples import scala.collection.mutable import scala.util.Random -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * Transitive closure on a graph. @@ -42,10 +42,13 @@ object SparkTC { } def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("SparkTC") - val spark = new SparkContext(sparkConf) + val spark = SparkSession + .builder + .appName("SparkTC") + .getOrCreate() + val sc = spark.sparkContext val slices = if (args.length > 0) args(0).toInt else 2 - var tc = spark.parallelize(generateGraph, slices).cache() + var tc = sc.parallelize(generateGraph, slices).cache() // Linear transitive closure: each round grows paths by one edge, // by joining the graph's edges with the already-discovered paths. diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index 7293cb51b2..59bdfa09ad 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -22,7 +22,7 @@ import java.io.File import com.google.common.io.{ByteStreams, Files} -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkConf import org.apache.spark.sql._ object HiveFromSpark { |