From 2b67fdb60be95778e016efae4f0a9cdf2fbfe779 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 4 Aug 2015 19:25:24 -0700 Subject: [SPARK-9513] [SQL] [PySpark] Add python API for DataFrame functions This adds Python API for those DataFrame functions that is introduced in 1.5. There is issue with serialize byte_array in Python 3, so some of functions (for BinaryType) does not have tests. cc rxin Author: Davies Liu Closes #7922 from davies/python_functions and squashes the following commits: 8ad942f [Davies Liu] fix test 5fb6ec3 [Davies Liu] fix bugs 3495ed3 [Davies Liu] fix issues ea5f7bb [Davies Liu] Add python API for DataFrame functions --- python/pyspark/sql/functions.py | 885 ++++++++++++++++++++++++++++------------ 1 file changed, 620 insertions(+), 265 deletions(-) (limited to 'python/pyspark/sql/functions.py') diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a73ecc7d93..e65b14dc0e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -32,41 +32,6 @@ from pyspark.sql.types import StringType from pyspark.sql.column import Column, _to_java_column, _to_seq -__all__ = [ - 'array', - 'approxCountDistinct', - 'bin', - 'coalesce', - 'countDistinct', - 'explode', - 'format_number', - 'length', - 'log2', - 'md5', - 'monotonicallyIncreasingId', - 'rand', - 'randn', - 'regexp_extract', - 'regexp_replace', - 'sha1', - 'sha2', - 'size', - 'sort_array', - 'sparkPartitionId', - 'struct', - 'udf', - 'when'] - -__all__ += ['lag', 'lead', 'ntile'] - -__all__ += [ - 'date_format', 'date_add', 'date_sub', 'add_months', 'months_between', - 'year', 'quarter', 'month', 'hour', 'minute', 'second', - 'dayofmonth', 'dayofyear', 'weekofyear'] - -__all__ += ['soundex', 'substring', 'substring_index'] - - def _create_function(name, doc=""): """ Create a function for aggregator by name""" def _(col): @@ -208,30 +173,6 @@ for _name, _doc in _binary_mathfunctions.items(): for _name, _doc in _window_functions.items(): globals()[_name] = since(1.4)(_create_window_function(_name, _doc)) del _name, _doc -__all__ += _functions.keys() -__all__ += _functions_1_4.keys() -__all__ += _binary_mathfunctions.keys() -__all__ += _window_functions.keys() -__all__.sort() - - -@since(1.4) -def array(*cols): - """Creates a new array column. - - :param cols: list of column names (string) or list of :class:`Column` expressions that have - the same data type. - - >>> df.select(array('age', 'age').alias("arr")).collect() - [Row(arr=[2, 2]), Row(arr=[5, 5])] - >>> df.select(array([df.age, df.age]).alias("arr")).collect() - [Row(arr=[2, 2]), Row(arr=[5, 5])] - """ - sc = SparkContext._active_spark_context - if len(cols) == 1 and isinstance(cols[0], (list, set)): - cols = cols[0] - jc = sc._jvm.functions.array(_to_seq(sc, cols, _to_java_column)) - return Column(jc) @since(1.3) @@ -249,19 +190,6 @@ def approxCountDistinct(col, rsd=None): return Column(jc) -@ignore_unicode_prefix -@since(1.5) -def bin(col): - """Returns the string representation of the binary value of the given column. - - >>> df.select(bin(df.age).alias('c')).collect() - [Row(c=u'10'), Row(c=u'101')] - """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.bin(_to_java_column(col)) - return Column(jc) - - @since(1.4) def coalesce(*cols): """Returns the first column that is not null. @@ -314,82 +242,6 @@ def countDistinct(col, *cols): return Column(jc) -@since(1.4) -def explode(col): - """Returns a new row for each element in the given array or map. - - >>> from pyspark.sql import Row - >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) - >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect() - [Row(anInt=1), Row(anInt=2), Row(anInt=3)] - - >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show() - +---+-----+ - |key|value| - +---+-----+ - | a| b| - +---+-----+ - """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.explode(_to_java_column(col)) - return Column(jc) - - -@ignore_unicode_prefix -@since(1.5) -def levenshtein(left, right): - """Computes the Levenshtein distance of the two given strings. - - >>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r']) - >>> df0.select(levenshtein('l', 'r').alias('d')).collect() - [Row(d=3)] - """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.levenshtein(_to_java_column(left), _to_java_column(right)) - return Column(jc) - - -@ignore_unicode_prefix -@since(1.5) -def regexp_extract(str, pattern, idx): - """Extract a specific(idx) group identified by a java regex, from the specified string column. - - >>> df = sqlContext.createDataFrame([('100-200',)], ['str']) - >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect() - [Row(d=u'100')] - """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx) - return Column(jc) - - -@ignore_unicode_prefix -@since(1.5) -def regexp_replace(str, pattern, replacement): - """Replace all substrings of the specified string value that match regexp with rep. - - >>> df = sqlContext.createDataFrame([('100-200',)], ['str']) - >>> df.select(regexp_replace('str', '(\\d+)', '##').alias('d')).collect() - [Row(d=u'##-##')] - """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.regexp_replace(_to_java_column(str), pattern, replacement) - return Column(jc) - - -@ignore_unicode_prefix -@since(1.5) -def md5(col): - """Calculates the MD5 digest and returns the value as a 32 character hex string. - - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect() - [Row(hash=u'902fbdd2b1df0c4f70b4a5d23525e932')] - """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.md5(_to_java_column(col)) - return Column(jc) - - @since(1.4) def monotonicallyIncreasingId(): """A column that generates monotonically increasing 64-bit integers. @@ -435,63 +287,17 @@ def randn(seed=None): return Column(jc) -@ignore_unicode_prefix -@since(1.5) -def hex(col): - """Computes hex value of the given column, which could be StringType, - BinaryType, IntegerType or LongType. - - >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() - [Row(hex(a)=u'414243', hex(b)=u'3')] - """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.hex(_to_java_column(col)) - return Column(jc) - - -@ignore_unicode_prefix -@since(1.5) -def unhex(col): - """Inverse of hex. Interprets each pair of characters as a hexadecimal number - and converts to the byte representation of number. - - >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() - [Row(unhex(a)=bytearray(b'ABC'))] - """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.unhex(_to_java_column(col)) - return Column(jc) - - -@ignore_unicode_prefix @since(1.5) -def sha1(col): - """Returns the hex string result of SHA-1. - - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() - [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] +def round(col, scale=0): """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.sha1(_to_java_column(col)) - return Column(jc) - - -@ignore_unicode_prefix -@since(1.5) -def sha2(col, numBits): - """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384, - and SHA-512). The numBits indicates the desired bit length of the result, which must have a - value of 224, 256, 384, 512, or 0 (which is equivalent to 256). + Round the value of `e` to `scale` decimal places if `scale` >= 0 + or at integral part when `scale` < 0. - >>> digests = df.select(sha2(df.name, 256).alias('s')).collect() - >>> digests[0] - Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043') - >>> digests[1] - Row(s=u'cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961') + >>> sqlContext.createDataFrame([(2.546,)], ['a']).select(round('a', 1).alias('r')).collect() + [Row(r=2.5)] """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.sha2(_to_java_column(col), numBits) - return Column(jc) + return Column(sc._jvm.functions.round(_to_java_column(col), scale)) @since(1.5) @@ -502,8 +308,7 @@ def shiftLeft(col, numBits): [Row(r=42)] """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.shiftLeft(_to_java_column(col), numBits) - return Column(jc) + return Column(sc._jvm.functions.shiftLeft(_to_java_column(col), numBits)) @since(1.5) @@ -522,8 +327,8 @@ def shiftRight(col, numBits): def shiftRightUnsigned(col, numBits): """Unsigned shift the the given value numBits right. - >>> sqlContext.createDataFrame([(-42,)], ['a']).select(shiftRightUnsigned('a', 1).alias('r'))\ - .collect() + >>> df = sqlContext.createDataFrame([(-42,)], ['a']) + >>> df.select(shiftRightUnsigned('a', 1).alias('r')).collect() [Row(r=9223372036854775787)] """ sc = SparkContext._active_spark_context @@ -544,6 +349,7 @@ def sparkPartitionId(): return Column(sc._jvm.functions.sparkPartitionId()) +@since(1.5) def expr(str): """Parses the expression string into the column that it represents @@ -554,34 +360,6 @@ def expr(str): return Column(sc._jvm.functions.expr(str)) -@ignore_unicode_prefix -@since(1.5) -def length(col): - """Calculates the length of a string or binary expression. - - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect() - [Row(length=3)] - """ - sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.length(_to_java_column(col))) - - -@ignore_unicode_prefix -@since(1.5) -def format_number(col, d): - """Formats the number X to a format like '#,###,###.##', rounded to d decimal places, - and returns the result as a string. - - :param col: the column name of the numeric value to be formatted - :param d: the N decimal places - - >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect() - [Row(v=u'5.0000')] - """ - sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.format_number(_to_java_column(col), d)) - - @ignore_unicode_prefix @since(1.4) def struct(*cols): @@ -601,6 +379,38 @@ def struct(*cols): return Column(jc) +@since(1.5) +def greatest(*cols): + """ + Returns the greatest value of the list of column names, skipping null values. + This function takes at least 2 parameters. It will return null iff all parameters are null. + + >>> df = sqlContext.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) + >>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect() + [Row(greatest=4)] + """ + if len(cols) < 2: + raise ValueError("greatest should take at least two columns") + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.greatest(_to_seq(sc, cols, _to_java_column))) + + +@since(1.5) +def least(*cols): + """ + Returns the least value of the list of column names, skipping null values. + This function takes at least 2 parameters. It will return null iff all parameters are null. + + >>> df = sqlContext.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) + >>> df.select(least(df.a, df.b, df.c).alias("least")).collect() + [Row(least=1)] + """ + if len(cols) < 2: + raise ValueError("least should take at least two columns") + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.least(_to_seq(sc, cols, _to_java_column))) + + @since(1.4) def when(condition, value): """Evaluates a list of conditions and returns one of multiple possible result expressions. @@ -654,6 +464,35 @@ def log2(col): return Column(sc._jvm.functions.log2(_to_java_column(col))) +@since(1.5) +@ignore_unicode_prefix +def conv(col, fromBase, toBase): + """ + Convert a number in a string column from one base to another. + + >>> df = sqlContext.createDataFrame([("010101",)], ['n']) + >>> df.select(conv(df.n, 2, 16).alias('hex')).collect() + [Row(hex=u'15')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.conv(_to_java_column(col), fromBase, toBase)) + + +@since(1.5) +def factorial(col): + """ + Computes the factorial of the given value. + + >>> df = sqlContext.createDataFrame([(5,)], ['n']) + >>> df.select(factorial(df.n).alias('f')).collect() + [Row(f=120)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.factorial(_to_java_column(col))) + + +# --------------- Window functions ------------------------ + @since(1.4) def lag(col, count=1, default=None): """ @@ -703,9 +542,28 @@ def ntile(n): return Column(sc._jvm.functions.ntile(int(n))) +# ---------------------- Date/Timestamp functions ------------------------------ + +@since(1.5) +def current_date(): + """ + Returns the current date as a date column. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.current_date()) + + +def current_timestamp(): + """ + Returns the current timestamp as a timestamp column. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.current_timestamp()) + + @ignore_unicode_prefix @since(1.5) -def date_format(dateCol, format): +def date_format(date, format): """ Converts a date/timestamp/string to a value of string in the format specified by the date format given by the second argument. @@ -721,7 +579,7 @@ def date_format(dateCol, format): [Row(date=u'04/08/2015')] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.date_format(_to_java_column(dateCol), format)) + return Column(sc._jvm.functions.date_format(_to_java_column(date), format)) @since(1.5) @@ -867,6 +725,19 @@ def date_sub(start, days): return Column(sc._jvm.functions.date_sub(_to_java_column(start), days)) +@since(1.5) +def datediff(end, start): + """ + Returns the number of days from `start` to `end`. + + >>> df = sqlContext.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2']) + >>> df.select(datediff(df.d2, df.d1).alias('diff')).collect() + [Row(diff=32)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.datediff(_to_java_column(end), _to_java_column(start))) + + @since(1.5) def add_months(start, months): """ @@ -924,33 +795,296 @@ def trunc(date, format): @since(1.5) -@ignore_unicode_prefix -def substring(str, pos, len): +def next_day(date, dayOfWeek): """ - Substring starts at `pos` and is of length `len` when str is String type or - returns the slice of byte array that starts at `pos` in byte and is of length `len` - when str is Binary type + Returns the first date which is later than the value of the date column. - >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) - >>> df.select(substring(df.s, 1, 2).alias('s')).collect() - [Row(s=u'ab')] + Day of the week parameter is case insensitive, and accepts: + "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". + + >>> df = sqlContext.createDataFrame([('2015-07-27',)], ['d']) + >>> df.select(next_day(df.d, 'Sun').alias('date')).collect() + [Row(date=datetime.date(2015, 8, 2))] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.substring(_to_java_column(str), pos, len)) + return Column(sc._jvm.functions.next_day(_to_java_column(date), dayOfWeek)) @since(1.5) -@ignore_unicode_prefix -def substring_index(str, delim, count): +def last_day(date): """ - Returns the substring from string str before count occurrences of the delimiter delim. - If count is positive, everything the left of the final delimiter (counting from left) is - returned. If count is negative, every to the right of the final delimiter (counting from the - right) is returned. substring_index performs a case-sensitive match when searching for delim. + Returns the last day of the month which the given date belongs to. - >>> df = sqlContext.createDataFrame([('a.b.c.d',)], ['s']) - >>> df.select(substring_index(df.s, '.', 2).alias('s')).collect() - [Row(s=u'a.b')] + >>> df = sqlContext.createDataFrame([('1997-02-10',)], ['d']) + >>> df.select(last_day(df.d).alias('date')).collect() + [Row(date=datetime.date(1997, 2, 28))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.last_day(_to_java_column(date))) + + +@since(1.5) +def from_unixtime(timestamp, format="yyyy-MM-dd HH:mm:ss"): + """ + Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + representing the timestamp of that moment in the current system time zone in the given + format. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.from_unixtime(_to_java_column(timestamp), format)) + + +@since(1.5) +def unix_timestamp(timestamp=None, format='yyyy-MM-dd HH:mm:ss'): + """ + Convert time string with given pattern ('yyyy-MM-dd HH:mm:ss', by default) + to Unix time stamp (in seconds), using the default timezone and the default + locale, return null if fail. + + if `timestamp` is None, then it returns current timestamp. + """ + sc = SparkContext._active_spark_context + if timestamp is None: + return Column(sc._jvm.functions.unix_timestamp()) + return Column(sc._jvm.functions.unix_timestamp(_to_java_column(timestamp), format)) + + +@since(1.5) +def from_utc_timestamp(timestamp, tz): + """ + Assumes given timestamp is UTC and converts to given timezone. + + >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df.select(from_utc_timestamp(df.t, "PST").alias('t')).collect() + [Row(t=datetime.datetime(1997, 2, 28, 2, 30))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.from_utc_timestamp(_to_java_column(timestamp), tz)) + + +@since(1.5) +def to_utc_timestamp(timestamp, tz): + """ + Assumes given timestamp is in given timezone and converts to UTC. + + >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df.select(to_utc_timestamp(df.t, "PST").alias('t')).collect() + [Row(t=datetime.datetime(1997, 2, 28, 18, 30))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz)) + + +# ---------------------------- misc functions ---------------------------------- + +@since(1.5) +@ignore_unicode_prefix +def crc32(col): + """ + Calculates the cyclic redundancy check value (CRC32) of a binary column and + returns the value as a bigint. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).collect() + [Row(crc32=u'902fbdd2b1df0c4f70b4a5d23525e932')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.md5(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(1.5) +def md5(col): + """Calculates the MD5 digest and returns the value as a 32 character hex string. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect() + [Row(hash=u'902fbdd2b1df0c4f70b4a5d23525e932')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.md5(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def sha1(col): + """Returns the hex string result of SHA-1. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() + [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.sha1(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def sha2(col, numBits): + """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384, + and SHA-512). The numBits indicates the desired bit length of the result, which must have a + value of 224, 256, 384, 512, or 0 (which is equivalent to 256). + + >>> digests = df.select(sha2(df.name, 256).alias('s')).collect() + >>> digests[0] + Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043') + >>> digests[1] + Row(s=u'cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961') + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.sha2(_to_java_column(col), numBits) + return Column(jc) + + +# ---------------------- String/Binary functions ------------------------------ + +_string_functions = { + 'ascii': 'Computes the numeric value of the first character of the string column.', + 'base64': 'Computes the BASE64 encoding of a binary column and returns it as a string column.', + 'unbase64': 'Decodes a BASE64 encoded string column and returns it as a binary column.', + 'initcap': 'Returns a new string column by converting the first letter of each word to ' + + 'uppercase. Words are delimited by whitespace.', + 'lower': 'Converts a string column to lower case.', + 'upper': 'Converts a string column to upper case.', + 'reverse': 'Reverses the string column and returns it as a new string column.', + 'ltrim': 'Trim the spaces from right end for the specified string value.', + 'rtrim': 'Trim the spaces from right end for the specified string value.', + 'trim': 'Trim the spaces from both ends for the specified string column.', +} + + +for _name, _doc in _string_functions.items(): + globals()[_name] = since(1.5)(_create_function(_name, _doc)) +del _name, _doc + + +@since(1.5) +@ignore_unicode_prefix +def concat(*cols): + """ + Concatenates multiple input string columns together into a single string column. + + >>> df = sqlContext.createDataFrame([('abcd','123')], ['s', 'd']) + >>> df.select(concat(df.s, df.d).alias('s')).collect() + [Row(s=u'abcd123')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column))) + + +@since(1.5) +@ignore_unicode_prefix +def concat_ws(sep, *cols): + """ + Concatenates multiple input string columns together into a single string column, + using the given separator. + + >>> df = sqlContext.createDataFrame([('abcd','123')], ['s', 'd']) + >>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect() + [Row(s=u'abcd-123')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.concat_ws(sep, _to_seq(sc, cols, _to_java_column))) + + +@since(1.5) +def decode(col, charset): + """ + Computes the first argument into a string from a binary using the provided character set + (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.decode(_to_java_column(col), charset)) + + +@since(1.5) +def encode(col, charset): + """ + Computes the first argument into a binary from a string using the provided character set + (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.encode(_to_java_column(col), charset)) + + +@ignore_unicode_prefix +@since(1.5) +def format_number(col, d): + """ + Formats the number X to a format like '#,--#,--#.--', rounded to d decimal places, + and returns the result as a string. + + :param col: the column name of the numeric value to be formatted + :param d: the N decimal places + + >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect() + [Row(v=u'5.0000')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.format_number(_to_java_column(col), d)) + + +@ignore_unicode_prefix +@since(1.5) +def format_string(format, *cols): + """ + Formats the arguments in printf-style and returns the result as a string column. + + :param col: the column name of the numeric value to be formatted + :param d: the N decimal places + + >>> df = sqlContext.createDataFrame([(5, "hello")], ['a', 'b']) + >>> df.select(format_string('%d %s', df.a, df.b).alias('v')).collect() + [Row(v=u'5 hello')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.format_string(format, _to_seq(sc, cols, _to_java_column))) + + +@since(1.5) +def instr(str, substr): + """ + Locate the position of the first occurrence of substr column in the given string. + Returns null if either of the arguments are null. + + NOTE: The position is not zero based, but 1 based index, returns 0 if substr + could not be found in str. + + >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df.select(instr(df.s, 'b').alias('s')).collect() + [Row(s=2)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.instr(_to_java_column(str), substr)) + + +@since(1.5) +@ignore_unicode_prefix +def substring(str, pos, len): + """ + Substring starts at `pos` and is of length `len` when str is String type or + returns the slice of byte array that starts at `pos` in byte and is of length `len` + when str is Binary type + + >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df.select(substring(df.s, 1, 2).alias('s')).collect() + [Row(s=u'ab')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.substring(_to_java_column(str), pos, len)) + + +@since(1.5) +@ignore_unicode_prefix +def substring_index(str, delim, count): + """ + Returns the substring from string str before count occurrences of the delimiter delim. + If count is positive, everything the left of the final delimiter (counting from left) is + returned. If count is negative, every to the right of the final delimiter (counting from the + right) is returned. substring_index performs a case-sensitive match when searching for delim. + + >>> df = sqlContext.createDataFrame([('a.b.c.d',)], ['s']) + >>> df.select(substring_index(df.s, '.', 2).alias('s')).collect() + [Row(s=u'a.b')] >>> df.select(substring_index(df.s, '.', -3).alias('s')).collect() [Row(s=u'b.c.d')] """ @@ -958,6 +1092,126 @@ def substring_index(str, delim, count): return Column(sc._jvm.functions.substring_index(_to_java_column(str), delim, count)) +@ignore_unicode_prefix +@since(1.5) +def levenshtein(left, right): + """Computes the Levenshtein distance of the two given strings. + + >>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r']) + >>> df0.select(levenshtein('l', 'r').alias('d')).collect() + [Row(d=3)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.levenshtein(_to_java_column(left), _to_java_column(right)) + return Column(jc) + + +@since(1.5) +def locate(substr, str, pos=0): + """ + Locate the position of the first occurrence of substr in a string column, after position pos. + + NOTE: The position is not zero based, but 1 based index. returns 0 if substr + could not be found in str. + + :param substr: a string + :param str: a Column of StringType + :param pos: start position (zero based) + + >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df.select(locate('b', df.s, 1).alias('s')).collect() + [Row(s=2)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.locate(substr, _to_java_column(str), pos)) + + +@since(1.5) +@ignore_unicode_prefix +def lpad(col, len, pad): + """ + Left-pad the string column to width `len` with `pad`. + + >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df.select(lpad(df.s, 6, '#').alias('s')).collect() + [Row(s=u'##abcd')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.lpad(_to_java_column(col), len, pad)) + + +@since(1.5) +@ignore_unicode_prefix +def rpad(col, len, pad): + """ + Right-pad the string column to width `len` with `pad`. + + >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df.select(rpad(df.s, 6, '#').alias('s')).collect() + [Row(s=u'abcd##')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.rpad(_to_java_column(col), len, pad)) + + +@since(1.5) +@ignore_unicode_prefix +def repeat(col, n): + """ + Repeats a string column n times, and returns it as a new string column. + + >>> df = sqlContext.createDataFrame([('ab',)], ['s',]) + >>> df.select(repeat(df.s, 3).alias('s')).collect() + [Row(s=u'ababab')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.repeat(_to_java_column(col), n)) + + +@since(1.5) +@ignore_unicode_prefix +def split(str, pattern): + """ + Splits str around pattern (pattern is a regular expression). + + NOTE: pattern is a string represent the regular expression. + + >>> df = sqlContext.createDataFrame([('ab12cd',)], ['s',]) + >>> df.select(split(df.s, '[0-9]+').alias('s')).collect() + [Row(s=[u'ab', u'cd'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.split(_to_java_column(str), pattern)) + + +@ignore_unicode_prefix +@since(1.5) +def regexp_extract(str, pattern, idx): + """Extract a specific(idx) group identified by a java regex, from the specified string column. + + >>> df = sqlContext.createDataFrame([('100-200',)], ['str']) + >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect() + [Row(d=u'100')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def regexp_replace(str, pattern, replacement): + """Replace all substrings of the specified string value that match regexp with rep. + + >>> df = sqlContext.createDataFrame([('100-200',)], ['str']) + >>> df.select(regexp_replace('str', '(\\d+)', '--').alias('d')).collect() + [Row(d=u'-----')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.regexp_replace(_to_java_column(str), pattern, replacement) + return Column(jc) + + @ignore_unicode_prefix @since(1.5) def initcap(col): @@ -970,6 +1224,114 @@ def initcap(col): return Column(sc._jvm.functions.initcap(_to_java_column(col))) +@since(1.5) +@ignore_unicode_prefix +def soundex(col): + """ + Returns the SoundEx encoding for a string + + >>> df = sqlContext.createDataFrame([("Peters",),("Uhrbach",)], ['name']) + >>> df.select(soundex(df.name).alias("soundex")).collect() + [Row(soundex=u'P362'), Row(soundex=u'U612')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.soundex(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(1.5) +def bin(col): + """Returns the string representation of the binary value of the given column. + + >>> df.select(bin(df.age).alias('c')).collect() + [Row(c=u'10'), Row(c=u'101')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.bin(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def hex(col): + """Computes hex value of the given column, which could be StringType, + BinaryType, IntegerType or LongType. + + >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() + [Row(hex(a)=u'414243', hex(b)=u'3')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.hex(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def unhex(col): + """Inverse of hex. Interprets each pair of characters as a hexadecimal number + and converts to the byte representation of number. + + >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() + [Row(unhex(a)=bytearray(b'ABC'))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.unhex(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(1.5) +def length(col): + """Calculates the length of a string or binary expression. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect() + [Row(length=3)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.length(_to_java_column(col))) + + +# ---------------------- Collection functions ------------------------------ + +@since(1.4) +def array(*cols): + """Creates a new array column. + + :param cols: list of column names (string) or list of :class:`Column` expressions that have + the same data type. + + >>> df.select(array('age', 'age').alias("arr")).collect() + [Row(arr=[2, 2]), Row(arr=[5, 5])] + >>> df.select(array([df.age, df.age]).alias("arr")).collect() + [Row(arr=[2, 2]), Row(arr=[5, 5])] + """ + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.array(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + +@since(1.4) +def explode(col): + """Returns a new row for each element in the given array or map. + + >>> from pyspark.sql import Row + >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) + >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect() + [Row(anInt=1), Row(anInt=2), Row(anInt=3)] + + >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show() + +---+-----+ + |key|value| + +---+-----+ + | a| b| + +---+-----+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.explode(_to_java_column(col)) + return Column(jc) + + @since(1.5) def size(col): """ @@ -1002,19 +1364,7 @@ def sort_array(col, asc=True): return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) -@since -@ignore_unicode_prefix -def soundex(col): - """ - Returns the SoundEx encoding for a string - - >>> df = sqlContext.createDataFrame([("Peters",),("Uhrbach",)], ['name']) - >>> df.select(soundex(df.name).alias("soundex")).collect() - [Row(soundex=u'P362'), Row(soundex=u'U612')] - """ - sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.size(_to_java_column(col))) - +# ---------------------------- User Defined Function ---------------------------------- class UserDefinedFunction(object): """ @@ -1066,6 +1416,11 @@ def udf(f, returnType=StringType()): """ return UserDefinedFunction(f, returnType) +blacklist = ['map', 'since', 'ignore_unicode_prefix'] +__all__ = [k for k, v in globals().items() + if not k.startswith('_') and k[0].islower() and callable(v) and k not in blacklist] +__all__.sort() + def _test(): import doctest -- cgit v1.2.3