aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-02-17 10:22:48 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-17 10:22:48 -0800
commitd8adefefcc2a4af32295440ed1d4917a6968f017 (patch)
tree70f6455183a7a38c94e59b61edcbeb947646727b /python/pyspark/sql
parentc74b07fa94a8da50437d952ae05cf6ac70fbb93e (diff)
downloadspark-d8adefefcc2a4af32295440ed1d4917a6968f017.tar.gz
spark-d8adefefcc2a4af32295440ed1d4917a6968f017.tar.bz2
spark-d8adefefcc2a4af32295440ed1d4917a6968f017.zip
[SPARK-5859] [PySpark] [SQL] fix DataFrame Python API
1. added explain() 2. add isLocal() 3. do not call show() in __repl__ 4. add foreach() and foreachPartition() 5. add distinct() 6. fix functions.col()/column()/lit() 7. fix unit tests in sql/functions.py 8. fix unicode in showString() Author: Davies Liu <davies@databricks.com> Closes #4645 from davies/df6 and squashes the following commits: 6b46a2c [Davies Liu] fix DataFrame Python API
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r--python/pyspark/sql/dataframe.py65
-rw-r--r--python/pyspark/sql/functions.py12
2 files changed, 59 insertions, 18 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 28a59e73a3..841724095f 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -238,6 +238,22 @@ class DataFrame(object):
"""
print (self._jdf.schema().treeString())
+ def explain(self, extended=False):
+ """
+ Prints the plans (logical and physical) to the console for
+ debugging purpose.
+
+ If extended is False, only prints the physical plan.
+ """
+ self._jdf.explain(extended)
+
+ def isLocal(self):
+ """
+ Returns True if the `collect` and `take` methods can be run locally
+ (without any Spark executors).
+ """
+ return self._jdf.isLocal()
+
def show(self):
"""
Print the first 20 rows.
@@ -247,14 +263,12 @@ class DataFrame(object):
2 Alice
5 Bob
>>> df
- age name
- 2 Alice
- 5 Bob
+ DataFrame[age: int, name: string]
"""
- print (self)
+ print self._jdf.showString().encode('utf8', 'ignore')
def __repr__(self):
- return self._jdf.showString()
+ return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
def count(self):
"""Return the number of elements in this RDD.
@@ -336,6 +350,8 @@ class DataFrame(object):
"""
Return a new RDD by applying a function to each partition.
+ It's a shorthand for df.rdd.mapPartitions()
+
>>> rdd = sc.parallelize([1, 2, 3, 4], 4)
>>> def f(iterator): yield 1
>>> rdd.mapPartitions(f).sum()
@@ -343,6 +359,31 @@ class DataFrame(object):
"""
return self.rdd.mapPartitions(f, preservesPartitioning)
+ def foreach(self, f):
+ """
+ Applies a function to all rows of this DataFrame.
+
+ It's a shorthand for df.rdd.foreach()
+
+ >>> def f(person):
+ ... print person.name
+ >>> df.foreach(f)
+ """
+ return self.rdd.foreach(f)
+
+ def foreachPartition(self, f):
+ """
+ Applies a function to each partition of this DataFrame.
+
+ It's a shorthand for df.rdd.foreachPartition()
+
+ >>> def f(people):
+ ... for person in people:
+ ... print person.name
+ >>> df.foreachPartition(f)
+ """
+ return self.rdd.foreachPartition(f)
+
def cache(self):
""" Persist with the default storage level (C{MEMORY_ONLY_SER}).
"""
@@ -377,8 +418,13 @@ class DataFrame(object):
""" Return a new :class:`DataFrame` that has exactly `numPartitions`
partitions.
"""
- rdd = self._jdf.repartition(numPartitions, None)
- return DataFrame(rdd, self.sql_ctx)
+ return DataFrame(self._jdf.repartition(numPartitions, None), self.sql_ctx)
+
+ def distinct(self):
+ """
+ Return a new :class:`DataFrame` containing the distinct rows in this DataFrame.
+ """
+ return DataFrame(self._jdf.distinct(), self.sql_ctx)
def sample(self, withReplacement, fraction, seed=None):
"""
@@ -957,10 +1003,7 @@ class Column(DataFrame):
return Column(jc, self.sql_ctx)
def __repr__(self):
- if self._jdf.isComputable():
- return self._jdf.samples()
- else:
- return 'Column<%s>' % self._jdf.toString()
+ return 'Column<%s>' % self._jdf.toString().encode('utf8')
def toPandas(self):
"""
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index d0e090607f..fc61162f0b 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -37,7 +37,7 @@ def _create_function(name, doc=""):
""" Create a function for aggregator by name"""
def _(col):
sc = SparkContext._active_spark_context
- jc = getattr(sc._jvm.functions, name)(_to_java_column(col))
+ jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col)
return Column(jc)
_.__name__ = name
_.__doc__ = doc
@@ -140,6 +140,7 @@ class UserDefinedFunction(object):
def udf(f, returnType=StringType()):
"""Create a user defined function (UDF)
+ >>> from pyspark.sql.types import IntegerType
>>> slen = udf(lambda s: len(s), IntegerType())
>>> df.select(slen(df.name).alias('slen')).collect()
[Row(slen=5), Row(slen=3)]
@@ -151,17 +152,14 @@ def _test():
import doctest
from pyspark.context import SparkContext
from pyspark.sql import Row, SQLContext
- import pyspark.sql.dataframe
- globs = pyspark.sql.dataframe.__dict__.copy()
+ import pyspark.sql.functions
+ globs = pyspark.sql.functions.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
globs['sqlCtx'] = SQLContext(sc)
globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
- globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
- globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
- Row(name='Bob', age=5, height=85)]).toDF()
(failure_count, test_count) = doctest.testmod(
- pyspark.sql.dataframe, globs=globs,
+ pyspark.sql.functions, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
globs['sc'].stop()
if failure_count: