aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/group.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/group.py')
-rw-r--r--python/pyspark/sql/group.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index ee734cb439..6987af69cf 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -195,13 +195,15 @@ class GroupedData(object):
def _test():
import doctest
- from pyspark.context import SparkContext
- from pyspark.sql import Row, SQLContext
+ from pyspark.sql import Row, SparkSession
import pyspark.sql.group
globs = pyspark.sql.group.__dict__.copy()
- sc = SparkContext('local[4]', 'PythonTest')
+ spark = SparkSession.builder\
+ .master("local[4]")\
+ .appName("sql.group tests")\
+ .getOrCreate()
+ sc = spark.sparkContext
globs['sc'] = sc
- globs['sqlContext'] = SQLContext(sc)
globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
.toDF(StructType([StructField('age', IntegerType()),
StructField('name', StringType())]))
@@ -216,7 +218,7 @@ def _test():
(failure_count, test_count) = doctest.testmod(
pyspark.sql.group, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
- globs['sc'].stop()
+ spark.stop()
if failure_count:
exit(-1)