aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/python/streaming
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/python/streaming')
-rw-r--r--examples/src/main/python/streaming/sql_network_wordcount.py19
1 files changed, 10 insertions, 9 deletions
diff --git a/examples/src/main/python/streaming/sql_network_wordcount.py b/examples/src/main/python/streaming/sql_network_wordcount.py
index 1ba5e9fb78..588cbfee14 100644
--- a/examples/src/main/python/streaming/sql_network_wordcount.py
+++ b/examples/src/main/python/streaming/sql_network_wordcount.py
@@ -33,13 +33,14 @@ import sys
from pyspark import SparkContext
from pyspark.streaming import StreamingContext
-from pyspark.sql import SQLContext, Row
+from pyspark.sql import Row, SparkSession
-def getSqlContextInstance(sparkContext):
- if ('sqlContextSingletonInstance' not in globals()):
- globals()['sqlContextSingletonInstance'] = SQLContext(sparkContext)
- return globals()['sqlContextSingletonInstance']
+def getSparkSessionInstance(sparkConf):
+ if ('sparkSessionSingletonInstance' not in globals()):
+ globals()['sparkSessionSingletonInstance'] =\
+ SparkSession.builder.config(conf=sparkConf).getOrCreate()
+ return globals()['sparkSessionSingletonInstance']
if __name__ == "__main__":
@@ -60,19 +61,19 @@ if __name__ == "__main__":
print("========= %s =========" % str(time))
try:
- # Get the singleton instance of SQLContext
- sqlContext = getSqlContextInstance(rdd.context)
+ # Get the singleton instance of SparkSession
+ spark = getSparkSessionInstance(rdd.context.getConf())
# Convert RDD[String] to RDD[Row] to DataFrame
rowRdd = rdd.map(lambda w: Row(word=w))
- wordsDataFrame = sqlContext.createDataFrame(rowRdd)
+ wordsDataFrame = spark.createDataFrame(rowRdd)
# Register as table
wordsDataFrame.registerTempTable("words")
# Do word count on table using SQL and print it
wordCountsDataFrame = \
- sqlContext.sql("select word, count(*) as total from words group by word")
+ spark.sql("select word, count(*) as total from words group by word")
wordCountsDataFrame.show()
except:
pass