aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-08-04 19:25:24 -0700
committerReynold Xin <rxin@databricks.com>2015-08-04 19:25:24 -0700
commit2b67fdb60be95778e016efae4f0a9cdf2fbfe779 (patch)
tree23672d80a51fdbafb5ddfbf7e9c6317145e7aaad /python
parent6f8f0e265a29e89bd5192a8d5217cba19f0875da (diff)
downloadspark-2b67fdb60be95778e016efae4f0a9cdf2fbfe779.tar.gz
spark-2b67fdb60be95778e016efae4f0a9cdf2fbfe779.tar.bz2
spark-2b67fdb60be95778e016efae4f0a9cdf2fbfe779.zip
[SPARK-9513] [SQL] [PySpark] Add python API for DataFrame functions
This adds Python API for those DataFrame functions that is introduced in 1.5. There is issue with serialize byte_array in Python 3, so some of functions (for BinaryType) does not have tests. cc rxin Author: Davies Liu <davies@databricks.com> Closes #7922 from davies/python_functions and squashes the following commits: 8ad942f [Davies Liu] fix test 5fb6ec3 [Davies Liu] fix bugs 3495ed3 [Davies Liu] fix issues ea5f7bb [Davies Liu] Add python API for DataFrame functions
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/functions.py849
1 files changed, 602 insertions, 247 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index a73ecc7d93..e65b14dc0e 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -32,41 +32,6 @@ from pyspark.sql.types import StringType
from pyspark.sql.column import Column, _to_java_column, _to_seq
-__all__ = [
- 'array',
- 'approxCountDistinct',
- 'bin',
- 'coalesce',
- 'countDistinct',
- 'explode',
- 'format_number',
- 'length',
- 'log2',
- 'md5',
- 'monotonicallyIncreasingId',
- 'rand',
- 'randn',
- 'regexp_extract',
- 'regexp_replace',
- 'sha1',
- 'sha2',
- 'size',
- 'sort_array',
- 'sparkPartitionId',
- 'struct',
- 'udf',
- 'when']
-
-__all__ += ['lag', 'lead', 'ntile']
-
-__all__ += [
- 'date_format', 'date_add', 'date_sub', 'add_months', 'months_between',
- 'year', 'quarter', 'month', 'hour', 'minute', 'second',
- 'dayofmonth', 'dayofyear', 'weekofyear']
-
-__all__ += ['soundex', 'substring', 'substring_index']
-
-
def _create_function(name, doc=""):
""" Create a function for aggregator by name"""
def _(col):
@@ -208,30 +173,6 @@ for _name, _doc in _binary_mathfunctions.items():
for _name, _doc in _window_functions.items():
globals()[_name] = since(1.4)(_create_window_function(_name, _doc))
del _name, _doc
-__all__ += _functions.keys()
-__all__ += _functions_1_4.keys()
-__all__ += _binary_mathfunctions.keys()
-__all__ += _window_functions.keys()
-__all__.sort()
-
-
-@since(1.4)
-def array(*cols):
- """Creates a new array column.
-
- :param cols: list of column names (string) or list of :class:`Column` expressions that have
- the same data type.
-
- >>> df.select(array('age', 'age').alias("arr")).collect()
- [Row(arr=[2, 2]), Row(arr=[5, 5])]
- >>> df.select(array([df.age, df.age]).alias("arr")).collect()
- [Row(arr=[2, 2]), Row(arr=[5, 5])]
- """
- sc = SparkContext._active_spark_context
- if len(cols) == 1 and isinstance(cols[0], (list, set)):
- cols = cols[0]
- jc = sc._jvm.functions.array(_to_seq(sc, cols, _to_java_column))
- return Column(jc)
@since(1.3)
@@ -249,19 +190,6 @@ def approxCountDistinct(col, rsd=None):
return Column(jc)
-@ignore_unicode_prefix
-@since(1.5)
-def bin(col):
- """Returns the string representation of the binary value of the given column.
-
- >>> df.select(bin(df.age).alias('c')).collect()
- [Row(c=u'10'), Row(c=u'101')]
- """
- sc = SparkContext._active_spark_context
- jc = sc._jvm.functions.bin(_to_java_column(col))
- return Column(jc)
-
-
@since(1.4)
def coalesce(*cols):
"""Returns the first column that is not null.
@@ -315,82 +243,6 @@ def countDistinct(col, *cols):
@since(1.4)
-def explode(col):
- """Returns a new row for each element in the given array or map.
-
- >>> from pyspark.sql import Row
- >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])
- >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect()
- [Row(anInt=1), Row(anInt=2), Row(anInt=3)]
-
- >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show()
- +---+-----+
- |key|value|
- +---+-----+
- | a| b|
- +---+-----+
- """
- sc = SparkContext._active_spark_context
- jc = sc._jvm.functions.explode(_to_java_column(col))
- return Column(jc)
-
-
-@ignore_unicode_prefix
-@since(1.5)
-def levenshtein(left, right):
- """Computes the Levenshtein distance of the two given strings.
-
- >>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r'])
- >>> df0.select(levenshtein('l', 'r').alias('d')).collect()
- [Row(d=3)]
- """
- sc = SparkContext._active_spark_context
- jc = sc._jvm.functions.levenshtein(_to_java_column(left), _to_java_column(right))
- return Column(jc)
-
-
-@ignore_unicode_prefix
-@since(1.5)
-def regexp_extract(str, pattern, idx):
- """Extract a specific(idx) group identified by a java regex, from the specified string column.
-
- >>> df = sqlContext.createDataFrame([('100-200',)], ['str'])
- >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect()
- [Row(d=u'100')]
- """
- sc = SparkContext._active_spark_context
- jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx)
- return Column(jc)
-
-
-@ignore_unicode_prefix
-@since(1.5)
-def regexp_replace(str, pattern, replacement):
- """Replace all substrings of the specified string value that match regexp with rep.
-
- >>> df = sqlContext.createDataFrame([('100-200',)], ['str'])
- >>> df.select(regexp_replace('str', '(\\d+)', '##').alias('d')).collect()
- [Row(d=u'##-##')]
- """
- sc = SparkContext._active_spark_context
- jc = sc._jvm.functions.regexp_replace(_to_java_column(str), pattern, replacement)
- return Column(jc)
-
-
-@ignore_unicode_prefix
-@since(1.5)
-def md5(col):
- """Calculates the MD5 digest and returns the value as a 32 character hex string.
-
- >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect()
- [Row(hash=u'902fbdd2b1df0c4f70b4a5d23525e932')]
- """
- sc = SparkContext._active_spark_context
- jc = sc._jvm.functions.md5(_to_java_column(col))
- return Column(jc)
-
-
-@since(1.4)
def monotonicallyIncreasingId():
"""A column that generates monotonically increasing 64-bit integers.
@@ -435,63 +287,17 @@ def randn(seed=None):
return Column(jc)
-@ignore_unicode_prefix
-@since(1.5)
-def hex(col):
- """Computes hex value of the given column, which could be StringType,
- BinaryType, IntegerType or LongType.
-
- >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect()
- [Row(hex(a)=u'414243', hex(b)=u'3')]
- """
- sc = SparkContext._active_spark_context
- jc = sc._jvm.functions.hex(_to_java_column(col))
- return Column(jc)
-
-
-@ignore_unicode_prefix
-@since(1.5)
-def unhex(col):
- """Inverse of hex. Interprets each pair of characters as a hexadecimal number
- and converts to the byte representation of number.
-
- >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect()
- [Row(unhex(a)=bytearray(b'ABC'))]
- """
- sc = SparkContext._active_spark_context
- jc = sc._jvm.functions.unhex(_to_java_column(col))
- return Column(jc)
-
-
-@ignore_unicode_prefix
@since(1.5)
-def sha1(col):
- """Returns the hex string result of SHA-1.
-
- >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect()
- [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')]
+def round(col, scale=0):
"""
- sc = SparkContext._active_spark_context
- jc = sc._jvm.functions.sha1(_to_java_column(col))
- return Column(jc)
-
-
-@ignore_unicode_prefix
-@since(1.5)
-def sha2(col, numBits):
- """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384,
- and SHA-512). The numBits indicates the desired bit length of the result, which must have a
- value of 224, 256, 384, 512, or 0 (which is equivalent to 256).
+ Round the value of `e` to `scale` decimal places if `scale` >= 0
+ or at integral part when `scale` < 0.
- >>> digests = df.select(sha2(df.name, 256).alias('s')).collect()
- >>> digests[0]
- Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043')
- >>> digests[1]
- Row(s=u'cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961')
+ >>> sqlContext.createDataFrame([(2.546,)], ['a']).select(round('a', 1).alias('r')).collect()
+ [Row(r=2.5)]
"""
sc = SparkContext._active_spark_context
- jc = sc._jvm.functions.sha2(_to_java_column(col), numBits)
- return Column(jc)
+ return Column(sc._jvm.functions.round(_to_java_column(col), scale))
@since(1.5)
@@ -502,8 +308,7 @@ def shiftLeft(col, numBits):
[Row(r=42)]
"""
sc = SparkContext._active_spark_context
- jc = sc._jvm.functions.shiftLeft(_to_java_column(col), numBits)
- return Column(jc)
+ return Column(sc._jvm.functions.shiftLeft(_to_java_column(col), numBits))
@since(1.5)
@@ -522,8 +327,8 @@ def shiftRight(col, numBits):
def shiftRightUnsigned(col, numBits):
"""Unsigned shift the the given value numBits right.
- >>> sqlContext.createDataFrame([(-42,)], ['a']).select(shiftRightUnsigned('a', 1).alias('r'))\
- .collect()
+ >>> df = sqlContext.createDataFrame([(-42,)], ['a'])
+ >>> df.select(shiftRightUnsigned('a', 1).alias('r')).collect()
[Row(r=9223372036854775787)]
"""
sc = SparkContext._active_spark_context
@@ -544,6 +349,7 @@ def sparkPartitionId():
return Column(sc._jvm.functions.sparkPartitionId())
+@since(1.5)
def expr(str):
"""Parses the expression string into the column that it represents
@@ -555,34 +361,6 @@ def expr(str):
@ignore_unicode_prefix
-@since(1.5)
-def length(col):
- """Calculates the length of a string or binary expression.
-
- >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect()
- [Row(length=3)]
- """
- sc = SparkContext._active_spark_context
- return Column(sc._jvm.functions.length(_to_java_column(col)))
-
-
-@ignore_unicode_prefix
-@since(1.5)
-def format_number(col, d):
- """Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
- and returns the result as a string.
-
- :param col: the column name of the numeric value to be formatted
- :param d: the N decimal places
-
- >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect()
- [Row(v=u'5.0000')]
- """
- sc = SparkContext._active_spark_context
- return Column(sc._jvm.functions.format_number(_to_java_column(col), d))
-
-
-@ignore_unicode_prefix
@since(1.4)
def struct(*cols):
"""Creates a new struct column.
@@ -601,6 +379,38 @@ def struct(*cols):
return Column(jc)
+@since(1.5)
+def greatest(*cols):
+ """
+ Returns the greatest value of the list of column names, skipping null values.
+ This function takes at least 2 parameters. It will return null iff all parameters are null.
+
+ >>> df = sqlContext.createDataFrame([(1, 4, 3)], ['a', 'b', 'c'])
+ >>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect()
+ [Row(greatest=4)]
+ """
+ if len(cols) < 2:
+ raise ValueError("greatest should take at least two columns")
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.greatest(_to_seq(sc, cols, _to_java_column)))
+
+
+@since(1.5)
+def least(*cols):
+ """
+ Returns the least value of the list of column names, skipping null values.
+ This function takes at least 2 parameters. It will return null iff all parameters are null.
+
+ >>> df = sqlContext.createDataFrame([(1, 4, 3)], ['a', 'b', 'c'])
+ >>> df.select(least(df.a, df.b, df.c).alias("least")).collect()
+ [Row(least=1)]
+ """
+ if len(cols) < 2:
+ raise ValueError("least should take at least two columns")
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.least(_to_seq(sc, cols, _to_java_column)))
+
+
@since(1.4)
def when(condition, value):
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
@@ -654,6 +464,35 @@ def log2(col):
return Column(sc._jvm.functions.log2(_to_java_column(col)))
+@since(1.5)
+@ignore_unicode_prefix
+def conv(col, fromBase, toBase):
+ """
+ Convert a number in a string column from one base to another.
+
+ >>> df = sqlContext.createDataFrame([("010101",)], ['n'])
+ >>> df.select(conv(df.n, 2, 16).alias('hex')).collect()
+ [Row(hex=u'15')]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.conv(_to_java_column(col), fromBase, toBase))
+
+
+@since(1.5)
+def factorial(col):
+ """
+ Computes the factorial of the given value.
+
+ >>> df = sqlContext.createDataFrame([(5,)], ['n'])
+ >>> df.select(factorial(df.n).alias('f')).collect()
+ [Row(f=120)]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.factorial(_to_java_column(col)))
+
+
+# --------------- Window functions ------------------------
+
@since(1.4)
def lag(col, count=1, default=None):
"""
@@ -703,9 +542,28 @@ def ntile(n):
return Column(sc._jvm.functions.ntile(int(n)))
+# ---------------------- Date/Timestamp functions ------------------------------
+
+@since(1.5)
+def current_date():
+ """
+ Returns the current date as a date column.
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.current_date())
+
+
+def current_timestamp():
+ """
+ Returns the current timestamp as a timestamp column.
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.current_timestamp())
+
+
@ignore_unicode_prefix
@since(1.5)
-def date_format(dateCol, format):
+def date_format(date, format):
"""
Converts a date/timestamp/string to a value of string in the format specified by the date
format given by the second argument.
@@ -721,7 +579,7 @@ def date_format(dateCol, format):
[Row(date=u'04/08/2015')]
"""
sc = SparkContext._active_spark_context
- return Column(sc._jvm.functions.date_format(_to_java_column(dateCol), format))
+ return Column(sc._jvm.functions.date_format(_to_java_column(date), format))
@since(1.5)
@@ -868,6 +726,19 @@ def date_sub(start, days):
@since(1.5)
+def datediff(end, start):
+ """
+ Returns the number of days from `start` to `end`.
+
+ >>> df = sqlContext.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2'])
+ >>> df.select(datediff(df.d2, df.d1).alias('diff')).collect()
+ [Row(diff=32)]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.datediff(_to_java_column(end), _to_java_column(start)))
+
+
+@since(1.5)
def add_months(start, months):
"""
Returns the date that is `months` months after `start`
@@ -924,6 +795,269 @@ def trunc(date, format):
@since(1.5)
+def next_day(date, dayOfWeek):
+ """
+ Returns the first date which is later than the value of the date column.
+
+ Day of the week parameter is case insensitive, and accepts:
+ "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun".
+
+ >>> df = sqlContext.createDataFrame([('2015-07-27',)], ['d'])
+ >>> df.select(next_day(df.d, 'Sun').alias('date')).collect()
+ [Row(date=datetime.date(2015, 8, 2))]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.next_day(_to_java_column(date), dayOfWeek))
+
+
+@since(1.5)
+def last_day(date):
+ """
+ Returns the last day of the month which the given date belongs to.
+
+ >>> df = sqlContext.createDataFrame([('1997-02-10',)], ['d'])
+ >>> df.select(last_day(df.d).alias('date')).collect()
+ [Row(date=datetime.date(1997, 2, 28))]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.last_day(_to_java_column(date)))
+
+
+@since(1.5)
+def from_unixtime(timestamp, format="yyyy-MM-dd HH:mm:ss"):
+ """
+ Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string
+ representing the timestamp of that moment in the current system time zone in the given
+ format.
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.from_unixtime(_to_java_column(timestamp), format))
+
+
+@since(1.5)
+def unix_timestamp(timestamp=None, format='yyyy-MM-dd HH:mm:ss'):
+ """
+ Convert time string with given pattern ('yyyy-MM-dd HH:mm:ss', by default)
+ to Unix time stamp (in seconds), using the default timezone and the default
+ locale, return null if fail.
+
+ if `timestamp` is None, then it returns current timestamp.
+ """
+ sc = SparkContext._active_spark_context
+ if timestamp is None:
+ return Column(sc._jvm.functions.unix_timestamp())
+ return Column(sc._jvm.functions.unix_timestamp(_to_java_column(timestamp), format))
+
+
+@since(1.5)
+def from_utc_timestamp(timestamp, tz):
+ """
+ Assumes given timestamp is UTC and converts to given timezone.
+
+ >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
+ >>> df.select(from_utc_timestamp(df.t, "PST").alias('t')).collect()
+ [Row(t=datetime.datetime(1997, 2, 28, 2, 30))]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.from_utc_timestamp(_to_java_column(timestamp), tz))
+
+
+@since(1.5)
+def to_utc_timestamp(timestamp, tz):
+ """
+ Assumes given timestamp is in given timezone and converts to UTC.
+
+ >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
+ >>> df.select(to_utc_timestamp(df.t, "PST").alias('t')).collect()
+ [Row(t=datetime.datetime(1997, 2, 28, 18, 30))]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz))
+
+
+# ---------------------------- misc functions ----------------------------------
+
+@since(1.5)
+@ignore_unicode_prefix
+def crc32(col):
+ """
+ Calculates the cyclic redundancy check value (CRC32) of a binary column and
+ returns the value as a bigint.
+
+ >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).collect()
+ [Row(crc32=u'902fbdd2b1df0c4f70b4a5d23525e932')]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.md5(_to_java_column(col)))
+
+
+@ignore_unicode_prefix
+@since(1.5)
+def md5(col):
+ """Calculates the MD5 digest and returns the value as a 32 character hex string.
+
+ >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect()
+ [Row(hash=u'902fbdd2b1df0c4f70b4a5d23525e932')]
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.md5(_to_java_column(col))
+ return Column(jc)
+
+
+@ignore_unicode_prefix
+@since(1.5)
+def sha1(col):
+ """Returns the hex string result of SHA-1.
+
+ >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect()
+ [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')]
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.sha1(_to_java_column(col))
+ return Column(jc)
+
+
+@ignore_unicode_prefix
+@since(1.5)
+def sha2(col, numBits):
+ """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384,
+ and SHA-512). The numBits indicates the desired bit length of the result, which must have a
+ value of 224, 256, 384, 512, or 0 (which is equivalent to 256).
+
+ >>> digests = df.select(sha2(df.name, 256).alias('s')).collect()
+ >>> digests[0]
+ Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043')
+ >>> digests[1]
+ Row(s=u'cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961')
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.sha2(_to_java_column(col), numBits)
+ return Column(jc)
+
+
+# ---------------------- String/Binary functions ------------------------------
+
+_string_functions = {
+ 'ascii': 'Computes the numeric value of the first character of the string column.',
+ 'base64': 'Computes the BASE64 encoding of a binary column and returns it as a string column.',
+ 'unbase64': 'Decodes a BASE64 encoded string column and returns it as a binary column.',
+ 'initcap': 'Returns a new string column by converting the first letter of each word to ' +
+ 'uppercase. Words are delimited by whitespace.',
+ 'lower': 'Converts a string column to lower case.',
+ 'upper': 'Converts a string column to upper case.',
+ 'reverse': 'Reverses the string column and returns it as a new string column.',
+ 'ltrim': 'Trim the spaces from right end for the specified string value.',
+ 'rtrim': 'Trim the spaces from right end for the specified string value.',
+ 'trim': 'Trim the spaces from both ends for the specified string column.',
+}
+
+
+for _name, _doc in _string_functions.items():
+ globals()[_name] = since(1.5)(_create_function(_name, _doc))
+del _name, _doc
+
+
+@since(1.5)
+@ignore_unicode_prefix
+def concat(*cols):
+ """
+ Concatenates multiple input string columns together into a single string column.
+
+ >>> df = sqlContext.createDataFrame([('abcd','123')], ['s', 'd'])
+ >>> df.select(concat(df.s, df.d).alias('s')).collect()
+ [Row(s=u'abcd123')]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))
+
+
+@since(1.5)
+@ignore_unicode_prefix
+def concat_ws(sep, *cols):
+ """
+ Concatenates multiple input string columns together into a single string column,
+ using the given separator.
+
+ >>> df = sqlContext.createDataFrame([('abcd','123')], ['s', 'd'])
+ >>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect()
+ [Row(s=u'abcd-123')]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.concat_ws(sep, _to_seq(sc, cols, _to_java_column)))
+
+
+@since(1.5)
+def decode(col, charset):
+ """
+ Computes the first argument into a string from a binary using the provided character set
+ (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.decode(_to_java_column(col), charset))
+
+
+@since(1.5)
+def encode(col, charset):
+ """
+ Computes the first argument into a binary from a string using the provided character set
+ (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.encode(_to_java_column(col), charset))
+
+
+@ignore_unicode_prefix
+@since(1.5)
+def format_number(col, d):
+ """
+ Formats the number X to a format like '#,--#,--#.--', rounded to d decimal places,
+ and returns the result as a string.
+
+ :param col: the column name of the numeric value to be formatted
+ :param d: the N decimal places
+
+ >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect()
+ [Row(v=u'5.0000')]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.format_number(_to_java_column(col), d))
+
+
+@ignore_unicode_prefix
+@since(1.5)
+def format_string(format, *cols):
+ """
+ Formats the arguments in printf-style and returns the result as a string column.
+
+ :param col: the column name of the numeric value to be formatted
+ :param d: the N decimal places
+
+ >>> df = sqlContext.createDataFrame([(5, "hello")], ['a', 'b'])
+ >>> df.select(format_string('%d %s', df.a, df.b).alias('v')).collect()
+ [Row(v=u'5 hello')]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.format_string(format, _to_seq(sc, cols, _to_java_column)))
+
+
+@since(1.5)
+def instr(str, substr):
+ """
+ Locate the position of the first occurrence of substr column in the given string.
+ Returns null if either of the arguments are null.
+
+ NOTE: The position is not zero based, but 1 based index, returns 0 if substr
+ could not be found in str.
+
+ >>> df = sqlContext.createDataFrame([('abcd',)], ['s',])
+ >>> df.select(instr(df.s, 'b').alias('s')).collect()
+ [Row(s=2)]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.instr(_to_java_column(str), substr))
+
+
+@since(1.5)
@ignore_unicode_prefix
def substring(str, pos, len):
"""
@@ -960,6 +1094,126 @@ def substring_index(str, delim, count):
@ignore_unicode_prefix
@since(1.5)
+def levenshtein(left, right):
+ """Computes the Levenshtein distance of the two given strings.
+
+ >>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r'])
+ >>> df0.select(levenshtein('l', 'r').alias('d')).collect()
+ [Row(d=3)]
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.levenshtein(_to_java_column(left), _to_java_column(right))
+ return Column(jc)
+
+
+@since(1.5)
+def locate(substr, str, pos=0):
+ """
+ Locate the position of the first occurrence of substr in a string column, after position pos.
+
+ NOTE: The position is not zero based, but 1 based index. returns 0 if substr
+ could not be found in str.
+
+ :param substr: a string
+ :param str: a Column of StringType
+ :param pos: start position (zero based)
+
+ >>> df = sqlContext.createDataFrame([('abcd',)], ['s',])
+ >>> df.select(locate('b', df.s, 1).alias('s')).collect()
+ [Row(s=2)]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.locate(substr, _to_java_column(str), pos))
+
+
+@since(1.5)
+@ignore_unicode_prefix
+def lpad(col, len, pad):
+ """
+ Left-pad the string column to width `len` with `pad`.
+
+ >>> df = sqlContext.createDataFrame([('abcd',)], ['s',])
+ >>> df.select(lpad(df.s, 6, '#').alias('s')).collect()
+ [Row(s=u'##abcd')]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.lpad(_to_java_column(col), len, pad))
+
+
+@since(1.5)
+@ignore_unicode_prefix
+def rpad(col, len, pad):
+ """
+ Right-pad the string column to width `len` with `pad`.
+
+ >>> df = sqlContext.createDataFrame([('abcd',)], ['s',])
+ >>> df.select(rpad(df.s, 6, '#').alias('s')).collect()
+ [Row(s=u'abcd##')]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.rpad(_to_java_column(col), len, pad))
+
+
+@since(1.5)
+@ignore_unicode_prefix
+def repeat(col, n):
+ """
+ Repeats a string column n times, and returns it as a new string column.
+
+ >>> df = sqlContext.createDataFrame([('ab',)], ['s',])
+ >>> df.select(repeat(df.s, 3).alias('s')).collect()
+ [Row(s=u'ababab')]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.repeat(_to_java_column(col), n))
+
+
+@since(1.5)
+@ignore_unicode_prefix
+def split(str, pattern):
+ """
+ Splits str around pattern (pattern is a regular expression).
+
+ NOTE: pattern is a string represent the regular expression.
+
+ >>> df = sqlContext.createDataFrame([('ab12cd',)], ['s',])
+ >>> df.select(split(df.s, '[0-9]+').alias('s')).collect()
+ [Row(s=[u'ab', u'cd'])]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.split(_to_java_column(str), pattern))
+
+
+@ignore_unicode_prefix
+@since(1.5)
+def regexp_extract(str, pattern, idx):
+ """Extract a specific(idx) group identified by a java regex, from the specified string column.
+
+ >>> df = sqlContext.createDataFrame([('100-200',)], ['str'])
+ >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect()
+ [Row(d=u'100')]
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx)
+ return Column(jc)
+
+
+@ignore_unicode_prefix
+@since(1.5)
+def regexp_replace(str, pattern, replacement):
+ """Replace all substrings of the specified string value that match regexp with rep.
+
+ >>> df = sqlContext.createDataFrame([('100-200',)], ['str'])
+ >>> df.select(regexp_replace('str', '(\\d+)', '--').alias('d')).collect()
+ [Row(d=u'-----')]
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.regexp_replace(_to_java_column(str), pattern, replacement)
+ return Column(jc)
+
+
+@ignore_unicode_prefix
+@since(1.5)
def initcap(col):
"""Translate the first letter of each word to upper case in the sentence.
@@ -971,6 +1225,114 @@ def initcap(col):
@since(1.5)
+@ignore_unicode_prefix
+def soundex(col):
+ """
+ Returns the SoundEx encoding for a string
+
+ >>> df = sqlContext.createDataFrame([("Peters",),("Uhrbach",)], ['name'])
+ >>> df.select(soundex(df.name).alias("soundex")).collect()
+ [Row(soundex=u'P362'), Row(soundex=u'U612')]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.soundex(_to_java_column(col)))
+
+
+@ignore_unicode_prefix
+@since(1.5)
+def bin(col):
+ """Returns the string representation of the binary value of the given column.
+
+ >>> df.select(bin(df.age).alias('c')).collect()
+ [Row(c=u'10'), Row(c=u'101')]
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.bin(_to_java_column(col))
+ return Column(jc)
+
+
+@ignore_unicode_prefix
+@since(1.5)
+def hex(col):
+ """Computes hex value of the given column, which could be StringType,
+ BinaryType, IntegerType or LongType.
+
+ >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect()
+ [Row(hex(a)=u'414243', hex(b)=u'3')]
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.hex(_to_java_column(col))
+ return Column(jc)
+
+
+@ignore_unicode_prefix
+@since(1.5)
+def unhex(col):
+ """Inverse of hex. Interprets each pair of characters as a hexadecimal number
+ and converts to the byte representation of number.
+
+ >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect()
+ [Row(unhex(a)=bytearray(b'ABC'))]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.unhex(_to_java_column(col)))
+
+
+@ignore_unicode_prefix
+@since(1.5)
+def length(col):
+ """Calculates the length of a string or binary expression.
+
+ >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect()
+ [Row(length=3)]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.length(_to_java_column(col)))
+
+
+# ---------------------- Collection functions ------------------------------
+
+@since(1.4)
+def array(*cols):
+ """Creates a new array column.
+
+ :param cols: list of column names (string) or list of :class:`Column` expressions that have
+ the same data type.
+
+ >>> df.select(array('age', 'age').alias("arr")).collect()
+ [Row(arr=[2, 2]), Row(arr=[5, 5])]
+ >>> df.select(array([df.age, df.age]).alias("arr")).collect()
+ [Row(arr=[2, 2]), Row(arr=[5, 5])]
+ """
+ sc = SparkContext._active_spark_context
+ if len(cols) == 1 and isinstance(cols[0], (list, set)):
+ cols = cols[0]
+ jc = sc._jvm.functions.array(_to_seq(sc, cols, _to_java_column))
+ return Column(jc)
+
+
+@since(1.4)
+def explode(col):
+ """Returns a new row for each element in the given array or map.
+
+ >>> from pyspark.sql import Row
+ >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])
+ >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect()
+ [Row(anInt=1), Row(anInt=2), Row(anInt=3)]
+
+ >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show()
+ +---+-----+
+ |key|value|
+ +---+-----+
+ | a| b|
+ +---+-----+
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.explode(_to_java_column(col))
+ return Column(jc)
+
+
+@since(1.5)
def size(col):
"""
Collection function: returns the length of the array or map stored in the column.
@@ -1002,19 +1364,7 @@ def sort_array(col, asc=True):
return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc))
-@since
-@ignore_unicode_prefix
-def soundex(col):
- """
- Returns the SoundEx encoding for a string
-
- >>> df = sqlContext.createDataFrame([("Peters",),("Uhrbach",)], ['name'])
- >>> df.select(soundex(df.name).alias("soundex")).collect()
- [Row(soundex=u'P362'), Row(soundex=u'U612')]
- """
- sc = SparkContext._active_spark_context
- return Column(sc._jvm.functions.size(_to_java_column(col)))
-
+# ---------------------------- User Defined Function ----------------------------------
class UserDefinedFunction(object):
"""
@@ -1066,6 +1416,11 @@ def udf(f, returnType=StringType()):
"""
return UserDefinedFunction(f, returnType)
+blacklist = ['map', 'since', 'ignore_unicode_prefix']
+__all__ = [k for k, v in globals().items()
+ if not k.startswith('_') and k[0].islower() and callable(v) and k not in blacklist]
+__all__.sort()
+
def _test():
import doctest