aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/python
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2016-05-20 16:40:33 -0700
committerAndrew Or <andrew@databricks.com>2016-05-20 16:40:33 -0700
commit127bf1bb07967e2e4f99ad7abaa7f6fab3b3f407 (patch)
treea127031cd361df2f1d895cb11489f8e183c76f73 /examples/src/main/python
parent06c9f520714e07259c6f8ce6f9ea5a230a278cb5 (diff)
downloadspark-127bf1bb07967e2e4f99ad7abaa7f6fab3b3f407.tar.gz
spark-127bf1bb07967e2e4f99ad7abaa7f6fab3b3f407.tar.bz2
spark-127bf1bb07967e2e4f99ad7abaa7f6fab3b3f407.zip
[SPARK-15031][EXAMPLE] Use SparkSession in examples
## What changes were proposed in this pull request? Use `SparkSession` according to [SPARK-15031](https://issues.apache.org/jira/browse/SPARK-15031) `MLLLIB` is not recommended to use now, so examples in `MLLIB` are ignored in this PR. `StreamingContext` can not be directly obtained from `SparkSession`, so example in `Streaming` are ignored too. cc andrewor14 ## How was this patch tested? manual tests with spark-submit Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #13164 from zhengruifeng/use_sparksession_ii.
Diffstat (limited to 'examples/src/main/python')
-rwxr-xr-xexamples/src/main/python/als.py12
-rw-r--r--examples/src/main/python/avro_inputformat.py12
-rwxr-xr-xexamples/src/main/python/kmeans.py12
-rwxr-xr-xexamples/src/main/python/logistic_regression.py13
-rwxr-xr-xexamples/src/main/python/pagerank.py11
-rw-r--r--examples/src/main/python/parquet_inputformat.py12
-rwxr-xr-xexamples/src/main/python/pi.py12
-rwxr-xr-xexamples/src/main/python/sort.py13
-rwxr-xr-xexamples/src/main/python/transitive_closure.py12
-rwxr-xr-xexamples/src/main/python/wordcount.py13
10 files changed, 87 insertions, 35 deletions
diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py
index f07020b503..81562e20a9 100755
--- a/examples/src/main/python/als.py
+++ b/examples/src/main/python/als.py
@@ -28,7 +28,7 @@ import sys
import numpy as np
from numpy.random import rand
from numpy import matrix
-from pyspark import SparkContext
+from pyspark.sql import SparkSession
LAMBDA = 0.01 # regularization
np.random.seed(42)
@@ -62,7 +62,13 @@ if __name__ == "__main__":
example. Please use pyspark.ml.recommendation.ALS for more
conventional use.""", file=sys.stderr)
- sc = SparkContext(appName="PythonALS")
+ spark = SparkSession\
+ .builder\
+ .appName("PythonALS")\
+ .getOrCreate()
+
+ sc = spark._sc
+
M = int(sys.argv[1]) if len(sys.argv) > 1 else 100
U = int(sys.argv[2]) if len(sys.argv) > 2 else 500
F = int(sys.argv[3]) if len(sys.argv) > 3 else 10
@@ -99,4 +105,4 @@ if __name__ == "__main__":
print("Iteration %d:" % i)
print("\nRMSE: %5.4f\n" % error)
- sc.stop()
+ spark.stop()
diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py
index da368ac628..3f65e8f79a 100644
--- a/examples/src/main/python/avro_inputformat.py
+++ b/examples/src/main/python/avro_inputformat.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import sys
-from pyspark import SparkContext
from functools import reduce
+from pyspark.sql import SparkSession
"""
Read data file users.avro in local Spark distro:
@@ -64,7 +64,13 @@ if __name__ == "__main__":
exit(-1)
path = sys.argv[1]
- sc = SparkContext(appName="AvroKeyInputFormat")
+
+ spark = SparkSession\
+ .builder\
+ .appName("AvroKeyInputFormat")\
+ .getOrCreate()
+
+ sc = spark._sc
conf = None
if len(sys.argv) == 3:
@@ -82,4 +88,4 @@ if __name__ == "__main__":
for k in output:
print(k)
- sc.stop()
+ spark.stop()
diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py
index 3426e491dc..92e0a3ae2e 100755
--- a/examples/src/main/python/kmeans.py
+++ b/examples/src/main/python/kmeans.py
@@ -27,7 +27,7 @@ from __future__ import print_function
import sys
import numpy as np
-from pyspark import SparkContext
+from pyspark.sql import SparkSession
def parseVector(line):
@@ -55,8 +55,12 @@ if __name__ == "__main__":
as an example! Please refer to examples/src/main/python/ml/kmeans_example.py for an
example on how to use ML's KMeans implementation.""", file=sys.stderr)
- sc = SparkContext(appName="PythonKMeans")
- lines = sc.textFile(sys.argv[1])
+ spark = SparkSession\
+ .builder\
+ .appName("PythonKMeans")\
+ .getOrCreate()
+
+ lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])
data = lines.map(parseVector).cache()
K = int(sys.argv[2])
convergeDist = float(sys.argv[3])
@@ -79,4 +83,4 @@ if __name__ == "__main__":
print("Final centers: " + str(kPoints))
- sc.stop()
+ spark.stop()
diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py
index 7d33be7e81..01c938454b 100755
--- a/examples/src/main/python/logistic_regression.py
+++ b/examples/src/main/python/logistic_regression.py
@@ -27,7 +27,7 @@ from __future__ import print_function
import sys
import numpy as np
-from pyspark import SparkContext
+from pyspark.sql import SparkSession
D = 10 # Number of dimensions
@@ -55,8 +55,13 @@ if __name__ == "__main__":
Please refer to examples/src/main/python/ml/logistic_regression_with_elastic_net.py
to see how ML's implementation is used.""", file=sys.stderr)
- sc = SparkContext(appName="PythonLR")
- points = sc.textFile(sys.argv[1]).mapPartitions(readPointBatch).cache()
+ spark = SparkSession\
+ .builder\
+ .appName("PythonLR")\
+ .getOrCreate()
+
+ points = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])\
+ .mapPartitions(readPointBatch).cache()
iterations = int(sys.argv[2])
# Initialize w to a random value
@@ -80,4 +85,4 @@ if __name__ == "__main__":
print("Final w: " + str(w))
- sc.stop()
+ spark.stop()
diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py
index 2fdc9773d4..a399a9c37c 100755
--- a/examples/src/main/python/pagerank.py
+++ b/examples/src/main/python/pagerank.py
@@ -25,7 +25,7 @@ import re
import sys
from operator import add
-from pyspark import SparkContext
+from pyspark.sql import SparkSession
def computeContribs(urls, rank):
@@ -51,14 +51,17 @@ if __name__ == "__main__":
file=sys.stderr)
# Initialize the spark context.
- sc = SparkContext(appName="PythonPageRank")
+ spark = SparkSession\
+ .builder\
+ .appName("PythonPageRank")\
+ .getOrCreate()
# Loads in input file. It should be in format of:
# URL neighbor URL
# URL neighbor URL
# URL neighbor URL
# ...
- lines = sc.textFile(sys.argv[1], 1)
+ lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])
# Loads all URLs from input file and initialize their neighbors.
links = lines.map(lambda urls: parseNeighbors(urls)).distinct().groupByKey().cache()
@@ -79,4 +82,4 @@ if __name__ == "__main__":
for (link, rank) in ranks.collect():
print("%s has rank: %s." % (link, rank))
- sc.stop()
+ spark.stop()
diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py
index e1fd85b082..2f09f4d573 100644
--- a/examples/src/main/python/parquet_inputformat.py
+++ b/examples/src/main/python/parquet_inputformat.py
@@ -18,7 +18,7 @@ from __future__ import print_function
import sys
-from pyspark import SparkContext
+from pyspark.sql import SparkSession
"""
Read data file users.parquet in local Spark distro:
@@ -47,7 +47,13 @@ if __name__ == "__main__":
exit(-1)
path = sys.argv[1]
- sc = SparkContext(appName="ParquetInputFormat")
+
+ spark = SparkSession\
+ .builder\
+ .appName("ParquetInputFormat")\
+ .getOrCreate()
+
+ sc = spark._sc
parquet_rdd = sc.newAPIHadoopFile(
path,
@@ -59,4 +65,4 @@ if __name__ == "__main__":
for k in output:
print(k)
- sc.stop()
+ spark.stop()
diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py
index 92e5cf45ab..5db03e4a21 100755
--- a/examples/src/main/python/pi.py
+++ b/examples/src/main/python/pi.py
@@ -20,14 +20,20 @@ import sys
from random import random
from operator import add
-from pyspark import SparkContext
+from pyspark.sql import SparkSession
if __name__ == "__main__":
"""
Usage: pi [partitions]
"""
- sc = SparkContext(appName="PythonPi")
+ spark = SparkSession\
+ .builder\
+ .appName("PythonPi")\
+ .getOrCreate()
+
+ sc = spark._sc
+
partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 2
n = 100000 * partitions
@@ -39,4 +45,4 @@ if __name__ == "__main__":
count = sc.parallelize(range(1, n + 1), partitions).map(f).reduce(add)
print("Pi is roughly %f" % (4.0 * count / n))
- sc.stop()
+ spark.stop()
diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py
index b6c2916254..81898cf6d5 100755
--- a/examples/src/main/python/sort.py
+++ b/examples/src/main/python/sort.py
@@ -19,15 +19,20 @@ from __future__ import print_function
import sys
-from pyspark import SparkContext
+from pyspark.sql import SparkSession
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: sort <file>", file=sys.stderr)
exit(-1)
- sc = SparkContext(appName="PythonSort")
- lines = sc.textFile(sys.argv[1], 1)
+
+ spark = SparkSession\
+ .builder\
+ .appName("PythonSort")\
+ .getOrCreate()
+
+ lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])
sortedCount = lines.flatMap(lambda x: x.split(' ')) \
.map(lambda x: (int(x), 1)) \
.sortByKey()
@@ -37,4 +42,4 @@ if __name__ == "__main__":
for (num, unitcount) in output:
print(num)
- sc.stop()
+ spark.stop()
diff --git a/examples/src/main/python/transitive_closure.py b/examples/src/main/python/transitive_closure.py
index 3d61250d8b..37c41dcd03 100755
--- a/examples/src/main/python/transitive_closure.py
+++ b/examples/src/main/python/transitive_closure.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import sys
from random import Random
-from pyspark import SparkContext
+from pyspark.sql import SparkSession
numEdges = 200
numVertices = 100
@@ -41,7 +41,13 @@ if __name__ == "__main__":
"""
Usage: transitive_closure [partitions]
"""
- sc = SparkContext(appName="PythonTransitiveClosure")
+ spark = SparkSession\
+ .builder\
+ .appName("PythonTransitiveClosure")\
+ .getOrCreate()
+
+ sc = spark._sc
+
partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 2
tc = sc.parallelize(generateGraph(), partitions).cache()
@@ -67,4 +73,4 @@ if __name__ == "__main__":
print("TC has %i edges" % tc.count())
- sc.stop()
+ spark.stop()
diff --git a/examples/src/main/python/wordcount.py b/examples/src/main/python/wordcount.py
index 7c0143607b..3d5e44d5b2 100755
--- a/examples/src/main/python/wordcount.py
+++ b/examples/src/main/python/wordcount.py
@@ -20,15 +20,20 @@ from __future__ import print_function
import sys
from operator import add
-from pyspark import SparkContext
+from pyspark.sql import SparkSession
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: wordcount <file>", file=sys.stderr)
exit(-1)
- sc = SparkContext(appName="PythonWordCount")
- lines = sc.textFile(sys.argv[1], 1)
+
+ spark = SparkSession\
+ .builder\
+ .appName("PythonWordCount")\
+ .getOrCreate()
+
+ lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])
counts = lines.flatMap(lambda x: x.split(' ')) \
.map(lambda x: (x, 1)) \
.reduceByKey(add)
@@ -36,4 +41,4 @@ if __name__ == "__main__":
for (word, count) in output:
print("%s: %i" % (word, count))
- sc.stop()
+ spark.stop()