diff options
author | WeichenXu <WeichenXu123@outlook.com> | 2016-05-23 18:14:48 -0700 |
---|---|---|
committer | Andrew Or <andrew@databricks.com> | 2016-05-23 18:14:48 -0700 |
commit | a15ca5533db91fefaf3248255a59c4d94eeda1a9 (patch) | |
tree | 80867e08b17b01d96611a9a695cbb1f01c0198f6 /python/pyspark/sql | |
parent | 5afd927a47aa7ede3039234f2f7262e2247aa2ae (diff) | |
download | spark-a15ca5533db91fefaf3248255a59c4d94eeda1a9.tar.gz spark-a15ca5533db91fefaf3248255a59c4d94eeda1a9.tar.bz2 spark-a15ca5533db91fefaf3248255a59c4d94eeda1a9.zip |
[SPARK-15464][ML][MLLIB][SQL][TESTS] Replace SQLContext and SparkContext with SparkSession using builder pattern in python test code
## What changes were proposed in this pull request?
Replace SQLContext and SparkContext with SparkSession using builder pattern in python test code.
## How was this patch tested?
Existing test.
Author: WeichenXu <WeichenXu123@outlook.com>
Closes #13242 from WeichenXu123/python_doctest_update_sparksession.
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r-- | python/pyspark/sql/catalog.py | 14 | ||||
-rw-r--r-- | python/pyspark/sql/column.py | 12 | ||||
-rw-r--r-- | python/pyspark/sql/conf.py | 11 | ||||
-rw-r--r-- | python/pyspark/sql/functions.py | 153 | ||||
-rw-r--r-- | python/pyspark/sql/group.py | 12 |
5 files changed, 107 insertions, 95 deletions
diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 812dbba59e..3033f147bc 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -244,21 +244,23 @@ class Catalog(object): def _test(): import os import doctest - from pyspark.context import SparkContext - from pyspark.sql.session import SparkSession + from pyspark.sql import SparkSession import pyspark.sql.catalog os.chdir(os.environ["SPARK_HOME"]) globs = pyspark.sql.catalog.__dict__.copy() - sc = SparkContext('local[4]', 'PythonTest') - globs['sc'] = sc - globs['spark'] = SparkSession(sc) + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.catalog tests")\ + .getOrCreate() + globs['sc'] = spark.sparkContext + globs['spark'] = spark (failure_count, test_count) = doctest.testmod( pyspark.sql.catalog, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) - globs['sc'].stop() + spark.stop() if failure_count: exit(-1) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 5b26e94698..4b99f3058b 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -434,13 +434,15 @@ class Column(object): def _test(): import doctest - from pyspark.context import SparkContext - from pyspark.sql import SQLContext + from pyspark.sql import SparkSession import pyspark.sql.column globs = pyspark.sql.column.__dict__.copy() - sc = SparkContext('local[4]', 'PythonTest') + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.column tests")\ + .getOrCreate() + sc = spark.sparkContext globs['sc'] = sc - globs['sqlContext'] = SQLContext(sc) globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) @@ -448,7 +450,7 @@ def _test(): (failure_count, test_count) = doctest.testmod( pyspark.sql.column, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) - globs['sc'].stop() + spark.stop() if failure_count: exit(-1) diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index 609d882a95..792c420ca6 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -71,11 +71,14 @@ def _test(): os.chdir(os.environ["SPARK_HOME"]) globs = pyspark.sql.conf.__dict__.copy() - sc = SparkContext('local[4]', 'PythonTest') - globs['sc'] = sc - globs['spark'] = SparkSession(sc) + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.conf tests")\ + .getOrCreate() + globs['sc'] = spark.sparkContext + globs['spark'] = spark (failure_count, test_count) = doctest.testmod(pyspark.sql.conf, globs=globs) - globs['sc'].stop() + spark.stop() if failure_count: exit(-1) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 716b16fdc9..1f15eec645 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -212,7 +212,7 @@ def broadcast(df): def coalesce(*cols): """Returns the first column that is not null. - >>> cDf = sqlContext.createDataFrame([(None, None), (1, None), (None, 2)], ("a", "b")) + >>> cDf = spark.createDataFrame([(None, None), (1, None), (None, 2)], ("a", "b")) >>> cDf.show() +----+----+ | a| b| @@ -252,7 +252,7 @@ def corr(col1, col2): >>> a = range(20) >>> b = [2 * x for x in range(20)] - >>> df = sqlContext.createDataFrame(zip(a, b), ["a", "b"]) + >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) >>> df.agg(corr("a", "b").alias('c')).collect() [Row(c=1.0)] """ @@ -267,7 +267,7 @@ def covar_pop(col1, col2): >>> a = [1] * 10 >>> b = [1] * 10 - >>> df = sqlContext.createDataFrame(zip(a, b), ["a", "b"]) + >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) >>> df.agg(covar_pop("a", "b").alias('c')).collect() [Row(c=0.0)] """ @@ -282,7 +282,7 @@ def covar_samp(col1, col2): >>> a = [1] * 10 >>> b = [1] * 10 - >>> df = sqlContext.createDataFrame(zip(a, b), ["a", "b"]) + >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) >>> df.agg(covar_samp("a", "b").alias('c')).collect() [Row(c=0.0)] """ @@ -373,7 +373,7 @@ def input_file_name(): def isnan(col): """An expression that returns true iff the column is NaN. - >>> df = sqlContext.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) + >>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) >>> df.select(isnan("a").alias("r1"), isnan(df.a).alias("r2")).collect() [Row(r1=False, r2=False), Row(r1=True, r2=True)] """ @@ -385,7 +385,7 @@ def isnan(col): def isnull(col): """An expression that returns true iff the column is null. - >>> df = sqlContext.createDataFrame([(1, None), (None, 2)], ("a", "b")) + >>> df = spark.createDataFrame([(1, None), (None, 2)], ("a", "b")) >>> df.select(isnull("a").alias("r1"), isnull(df.a).alias("r2")).collect() [Row(r1=False, r2=False), Row(r1=True, r2=True)] """ @@ -432,7 +432,7 @@ def nanvl(col1, col2): Both inputs should be floating point columns (DoubleType or FloatType). - >>> df = sqlContext.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) + >>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) >>> df.select(nanvl("a", "b").alias("r1"), nanvl(df.a, df.b).alias("r2")).collect() [Row(r1=1.0, r2=1.0), Row(r1=2.0, r2=2.0)] """ @@ -470,7 +470,7 @@ def round(col, scale=0): Round the given value to `scale` decimal places using HALF_UP rounding mode if `scale` >= 0 or at integral part when `scale` < 0. - >>> sqlContext.createDataFrame([(2.5,)], ['a']).select(round('a', 0).alias('r')).collect() + >>> spark.createDataFrame([(2.5,)], ['a']).select(round('a', 0).alias('r')).collect() [Row(r=3.0)] """ sc = SparkContext._active_spark_context @@ -483,7 +483,7 @@ def bround(col, scale=0): Round the given value to `scale` decimal places using HALF_EVEN rounding mode if `scale` >= 0 or at integral part when `scale` < 0. - >>> sqlContext.createDataFrame([(2.5,)], ['a']).select(bround('a', 0).alias('r')).collect() + >>> spark.createDataFrame([(2.5,)], ['a']).select(bround('a', 0).alias('r')).collect() [Row(r=2.0)] """ sc = SparkContext._active_spark_context @@ -494,7 +494,7 @@ def bround(col, scale=0): def shiftLeft(col, numBits): """Shift the given value numBits left. - >>> sqlContext.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect() + >>> spark.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect() [Row(r=42)] """ sc = SparkContext._active_spark_context @@ -505,7 +505,7 @@ def shiftLeft(col, numBits): def shiftRight(col, numBits): """Shift the given value numBits right. - >>> sqlContext.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect() + >>> spark.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect() [Row(r=21)] """ sc = SparkContext._active_spark_context @@ -517,7 +517,7 @@ def shiftRight(col, numBits): def shiftRightUnsigned(col, numBits): """Unsigned shift the given value numBits right. - >>> df = sqlContext.createDataFrame([(-42,)], ['a']) + >>> df = spark.createDataFrame([(-42,)], ['a']) >>> df.select(shiftRightUnsigned('a', 1).alias('r')).collect() [Row(r=9223372036854775787)] """ @@ -575,7 +575,7 @@ def greatest(*cols): Returns the greatest value of the list of column names, skipping null values. This function takes at least 2 parameters. It will return null iff all parameters are null. - >>> df = sqlContext.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) + >>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) >>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect() [Row(greatest=4)] """ @@ -591,7 +591,7 @@ def least(*cols): Returns the least value of the list of column names, skipping null values. This function takes at least 2 parameters. It will return null iff all parameters are null. - >>> df = sqlContext.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) + >>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) >>> df.select(least(df.a, df.b, df.c).alias("least")).collect() [Row(least=1)] """ @@ -647,7 +647,7 @@ def log(arg1, arg2=None): def log2(col): """Returns the base-2 logarithm of the argument. - >>> sqlContext.createDataFrame([(4,)], ['a']).select(log2('a').alias('log2')).collect() + >>> spark.createDataFrame([(4,)], ['a']).select(log2('a').alias('log2')).collect() [Row(log2=2.0)] """ sc = SparkContext._active_spark_context @@ -660,7 +660,7 @@ def conv(col, fromBase, toBase): """ Convert a number in a string column from one base to another. - >>> df = sqlContext.createDataFrame([("010101",)], ['n']) + >>> df = spark.createDataFrame([("010101",)], ['n']) >>> df.select(conv(df.n, 2, 16).alias('hex')).collect() [Row(hex=u'15')] """ @@ -673,7 +673,7 @@ def factorial(col): """ Computes the factorial of the given value. - >>> df = sqlContext.createDataFrame([(5,)], ['n']) + >>> df = spark.createDataFrame([(5,)], ['n']) >>> df.select(factorial(df.n).alias('f')).collect() [Row(f=120)] """ @@ -765,7 +765,7 @@ def date_format(date, format): NOTE: Use when ever possible specialized functions like `year`. These benefit from a specialized implementation. - >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) >>> df.select(date_format('a', 'MM/dd/yyy').alias('date')).collect() [Row(date=u'04/08/2015')] """ @@ -778,7 +778,7 @@ def year(col): """ Extract the year of a given date as integer. - >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) >>> df.select(year('a').alias('year')).collect() [Row(year=2015)] """ @@ -791,7 +791,7 @@ def quarter(col): """ Extract the quarter of a given date as integer. - >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) >>> df.select(quarter('a').alias('quarter')).collect() [Row(quarter=2)] """ @@ -804,7 +804,7 @@ def month(col): """ Extract the month of a given date as integer. - >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) >>> df.select(month('a').alias('month')).collect() [Row(month=4)] """ @@ -817,7 +817,7 @@ def dayofmonth(col): """ Extract the day of the month of a given date as integer. - >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) >>> df.select(dayofmonth('a').alias('day')).collect() [Row(day=8)] """ @@ -830,7 +830,7 @@ def dayofyear(col): """ Extract the day of the year of a given date as integer. - >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) >>> df.select(dayofyear('a').alias('day')).collect() [Row(day=98)] """ @@ -843,7 +843,7 @@ def hour(col): """ Extract the hours of a given date as integer. - >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a']) + >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a']) >>> df.select(hour('a').alias('hour')).collect() [Row(hour=13)] """ @@ -856,7 +856,7 @@ def minute(col): """ Extract the minutes of a given date as integer. - >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a']) + >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a']) >>> df.select(minute('a').alias('minute')).collect() [Row(minute=8)] """ @@ -869,7 +869,7 @@ def second(col): """ Extract the seconds of a given date as integer. - >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a']) + >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a']) >>> df.select(second('a').alias('second')).collect() [Row(second=15)] """ @@ -882,7 +882,7 @@ def weekofyear(col): """ Extract the week number of a given date as integer. - >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) >>> df.select(weekofyear(df.a).alias('week')).collect() [Row(week=15)] """ @@ -895,7 +895,7 @@ def date_add(start, days): """ Returns the date that is `days` days after `start` - >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df = spark.createDataFrame([('2015-04-08',)], ['d']) >>> df.select(date_add(df.d, 1).alias('d')).collect() [Row(d=datetime.date(2015, 4, 9))] """ @@ -908,7 +908,7 @@ def date_sub(start, days): """ Returns the date that is `days` days before `start` - >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df = spark.createDataFrame([('2015-04-08',)], ['d']) >>> df.select(date_sub(df.d, 1).alias('d')).collect() [Row(d=datetime.date(2015, 4, 7))] """ @@ -921,7 +921,7 @@ def datediff(end, start): """ Returns the number of days from `start` to `end`. - >>> df = sqlContext.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2']) + >>> df = spark.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2']) >>> df.select(datediff(df.d2, df.d1).alias('diff')).collect() [Row(diff=32)] """ @@ -934,7 +934,7 @@ def add_months(start, months): """ Returns the date that is `months` months after `start` - >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df = spark.createDataFrame([('2015-04-08',)], ['d']) >>> df.select(add_months(df.d, 1).alias('d')).collect() [Row(d=datetime.date(2015, 5, 8))] """ @@ -947,7 +947,7 @@ def months_between(date1, date2): """ Returns the number of months between date1 and date2. - >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd']) + >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd']) >>> df.select(months_between(df.t, df.d).alias('months')).collect() [Row(months=3.9495967...)] """ @@ -960,7 +960,7 @@ def to_date(col): """ Converts the column of StringType or TimestampType into DateType. - >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) >>> df.select(to_date(df.t).alias('date')).collect() [Row(date=datetime.date(1997, 2, 28))] """ @@ -975,7 +975,7 @@ def trunc(date, format): :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' - >>> df = sqlContext.createDataFrame([('1997-02-28',)], ['d']) + >>> df = spark.createDataFrame([('1997-02-28',)], ['d']) >>> df.select(trunc(df.d, 'year').alias('year')).collect() [Row(year=datetime.date(1997, 1, 1))] >>> df.select(trunc(df.d, 'mon').alias('month')).collect() @@ -993,7 +993,7 @@ def next_day(date, dayOfWeek): Day of the week parameter is case insensitive, and accepts: "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". - >>> df = sqlContext.createDataFrame([('2015-07-27',)], ['d']) + >>> df = spark.createDataFrame([('2015-07-27',)], ['d']) >>> df.select(next_day(df.d, 'Sun').alias('date')).collect() [Row(date=datetime.date(2015, 8, 2))] """ @@ -1006,7 +1006,7 @@ def last_day(date): """ Returns the last day of the month which the given date belongs to. - >>> df = sqlContext.createDataFrame([('1997-02-10',)], ['d']) + >>> df = spark.createDataFrame([('1997-02-10',)], ['d']) >>> df.select(last_day(df.d).alias('date')).collect() [Row(date=datetime.date(1997, 2, 28))] """ @@ -1045,7 +1045,7 @@ def from_utc_timestamp(timestamp, tz): """ Assumes given timestamp is UTC and converts to given timezone. - >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) >>> df.select(from_utc_timestamp(df.t, "PST").alias('t')).collect() [Row(t=datetime.datetime(1997, 2, 28, 2, 30))] """ @@ -1058,7 +1058,7 @@ def to_utc_timestamp(timestamp, tz): """ Assumes given timestamp is in given timezone and converts to UTC. - >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) >>> df.select(to_utc_timestamp(df.t, "PST").alias('t')).collect() [Row(t=datetime.datetime(1997, 2, 28, 18, 30))] """ @@ -1087,7 +1087,7 @@ def window(timeColumn, windowDuration, slideDuration=None, startTime=None): The output column will be a struct called 'window' by default with the nested columns 'start' and 'end', where 'start' and 'end' will be of `TimestampType`. - >>> df = sqlContext.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", "val") + >>> df = spark.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", "val") >>> w = df.groupBy(window("date", "5 seconds")).agg(sum("val").alias("sum")) >>> w.select(w.window.start.cast("string").alias("start"), ... w.window.end.cast("string").alias("end"), "sum").collect() @@ -1124,7 +1124,7 @@ def crc32(col): Calculates the cyclic redundancy check value (CRC32) of a binary column and returns the value as a bigint. - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).collect() + >>> spark.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).collect() [Row(crc32=2743272264)] """ sc = SparkContext._active_spark_context @@ -1136,7 +1136,7 @@ def crc32(col): def md5(col): """Calculates the MD5 digest and returns the value as a 32 character hex string. - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect() + >>> spark.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect() [Row(hash=u'902fbdd2b1df0c4f70b4a5d23525e932')] """ sc = SparkContext._active_spark_context @@ -1149,7 +1149,7 @@ def md5(col): def sha1(col): """Returns the hex string result of SHA-1. - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() + >>> spark.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] """ sc = SparkContext._active_spark_context @@ -1179,7 +1179,7 @@ def sha2(col, numBits): def hash(*cols): """Calculates the hash code of given columns, and returns the result as a int column. - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(hash('a').alias('hash')).collect() + >>> spark.createDataFrame([('ABC',)], ['a']).select(hash('a').alias('hash')).collect() [Row(hash=-757602832)] """ sc = SparkContext._active_spark_context @@ -1215,7 +1215,7 @@ def concat(*cols): """ Concatenates multiple input string columns together into a single string column. - >>> df = sqlContext.createDataFrame([('abcd','123')], ['s', 'd']) + >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) >>> df.select(concat(df.s, df.d).alias('s')).collect() [Row(s=u'abcd123')] """ @@ -1230,7 +1230,7 @@ def concat_ws(sep, *cols): Concatenates multiple input string columns together into a single string column, using the given separator. - >>> df = sqlContext.createDataFrame([('abcd','123')], ['s', 'd']) + >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) >>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect() [Row(s=u'abcd-123')] """ @@ -1268,7 +1268,7 @@ def format_number(col, d): :param col: the column name of the numeric value to be formatted :param d: the N decimal places - >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect() + >>> spark.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect() [Row(v=u'5.0000')] """ sc = SparkContext._active_spark_context @@ -1284,7 +1284,7 @@ def format_string(format, *cols): :param col: the column name of the numeric value to be formatted :param d: the N decimal places - >>> df = sqlContext.createDataFrame([(5, "hello")], ['a', 'b']) + >>> df = spark.createDataFrame([(5, "hello")], ['a', 'b']) >>> df.select(format_string('%d %s', df.a, df.b).alias('v')).collect() [Row(v=u'5 hello')] """ @@ -1301,7 +1301,7 @@ def instr(str, substr): NOTE: The position is not zero based, but 1 based index, returns 0 if substr could not be found in str. - >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df = spark.createDataFrame([('abcd',)], ['s',]) >>> df.select(instr(df.s, 'b').alias('s')).collect() [Row(s=2)] """ @@ -1317,7 +1317,7 @@ def substring(str, pos, len): returns the slice of byte array that starts at `pos` in byte and is of length `len` when str is Binary type - >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df = spark.createDataFrame([('abcd',)], ['s',]) >>> df.select(substring(df.s, 1, 2).alias('s')).collect() [Row(s=u'ab')] """ @@ -1334,7 +1334,7 @@ def substring_index(str, delim, count): returned. If count is negative, every to the right of the final delimiter (counting from the right) is returned. substring_index performs a case-sensitive match when searching for delim. - >>> df = sqlContext.createDataFrame([('a.b.c.d',)], ['s']) + >>> df = spark.createDataFrame([('a.b.c.d',)], ['s']) >>> df.select(substring_index(df.s, '.', 2).alias('s')).collect() [Row(s=u'a.b')] >>> df.select(substring_index(df.s, '.', -3).alias('s')).collect() @@ -1349,7 +1349,7 @@ def substring_index(str, delim, count): def levenshtein(left, right): """Computes the Levenshtein distance of the two given strings. - >>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r']) + >>> df0 = spark.createDataFrame([('kitten', 'sitting',)], ['l', 'r']) >>> df0.select(levenshtein('l', 'r').alias('d')).collect() [Row(d=3)] """ @@ -1370,7 +1370,7 @@ def locate(substr, str, pos=0): :param str: a Column of StringType :param pos: start position (zero based) - >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df = spark.createDataFrame([('abcd',)], ['s',]) >>> df.select(locate('b', df.s, 1).alias('s')).collect() [Row(s=2)] """ @@ -1384,7 +1384,7 @@ def lpad(col, len, pad): """ Left-pad the string column to width `len` with `pad`. - >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df = spark.createDataFrame([('abcd',)], ['s',]) >>> df.select(lpad(df.s, 6, '#').alias('s')).collect() [Row(s=u'##abcd')] """ @@ -1398,7 +1398,7 @@ def rpad(col, len, pad): """ Right-pad the string column to width `len` with `pad`. - >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df = spark.createDataFrame([('abcd',)], ['s',]) >>> df.select(rpad(df.s, 6, '#').alias('s')).collect() [Row(s=u'abcd##')] """ @@ -1412,7 +1412,7 @@ def repeat(col, n): """ Repeats a string column n times, and returns it as a new string column. - >>> df = sqlContext.createDataFrame([('ab',)], ['s',]) + >>> df = spark.createDataFrame([('ab',)], ['s',]) >>> df.select(repeat(df.s, 3).alias('s')).collect() [Row(s=u'ababab')] """ @@ -1428,7 +1428,7 @@ def split(str, pattern): NOTE: pattern is a string represent the regular expression. - >>> df = sqlContext.createDataFrame([('ab12cd',)], ['s',]) + >>> df = spark.createDataFrame([('ab12cd',)], ['s',]) >>> df.select(split(df.s, '[0-9]+').alias('s')).collect() [Row(s=[u'ab', u'cd'])] """ @@ -1441,7 +1441,7 @@ def split(str, pattern): def regexp_extract(str, pattern, idx): """Extract a specific(idx) group identified by a java regex, from the specified string column. - >>> df = sqlContext.createDataFrame([('100-200',)], ['str']) + >>> df = spark.createDataFrame([('100-200',)], ['str']) >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect() [Row(d=u'100')] """ @@ -1455,7 +1455,7 @@ def regexp_extract(str, pattern, idx): def regexp_replace(str, pattern, replacement): """Replace all substrings of the specified string value that match regexp with rep. - >>> df = sqlContext.createDataFrame([('100-200',)], ['str']) + >>> df = spark.createDataFrame([('100-200',)], ['str']) >>> df.select(regexp_replace('str', '(\\d+)', '--').alias('d')).collect() [Row(d=u'-----')] """ @@ -1469,7 +1469,7 @@ def regexp_replace(str, pattern, replacement): def initcap(col): """Translate the first letter of each word to upper case in the sentence. - >>> sqlContext.createDataFrame([('ab cd',)], ['a']).select(initcap("a").alias('v')).collect() + >>> spark.createDataFrame([('ab cd',)], ['a']).select(initcap("a").alias('v')).collect() [Row(v=u'Ab Cd')] """ sc = SparkContext._active_spark_context @@ -1482,7 +1482,7 @@ def soundex(col): """ Returns the SoundEx encoding for a string - >>> df = sqlContext.createDataFrame([("Peters",),("Uhrbach",)], ['name']) + >>> df = spark.createDataFrame([("Peters",),("Uhrbach",)], ['name']) >>> df.select(soundex(df.name).alias("soundex")).collect() [Row(soundex=u'P362'), Row(soundex=u'U612')] """ @@ -1509,7 +1509,7 @@ def hex(col): """Computes hex value of the given column, which could be StringType, BinaryType, IntegerType or LongType. - >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() + >>> spark.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() [Row(hex(a)=u'414243', hex(b)=u'3')] """ sc = SparkContext._active_spark_context @@ -1523,7 +1523,7 @@ def unhex(col): """Inverse of hex. Interprets each pair of characters as a hexadecimal number and converts to the byte representation of number. - >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() + >>> spark.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() [Row(unhex(a)=bytearray(b'ABC'))] """ sc = SparkContext._active_spark_context @@ -1535,7 +1535,7 @@ def unhex(col): def length(col): """Calculates the length of a string or binary expression. - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect() + >>> spark.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect() [Row(length=3)] """ sc = SparkContext._active_spark_context @@ -1550,7 +1550,7 @@ def translate(srcCol, matching, replace): The translate will happen when any character in the string matching with the character in the `matching`. - >>> sqlContext.createDataFrame([('translate',)], ['a']).select(translate('a', "rnlt", "123")\ + >>> spark.createDataFrame([('translate',)], ['a']).select(translate('a', "rnlt", "123")\ .alias('r')).collect() [Row(r=u'1a2s3ae')] """ @@ -1608,7 +1608,7 @@ def array_contains(col, value): :param col: name of column containing array :param value: value to check for in array - >>> df = sqlContext.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) + >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) >>> df.select(array_contains(df.data, "a")).collect() [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)] """ @@ -1621,7 +1621,7 @@ def explode(col): """Returns a new row for each element in the given array or map. >>> from pyspark.sql import Row - >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) + >>> eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect() [Row(anInt=1), Row(anInt=2), Row(anInt=3)] @@ -1648,7 +1648,7 @@ def get_json_object(col, path): :param path: path to the json object to extract >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')] - >>> df = sqlContext.createDataFrame(data, ("key", "jstring")) + >>> df = spark.createDataFrame(data, ("key", "jstring")) >>> df.select(df.key, get_json_object(df.jstring, '$.f1').alias("c0"), \ get_json_object(df.jstring, '$.f2').alias("c1") ).collect() [Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', c1=None)] @@ -1667,7 +1667,7 @@ def json_tuple(col, *fields): :param fields: list of fields to extract >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')] - >>> df = sqlContext.createDataFrame(data, ("key", "jstring")) + >>> df = spark.createDataFrame(data, ("key", "jstring")) >>> df.select(df.key, json_tuple(df.jstring, 'f1', 'f2')).collect() [Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', c1=None)] """ @@ -1683,7 +1683,7 @@ def size(col): :param col: name of column or expression - >>> df = sqlContext.createDataFrame([([1, 2, 3],),([1],),([],)], ['data']) + >>> df = spark.createDataFrame([([1, 2, 3],),([1],),([],)], ['data']) >>> df.select(size(df.data)).collect() [Row(size(data)=3), Row(size(data)=1), Row(size(data)=0)] """ @@ -1698,7 +1698,7 @@ def sort_array(col, asc=True): :param col: name of column or expression - >>> df = sqlContext.createDataFrame([([2, 1, 3],),([1],),([],)], ['data']) + >>> df = spark.createDataFrame([([2, 1, 3],),([1],),([],)], ['data']) >>> df.select(sort_array(df.data).alias('r')).collect() [Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])] >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() @@ -1775,18 +1775,21 @@ __all__.sort() def _test(): import doctest - from pyspark.context import SparkContext - from pyspark.sql import Row, SQLContext + from pyspark.sql import Row, SparkSession import pyspark.sql.functions globs = pyspark.sql.functions.__dict__.copy() - sc = SparkContext('local[4]', 'PythonTest') + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.functions tests")\ + .getOrCreate() + sc = spark.sparkContext globs['sc'] = sc - globs['sqlContext'] = SQLContext(sc) + globs['spark'] = spark globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.functions, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) - globs['sc'].stop() + spark.stop() if failure_count: exit(-1) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index ee734cb439..6987af69cf 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -195,13 +195,15 @@ class GroupedData(object): def _test(): import doctest - from pyspark.context import SparkContext - from pyspark.sql import Row, SQLContext + from pyspark.sql import Row, SparkSession import pyspark.sql.group globs = pyspark.sql.group.__dict__.copy() - sc = SparkContext('local[4]', 'PythonTest') + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.group tests")\ + .getOrCreate() + sc = spark.sparkContext globs['sc'] = sc - globs['sqlContext'] = SQLContext(sc) globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) @@ -216,7 +218,7 @@ def _test(): (failure_count, test_count) = doctest.testmod( pyspark.sql.group, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) - globs['sc'].stop() + spark.stop() if failure_count: exit(-1) |