diff options
author | Zheng RuiFeng <ruifengz@foxmail.com> | 2016-05-20 16:40:33 -0700 |
---|---|---|
committer | Andrew Or <andrew@databricks.com> | 2016-05-20 16:40:33 -0700 |
commit | 127bf1bb07967e2e4f99ad7abaa7f6fab3b3f407 (patch) | |
tree | a127031cd361df2f1d895cb11489f8e183c76f73 /examples/src/main/python | |
parent | 06c9f520714e07259c6f8ce6f9ea5a230a278cb5 (diff) | |
download | spark-127bf1bb07967e2e4f99ad7abaa7f6fab3b3f407.tar.gz spark-127bf1bb07967e2e4f99ad7abaa7f6fab3b3f407.tar.bz2 spark-127bf1bb07967e2e4f99ad7abaa7f6fab3b3f407.zip |
[SPARK-15031][EXAMPLE] Use SparkSession in examples
## What changes were proposed in this pull request?
Use `SparkSession` according to [SPARK-15031](https://issues.apache.org/jira/browse/SPARK-15031)
`MLLLIB` is not recommended to use now, so examples in `MLLIB` are ignored in this PR.
`StreamingContext` can not be directly obtained from `SparkSession`, so example in `Streaming` are ignored too.
cc andrewor14
## How was this patch tested?
manual tests with spark-submit
Author: Zheng RuiFeng <ruifengz@foxmail.com>
Closes #13164 from zhengruifeng/use_sparksession_ii.
Diffstat (limited to 'examples/src/main/python')
-rwxr-xr-x | examples/src/main/python/als.py | 12 | ||||
-rw-r--r-- | examples/src/main/python/avro_inputformat.py | 12 | ||||
-rwxr-xr-x | examples/src/main/python/kmeans.py | 12 | ||||
-rwxr-xr-x | examples/src/main/python/logistic_regression.py | 13 | ||||
-rwxr-xr-x | examples/src/main/python/pagerank.py | 11 | ||||
-rw-r--r-- | examples/src/main/python/parquet_inputformat.py | 12 | ||||
-rwxr-xr-x | examples/src/main/python/pi.py | 12 | ||||
-rwxr-xr-x | examples/src/main/python/sort.py | 13 | ||||
-rwxr-xr-x | examples/src/main/python/transitive_closure.py | 12 | ||||
-rwxr-xr-x | examples/src/main/python/wordcount.py | 13 |
10 files changed, 87 insertions, 35 deletions
diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index f07020b503..81562e20a9 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -28,7 +28,7 @@ import sys import numpy as np from numpy.random import rand from numpy import matrix -from pyspark import SparkContext +from pyspark.sql import SparkSession LAMBDA = 0.01 # regularization np.random.seed(42) @@ -62,7 +62,13 @@ if __name__ == "__main__": example. Please use pyspark.ml.recommendation.ALS for more conventional use.""", file=sys.stderr) - sc = SparkContext(appName="PythonALS") + spark = SparkSession\ + .builder\ + .appName("PythonALS")\ + .getOrCreate() + + sc = spark._sc + M = int(sys.argv[1]) if len(sys.argv) > 1 else 100 U = int(sys.argv[2]) if len(sys.argv) > 2 else 500 F = int(sys.argv[3]) if len(sys.argv) > 3 else 10 @@ -99,4 +105,4 @@ if __name__ == "__main__": print("Iteration %d:" % i) print("\nRMSE: %5.4f\n" % error) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index da368ac628..3f65e8f79a 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -19,8 +19,8 @@ from __future__ import print_function import sys -from pyspark import SparkContext from functools import reduce +from pyspark.sql import SparkSession """ Read data file users.avro in local Spark distro: @@ -64,7 +64,13 @@ if __name__ == "__main__": exit(-1) path = sys.argv[1] - sc = SparkContext(appName="AvroKeyInputFormat") + + spark = SparkSession\ + .builder\ + .appName("AvroKeyInputFormat")\ + .getOrCreate() + + sc = spark._sc conf = None if len(sys.argv) == 3: @@ -82,4 +88,4 @@ if __name__ == "__main__": for k in output: print(k) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py index 3426e491dc..92e0a3ae2e 100755 --- a/examples/src/main/python/kmeans.py +++ b/examples/src/main/python/kmeans.py @@ -27,7 +27,7 @@ from __future__ import print_function import sys import numpy as np -from pyspark import SparkContext +from pyspark.sql import SparkSession def parseVector(line): @@ -55,8 +55,12 @@ if __name__ == "__main__": as an example! Please refer to examples/src/main/python/ml/kmeans_example.py for an example on how to use ML's KMeans implementation.""", file=sys.stderr) - sc = SparkContext(appName="PythonKMeans") - lines = sc.textFile(sys.argv[1]) + spark = SparkSession\ + .builder\ + .appName("PythonKMeans")\ + .getOrCreate() + + lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0]) data = lines.map(parseVector).cache() K = int(sys.argv[2]) convergeDist = float(sys.argv[3]) @@ -79,4 +83,4 @@ if __name__ == "__main__": print("Final centers: " + str(kPoints)) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py index 7d33be7e81..01c938454b 100755 --- a/examples/src/main/python/logistic_regression.py +++ b/examples/src/main/python/logistic_regression.py @@ -27,7 +27,7 @@ from __future__ import print_function import sys import numpy as np -from pyspark import SparkContext +from pyspark.sql import SparkSession D = 10 # Number of dimensions @@ -55,8 +55,13 @@ if __name__ == "__main__": Please refer to examples/src/main/python/ml/logistic_regression_with_elastic_net.py to see how ML's implementation is used.""", file=sys.stderr) - sc = SparkContext(appName="PythonLR") - points = sc.textFile(sys.argv[1]).mapPartitions(readPointBatch).cache() + spark = SparkSession\ + .builder\ + .appName("PythonLR")\ + .getOrCreate() + + points = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])\ + .mapPartitions(readPointBatch).cache() iterations = int(sys.argv[2]) # Initialize w to a random value @@ -80,4 +85,4 @@ if __name__ == "__main__": print("Final w: " + str(w)) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index 2fdc9773d4..a399a9c37c 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -25,7 +25,7 @@ import re import sys from operator import add -from pyspark import SparkContext +from pyspark.sql import SparkSession def computeContribs(urls, rank): @@ -51,14 +51,17 @@ if __name__ == "__main__": file=sys.stderr) # Initialize the spark context. - sc = SparkContext(appName="PythonPageRank") + spark = SparkSession\ + .builder\ + .appName("PythonPageRank")\ + .getOrCreate() # Loads in input file. It should be in format of: # URL neighbor URL # URL neighbor URL # URL neighbor URL # ... - lines = sc.textFile(sys.argv[1], 1) + lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0]) # Loads all URLs from input file and initialize their neighbors. links = lines.map(lambda urls: parseNeighbors(urls)).distinct().groupByKey().cache() @@ -79,4 +82,4 @@ if __name__ == "__main__": for (link, rank) in ranks.collect(): print("%s has rank: %s." % (link, rank)) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index e1fd85b082..2f09f4d573 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -18,7 +18,7 @@ from __future__ import print_function import sys -from pyspark import SparkContext +from pyspark.sql import SparkSession """ Read data file users.parquet in local Spark distro: @@ -47,7 +47,13 @@ if __name__ == "__main__": exit(-1) path = sys.argv[1] - sc = SparkContext(appName="ParquetInputFormat") + + spark = SparkSession\ + .builder\ + .appName("ParquetInputFormat")\ + .getOrCreate() + + sc = spark._sc parquet_rdd = sc.newAPIHadoopFile( path, @@ -59,4 +65,4 @@ if __name__ == "__main__": for k in output: print(k) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py index 92e5cf45ab..5db03e4a21 100755 --- a/examples/src/main/python/pi.py +++ b/examples/src/main/python/pi.py @@ -20,14 +20,20 @@ import sys from random import random from operator import add -from pyspark import SparkContext +from pyspark.sql import SparkSession if __name__ == "__main__": """ Usage: pi [partitions] """ - sc = SparkContext(appName="PythonPi") + spark = SparkSession\ + .builder\ + .appName("PythonPi")\ + .getOrCreate() + + sc = spark._sc + partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 2 n = 100000 * partitions @@ -39,4 +45,4 @@ if __name__ == "__main__": count = sc.parallelize(range(1, n + 1), partitions).map(f).reduce(add) print("Pi is roughly %f" % (4.0 * count / n)) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index b6c2916254..81898cf6d5 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -19,15 +19,20 @@ from __future__ import print_function import sys -from pyspark import SparkContext +from pyspark.sql import SparkSession if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: sort <file>", file=sys.stderr) exit(-1) - sc = SparkContext(appName="PythonSort") - lines = sc.textFile(sys.argv[1], 1) + + spark = SparkSession\ + .builder\ + .appName("PythonSort")\ + .getOrCreate() + + lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0]) sortedCount = lines.flatMap(lambda x: x.split(' ')) \ .map(lambda x: (int(x), 1)) \ .sortByKey() @@ -37,4 +42,4 @@ if __name__ == "__main__": for (num, unitcount) in output: print(num) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/transitive_closure.py b/examples/src/main/python/transitive_closure.py index 3d61250d8b..37c41dcd03 100755 --- a/examples/src/main/python/transitive_closure.py +++ b/examples/src/main/python/transitive_closure.py @@ -20,7 +20,7 @@ from __future__ import print_function import sys from random import Random -from pyspark import SparkContext +from pyspark.sql import SparkSession numEdges = 200 numVertices = 100 @@ -41,7 +41,13 @@ if __name__ == "__main__": """ Usage: transitive_closure [partitions] """ - sc = SparkContext(appName="PythonTransitiveClosure") + spark = SparkSession\ + .builder\ + .appName("PythonTransitiveClosure")\ + .getOrCreate() + + sc = spark._sc + partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 2 tc = sc.parallelize(generateGraph(), partitions).cache() @@ -67,4 +73,4 @@ if __name__ == "__main__": print("TC has %i edges" % tc.count()) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/wordcount.py b/examples/src/main/python/wordcount.py index 7c0143607b..3d5e44d5b2 100755 --- a/examples/src/main/python/wordcount.py +++ b/examples/src/main/python/wordcount.py @@ -20,15 +20,20 @@ from __future__ import print_function import sys from operator import add -from pyspark import SparkContext +from pyspark.sql import SparkSession if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: wordcount <file>", file=sys.stderr) exit(-1) - sc = SparkContext(appName="PythonWordCount") - lines = sc.textFile(sys.argv[1], 1) + + spark = SparkSession\ + .builder\ + .appName("PythonWordCount")\ + .getOrCreate() + + lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0]) counts = lines.flatMap(lambda x: x.split(' ')) \ .map(lambda x: (x, 1)) \ .reduceByKey(add) @@ -36,4 +41,4 @@ if __name__ == "__main__": for (word, count) in output: print("%s: %i" % (word, count)) - sc.stop() + spark.stop() |