aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-02-02 19:01:47 -0800
committerReynold Xin <rxin@databricks.com>2015-02-02 19:01:47 -0800
commit554403fd913685da879cf6a280c58a9fad19448a (patch)
treeb3a63382e7385fa1480b54707b348b0bde02190d /python
parenteccb9fbb2d1bf6f7c65fb4f017e9205bb3034ec6 (diff)
downloadspark-554403fd913685da879cf6a280c58a9fad19448a.tar.gz
spark-554403fd913685da879cf6a280c58a9fad19448a.tar.bz2
spark-554403fd913685da879cf6a280c58a9fad19448a.zip
[SQL] Improve DataFrame API error reporting
1. Throw UnsupportedOperationException if a Column is not computable. 2. Perform eager analysis on DataFrame so we can catch errors when they happen (not when an action is run). Author: Reynold Xin <rxin@databricks.com> Author: Davies Liu <davies@databricks.com> Closes #4296 from rxin/col-computability and squashes the following commits: 6527b86 [Reynold Xin] Merge pull request #8 from davies/col-computability fd92bc7 [Reynold Xin] Merge branch 'master' into col-computability f79034c [Davies Liu] fix python tests 5afe1ff [Reynold Xin] Fix scala test. 17f6bae [Reynold Xin] Various fixes. b932e86 [Reynold Xin] Added eager analysis for error reporting. e6f00b8 [Reynold Xin] [SQL][API] ComputableColumn vs IncomputableColumn
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql.py75
-rw-r--r--python/pyspark/tests.py6
2 files changed, 56 insertions, 25 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 3f2d7ac825..32bff0c7e8 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -2124,6 +2124,10 @@ class DataFrame(object):
return rs[0] if rs else None
return self.take(n)
+ def first(self):
+ """ Return the first row. """
+ return self.head()
+
def tail(self):
raise NotImplemented
@@ -2159,7 +2163,7 @@ class DataFrame(object):
else:
cols = [c._jc for c in cols]
jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
- jdf = self._jdf.select(self._jdf.toColumnArray(jcols))
+ jdf = self._jdf.select(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
return DataFrame(jdf, self.sql_ctx)
def filter(self, condition):
@@ -2189,7 +2193,7 @@ class DataFrame(object):
else:
cols = [c._jc for c in cols]
jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
- jdf = self._jdf.groupBy(self._jdf.toColumnArray(jcols))
+ jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
return GroupedDataFrame(jdf, self.sql_ctx)
def agg(self, *exprs):
@@ -2278,14 +2282,17 @@ class GroupedDataFrame(object):
:param exprs: list or aggregate columns or a map from column
name to agregate methods.
"""
+ assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
jmap = MapConverter().convert(exprs[0],
self.sql_ctx._sc._gateway._gateway_client)
jdf = self._jdf.agg(jmap)
else:
# Columns
- assert all(isinstance(c, Column) for c in exprs), "all exprs should be Columns"
- jdf = self._jdf.agg(*exprs)
+ assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
+ jcols = ListConverter().convert([c._jc for c in exprs[1:]],
+ self.sql_ctx._sc._gateway._gateway_client)
+ jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
return DataFrame(jdf, self.sql_ctx)
@dfapi
@@ -2347,7 +2354,7 @@ def _create_column_from_literal(literal):
def _create_column_from_name(name):
sc = SparkContext._active_spark_context
- return sc._jvm.Column(name)
+ return sc._jvm.IncomputableColumn(name)
def _scalaMethod(name):
@@ -2371,7 +2378,7 @@ def _unary_op(name):
return _
-def _bin_op(name, pass_literal_through=False):
+def _bin_op(name, pass_literal_through=True):
""" Create a method for given binary operator
Keyword arguments:
@@ -2465,10 +2472,10 @@ class Column(DataFrame):
# __getattr__ = _bin_op("getField")
# string methods
- rlike = _bin_op("rlike", pass_literal_through=True)
- like = _bin_op("like", pass_literal_through=True)
- startswith = _bin_op("startsWith", pass_literal_through=True)
- endswith = _bin_op("endsWith", pass_literal_through=True)
+ rlike = _bin_op("rlike")
+ like = _bin_op("like")
+ startswith = _bin_op("startsWith")
+ endswith = _bin_op("endsWith")
upper = _unary_op("upper")
lower = _unary_op("lower")
@@ -2476,7 +2483,6 @@ class Column(DataFrame):
if type(startPos) != type(pos):
raise TypeError("Can not mix the type")
if isinstance(startPos, (int, long)):
-
jc = self._jc.substr(startPos, pos)
elif isinstance(startPos, Column):
jc = self._jc.substr(startPos._jc, pos._jc)
@@ -2507,16 +2513,21 @@ class Column(DataFrame):
return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx)
+def _to_java_column(col):
+ if isinstance(col, Column):
+ jcol = col._jc
+ else:
+ jcol = _create_column_from_name(col)
+ return jcol
+
+
def _aggregate_func(name):
""" Create a function for aggregator by name"""
def _(col):
sc = SparkContext._active_spark_context
- if isinstance(col, Column):
- jcol = col._jc
- else:
- jcol = _create_column_from_name(col)
- jc = getattr(sc._jvm.org.apache.spark.sql.Dsl, name)(jcol)
+ jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col))
return Column(jc)
+
return staticmethod(_)
@@ -2524,13 +2535,31 @@ class Aggregator(object):
"""
A collections of builtin aggregators
"""
- max = _aggregate_func("max")
- min = _aggregate_func("min")
- avg = mean = _aggregate_func("mean")
- sum = _aggregate_func("sum")
- first = _aggregate_func("first")
- last = _aggregate_func("last")
- count = _aggregate_func("count")
+ AGGS = [
+ 'lit', 'col', 'column', 'upper', 'lower', 'sqrt', 'abs',
+ 'min', 'max', 'first', 'last', 'count', 'avg', 'mean', 'sum', 'sumDistinct',
+ ]
+ for _name in AGGS:
+ locals()[_name] = _aggregate_func(_name)
+ del _name
+
+ @staticmethod
+ def countDistinct(col, *cols):
+ sc = SparkContext._active_spark_context
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+ sc._gateway._gateway_client)
+ jc = sc._jvm.Dsl.countDistinct(_to_java_column(col),
+ sc._jvm.Dsl.toColumns(jcols))
+ return Column(jc)
+
+ @staticmethod
+ def approxCountDistinct(col, rsd=None):
+ sc = SparkContext._active_spark_context
+ if rsd is None:
+ jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col))
+ else:
+ jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
+ return Column(jc)
def _test():
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index bec1961f26..fef6c92875 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1029,9 +1029,11 @@ class SQLTests(ReusedPySparkTestCase):
g = df.groupBy()
self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
- # TODO(davies): fix aggregators
+
from pyspark.sql import Aggregator as Agg
- # self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
+ self.assertEqual((0, u'99'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
+ self.assertTrue(95 < g.agg(Agg.approxCountDistinct(df.key)).first()[0])
+ self.assertEqual(100, g.agg(Agg.countDistinct(df.value)).first()[0])
def test_help_command(self):
# Regression test for SPARK-5464