aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/functions.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-02-17 10:22:48 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-17 10:22:48 -0800
commitd8adefefcc2a4af32295440ed1d4917a6968f017 (patch)
tree70f6455183a7a38c94e59b61edcbeb947646727b /python/pyspark/sql/functions.py
parentc74b07fa94a8da50437d952ae05cf6ac70fbb93e (diff)
downloadspark-d8adefefcc2a4af32295440ed1d4917a6968f017.tar.gz
spark-d8adefefcc2a4af32295440ed1d4917a6968f017.tar.bz2
spark-d8adefefcc2a4af32295440ed1d4917a6968f017.zip
[SPARK-5859] [PySpark] [SQL] fix DataFrame Python API
1. added explain() 2. add isLocal() 3. do not call show() in __repl__ 4. add foreach() and foreachPartition() 5. add distinct() 6. fix functions.col()/column()/lit() 7. fix unit tests in sql/functions.py 8. fix unicode in showString() Author: Davies Liu <davies@databricks.com> Closes #4645 from davies/df6 and squashes the following commits: 6b46a2c [Davies Liu] fix DataFrame Python API
Diffstat (limited to 'python/pyspark/sql/functions.py')
-rw-r--r--python/pyspark/sql/functions.py12
1 files changed, 5 insertions, 7 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index d0e090607f..fc61162f0b 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -37,7 +37,7 @@ def _create_function(name, doc=""):
""" Create a function for aggregator by name"""
def _(col):
sc = SparkContext._active_spark_context
- jc = getattr(sc._jvm.functions, name)(_to_java_column(col))
+ jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col)
return Column(jc)
_.__name__ = name
_.__doc__ = doc
@@ -140,6 +140,7 @@ class UserDefinedFunction(object):
def udf(f, returnType=StringType()):
"""Create a user defined function (UDF)
+ >>> from pyspark.sql.types import IntegerType
>>> slen = udf(lambda s: len(s), IntegerType())
>>> df.select(slen(df.name).alias('slen')).collect()
[Row(slen=5), Row(slen=3)]
@@ -151,17 +152,14 @@ def _test():
import doctest
from pyspark.context import SparkContext
from pyspark.sql import Row, SQLContext
- import pyspark.sql.dataframe
- globs = pyspark.sql.dataframe.__dict__.copy()
+ import pyspark.sql.functions
+ globs = pyspark.sql.functions.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
globs['sqlCtx'] = SQLContext(sc)
globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
- globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
- globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
- Row(name='Bob', age=5, height=85)]).toDF()
(failure_count, test_count) = doctest.testmod(
- pyspark.sql.dataframe, globs=globs,
+ pyspark.sql.functions, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
globs['sc'].stop()
if failure_count: